在 keras 中微调预训练模型

2023-12-22

我想在 keras 中使用预训练的 imagenet VGG16 模型,并在顶部添加我自己的小型卷积网络。我只对功能感兴趣,对预测不感兴趣

from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.applications.vgg16 import preprocess_input
import numpy as np
import os
from keras.models import Model
from keras.models import Sequential
from keras.layers import Convolution2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense

从目录加载图像(该目录包含 4 个图像)

IF = '/home/ubu/files/png/'
files = os.listdir(IF)

imgs = [img_to_array(load_img(IF + p, target_size=[224,224])) for p in files]
im = np.array(imgs)

加载基础模型,预处理输入并获取特征

base_model = VGG16(weights='imagenet', include_top=False)

x = preprocess_input(aa)
features = base_model.predict(x)

这有效,我在预训练的 VGG 上获得了图像的特征。

我现在想微调模型并添加一些卷积层。 我读https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html and https://keras.io/applications/ https://keras.io/applications/但无法将它们完全结合在一起。

在顶部添加我的模型:

x = base_model.output
x = Convolution2D(32, 3, 3)(x)
x = Activation('relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Convolution2D(32, 3, 3)(x)
x = Activation('relu')(x)
feat = MaxPooling2D(pool_size=(2, 2))(x)

构建完整的模型

model_complete = Model(input=base_model.input, output=feat)

阻止基础层被学习

for layer in base_model.layers:
layer.trainable = False

新模型

model_complete.compile(optimizer='rmsprop', 
          loss='binary_crossentropy')

现在拟合新模型,模型是 4 张图像,[1,0,1,0] 是类标签。 但这显然是错误的:

model_complete.fit_generator((x, [1,0,1,0]), samples_per_epoch=100, nb_epoch=2)

ValueError: output of generator should be a tuple (x, y, sample_weight) or (x, y). Found: None

这是怎么做到的?

如果我只想替换最后一个卷积块(VGG16 中的 conv block5)而不是添加某些内容,我该怎么做?

我如何只训练瓶颈特征?

特征输出features形状为 (4, 512, 7, 7)。有四个图像,那么另外的维度里有什么呢?我如何将其减少为 (1,x) 数组?


适配型号

生成器代码的问题在于 fit_generator 方法期望生成器函数生成您不提供的拟合数据。 您可以按照链接到的教程中的方式定义生成器,也可以自己创建数据和标签并自行拟合模型:

model_complete.fit(images, labels, batch_size=100, nb_epoch=2)

其中 images 是您生成的训练图像, labels 是相应的标签。

删除最后一层

假设您有一个模型变量和下面描述的“pop”方法,您可以这样做model = pop(model)删除最后一层。

仅训练特定层正如您在代码中所做的那样,您可以执行以下操作:

for layer in base_model.layers:
    layer.trainable = False

然后你可以通过改变它们来“解冻”并分层你想要的trainable财产给True.

改变尺寸

要将输出更改为一维数组,您可以使用压平层 https://keras.io/layers/core/#flatten


流行方法

def pop(model):
    '''Removes a layer instance on top of the layer stack.
    This code is thanks to @joelthchao https://github.com/fchollet/keras/issues/2371#issuecomment-211734276
    '''
    if not model.outputs:
        raise Exception('Sequential model cannot be popped: model is empty.')
    else:
        model.layers.pop()
        if not model.layers:
            model.outputs = []
            model.inbound_nodes = []
            model.outbound_nodes = []
        else:
            model.layers[-1].outbound_nodes = []
            model.outputs = [model.layers[-1].output]
        model.built = False

    return model
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

在 keras 中微调预训练模型 的相关文章

随机推荐

  • 成本较低的哈希算法是什么?

    我对哈希算法了解不多 在将文件转发到远程系统 有点像 S3 之前 我需要计算 Java 中传入文件的哈希值 该系统需要 MD2 MD5 SHA X 中的文件哈希值 计算此哈希值不是出于安全原因 而只是为了一致性校验和 我能够在转发文件时使用
  • 打开本机应用程序。来自野生动物园

    我知道 iPhone 应用程序 像沙箱一样操作 这意味着他们无权访问其他应用程序的文件 我还成功地使用以下命令从 Native App 在 Safari 中打开了一个网站 openURL NSURL URLWithString Websit
  • 禁用 web.config 继承?

    我的网站根目录中有一个内容管理应用程序 我尝试在子文件夹下使用另一个应用程序 计费应用程序 不幸的是 根站点的 web config 干扰了子应用程序 有没有办法只禁用子文件夹的 web config 继承 Update 如链接所示史蒂芬
  • 可以从 UIView 复制 CALayer 吗?

    这是我的设置 我有一个 CALayer 我想向其中添加子层 我通过设置 UILabel 创建这些子层 然后将 UILables 层添加到我的主层 当然 这会使沉重的 UILabel 对象在后台徘徊 是否可以从 UIView 获取图层及其所有
  • 在 OData 中,有没有办法按数组中的第一个元素进行排序?

    我有一个 OData 4 端点 用于在表中显示数据 其中一列包含一组由我的元素内的字符串数组连接而成的数据 有没有办法按数组中的第一个元素排序 我的元素可能如下所示 FirstName John MiddleNames Harry Bobb
  • 如何获取 Windows 8 应用程序的方法名称

    如何在 win 8 WinRT 应用程序中获取当前方法名称 早期在 wp7 中我们可以使用System Reflection MethodBase GetCurrentMethod Name但它不再存在了 谢谢 是的 NETCore 缺少很
  • 如何从命令行使用 GIMP 将 XCF 转换为 PNG?

    作为构建过程的一部分 我需要将许多 XCF GIMP 的本机格式 图像转换为 PNG 格式 我确信使用 GIMP 的批处理模式应该可以实现这一点 但我已经忘记了我以前知道的所有 script fu 我的输入图像有多个图层 因此我需要相当于
  • 将数据传递给 subprocess.check_output

    我想调用一个脚本 将字符串的内容通过管道传输到其标准输入并检索其标准输出 我不想接触真正的文件系统 所以我无法为其创建真正的临时文件 using subprocess check output剧本写什么我都能得到 我怎样才能将输入字符串放入
  • 尽可能快地获取大型文本文件中包含字符串的所有行?

    在Powershell中 如何尽可能快地读取和获取巨大文本文件 大约200000行 30 MB 中包含特定字符串的最后一行 或所有行 我在用着 get content myfile txt select string pattern my
  • GWT Requestfactory 性能建议

    我发现使用 GWT requestfactory 时性能非常糟糕 例如 一个请求需要我的服务层 2 秒才能完成 而 GWT 则需要 20 秒才能序列化 我的服务返回约 100 个实体代理 这些对象中的每一个都有 4 个 ValueProxi
  • 如何生成一组随机颜色,其中没有两种颜色几乎相似?

    我目前使用以下函数来生成颜色的随机十六进制表示 function getRandomColor max r 192 max g 192 max b 192 if max r gt 192 max r 192 if max g gt 192
  • 为什么 VS2010 调试器会挂起?

    这种情况刚刚开始发生在我的工作箱和家里 在 Visual Studio 2010 中 我将启动调试会话 程序将运行到第一个断点 仅此而已 我可以随心所欲地按 F10 11 5 什么都不会发生 退出的唯一方法是 Shift F5 这让我发疯
  • 无法设置访客内存“android_arm”:参数无效

    我花了几天时间尝试启动任何 Android 程序 即使 Hello World 也给我同样的错误 2014 10 28 18 07 14 android19 Android Launch 2014 10 28 18 07 14 androi
  • XDocument 中innerXml 和outerXml 的对应项是什么?

    我正在尝试将一些使用 XmlDocument 类的代码重构为 Linq To Xml 但是 我不确定XDocument 中innerXml 和outerXml 的对应项是什么 根据 MSDN InnerXml http msdn micro
  • 根据选择值禁用 Angular Reactive 表单输入

    我有一个表单 使用 Angular Material 我想根据选择值禁用某些输入字段 我的代码如下所示 HTML
  • 如何查看Excel文件的XML形式?

    如何查看 Excel 的 XML 形式 xlsx file XLSX 文件只是 ZIP 文件 因此请使用您最喜欢的 ZIP 工具解压缩它们
  • 如何通过java代码访问和创建azure存储帐户的生命周期规则/生命周期管理策略

    我想创建一个生命周期规则 or 生命周期管理政策对于特定的 azure 存储帐户 通过java代码 不通过 terraform 或 azure 门户 任何适当的代码片段或参考都会有所帮助 提前致谢 如果您想管理 Azure Blob 存储生
  • xcode Storyboard - ibtoold 解档异常

    例外的是 CompileStoryboard Catwall en lproj MainStoryboard storyboard cd Users guvenozyurt Desktop git catwall ios setenv IB
  • 智能卡读卡器访问时出现未知错误 0x16

    我正在尝试更改 ACR1252U 上的蜂鸣器持续时间 API 链接 http www acs com hk download manual 6402 API ACR1252U 1 09 pdf http www acs com hk dow
  • 在 keras 中微调预训练模型

    我想在 keras 中使用预训练的 imagenet VGG16 模型 并在顶部添加我自己的小型卷积网络 我只对功能感兴趣 对预测不感兴趣 from keras preprocessing image import ImageDataGen