使用数据增强层在 Tensorflow 2.7.0 上保存模型

2024-01-20

尝试使用 Tensorflow 版本 2.7.0 保存具有数据增强层的模型时出现错误。

这是数据增强的代码:

input_shape_rgb = (img_height, img_width, 3)
data_augmentation_rgb = tf.keras.Sequential(
  [ 
    layers.RandomFlip("horizontal"),
    layers.RandomFlip("vertical"),
    layers.RandomRotation(0.5),
    layers.RandomZoom(0.5),
    layers.RandomContrast(0.5),
    RandomColorDistortion(name='random_contrast_brightness/none'),
  ]
)

现在我像这样构建我的模型:

# Build the model
input_shape = (img_height, img_width, 3)

model = Sequential([
  layers.Input(input_shape),
  data_augmentation_rgb,
  layers.Rescaling((1./255)),

  layers.Conv2D(16, kernel_size, padding=padding, activation='relu', strides=1, 
     data_format='channels_last'),
  layers.MaxPooling2D(),
  layers.BatchNormalization(),

  layers.Conv2D(32, kernel_size, padding=padding, activation='relu'), # best 4
  layers.MaxPooling2D(),
  layers.BatchNormalization(),

  layers.Conv2D(64, kernel_size, padding=padding, activation='relu'), # best 3
  layers.MaxPooling2D(),
  layers.BatchNormalization(),

  layers.Conv2D(128, kernel_size, padding=padding, activation='relu'), # best 3
  layers.MaxPooling2D(),
  layers.BatchNormalization(),

  layers.Flatten(),
  layers.Dense(128, activation='relu'), # best 1
  layers.Dropout(0.1),
  layers.Dense(128, activation='relu'), # best 1
  layers.Dropout(0.1),
  layers.Dense(64, activation='relu'), # best 1
  layers.Dropout(0.1),
  layers.Dense(num_classes, activation = 'softmax')
 ])

 model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=metrics)
 model.summary()

然后训练完成后我就做:

model.save("./")

我收到此错误:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-84-87d3f09f8bee> in <module>()
----> 1 model.save("./")


/usr/local/lib/python3.7/dist-packages/keras/utils/traceback_utils.py in 
 error_handler(*args, **kwargs)
 65     except Exception as e:  # pylint: disable=broad-except
 66       filtered_tb = _process_traceback_frames(e.__traceback__)
 ---> 67       raise e.with_traceback(filtered_tb) from None
 68     finally:
 69       del filtered_tb

 /usr/local/lib/python3.7/dist- 
 packages/tensorflow/python/saved_model/function_serialization.py in 
 serialize_concrete_function(concrete_function, node_ids, coder)
 66   except KeyError:
 67     raise KeyError(
 ---> 68         f"Failed to add concrete function '{concrete_function.name}' to 
 object-"
 69         f"based SavedModel as it captures tensor {capture!r} which is 
 unsupported"
 70         " or not reachable from root. "

 KeyError: "Failed to add concrete function 
 'b'__inference_sequential_46_layer_call_fn_662953'' to object-based SavedModel as it 
 captures tensor <tf.Tensor: shape=(), dtype=resource, value=<Resource Tensor>> which 
 is unsupported or not reachable from root. One reason could be that a stateful 
 object or a variable that the function depends on is not assigned to an attribute of 
 the serialized trackable object (see SaveTest.test_captures_unreachable_variable)."

我通过更改模型的架构检查了出现此错误的原因,我发现原因来自 data_augmentation 层,因为RandomFlip and RandomRotation和其他人改变自layers.experimental.prepocessing.RandomFlip to layers.RandomFlip,但仍然出现错误。


这似乎是 Tensorflow 2.7 使用时的一个错误model.save与参数结合save_format="tf",这是默认设置的。各层RandomFlip, RandomRotation, RandomZoom, and RandomContrast导致问题的原因是它们不可序列化。有趣的是,Rescaling可以毫无问题地保存图层。解决方法是简单地使用较旧的 Keras H5 格式保存模型model.save("test", save_format='h5'):

import tensorflow as tf
import numpy as np

class RandomColorDistortion(tf.keras.layers.Layer):
    def __init__(self, contrast_range=[0.5, 1.5], 
                 brightness_delta=[-0.2, 0.2], **kwargs):
        super(RandomColorDistortion, self).__init__(**kwargs)
        self.contrast_range = contrast_range
        self.brightness_delta = brightness_delta
    
    def call(self, images, training=None):
        if not training:
            return images
        contrast = np.random.uniform(
            self.contrast_range[0], self.contrast_range[1])
        brightness = np.random.uniform(
            self.brightness_delta[0], self.brightness_delta[1])
        
        images = tf.image.adjust_contrast(images, contrast)
        images = tf.image.adjust_brightness(images, brightness)
        images = tf.clip_by_value(images, 0, 1)
        return images
    
    def get_config(self):
        config = super(RandomColorDistortion, self).get_config()
        config.update({"contrast_range": self.contrast_range, "brightness_delta": self.brightness_delta})
        return config
        
input_shape_rgb = (256, 256, 3)
data_augmentation_rgb = tf.keras.Sequential(
  [ 
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomFlip("vertical"),
    tf.keras.layers.RandomRotation(0.5),
    tf.keras.layers.RandomZoom(0.5),
    tf.keras.layers.RandomContrast(0.5),
    RandomColorDistortion(name='random_contrast_brightness/none'),
  ]
)
input_shape = (256, 256, 3)
padding = 'same'
kernel_size = 3
model = tf.keras.Sequential([
  tf.keras.layers.Input(input_shape),
  data_augmentation_rgb,
  tf.keras.layers.Rescaling((1./255)),
  tf.keras.layers.Conv2D(16, kernel_size, padding=padding, activation='relu', strides=1, 
     data_format='channels_last'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.BatchNormalization(),

  tf.keras.layers.Conv2D(32, kernel_size, padding=padding, activation='relu'), # best 4
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.BatchNormalization(),

  tf.keras.layers.Conv2D(64, kernel_size, padding=padding, activation='relu'), # best 3
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.BatchNormalization(),

  tf.keras.layers.Conv2D(128, kernel_size, padding=padding, activation='relu'), # best 3
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.BatchNormalization(),

  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'), # best 1
  tf.keras.layers.Dropout(0.1),
  tf.keras.layers.Dense(128, activation='relu'), # best 1
  tf.keras.layers.Dropout(0.1),
  tf.keras.layers.Dense(64, activation='relu'), # best 1
  tf.keras.layers.Dropout(0.1),
  tf.keras.layers.Dense(5, activation = 'softmax')
 ])

model.compile(loss='categorical_crossentropy', optimizer='adam')
model.summary()
model.save("test", save_format='h5')

使用自定义图层加载模型将如下所示:

model = tf.keras.models.load_model('test.h5', custom_objects={'RandomColorDistortion': RandomColorDistortion})

where RandomColorDistortion是您的自定义图层的名称。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

使用数据增强层在 Tensorflow 2.7.0 上保存模型 的相关文章

随机推荐

  • PL/SQL中如何使用ifexists-ifnotexists?

    我正在尝试将 ifexists 语句从 SQL Server 转换为 PL SQL 但出现错误 我正在尝试检查是否NAME 1我的中不存在table 1 如果它们不存在 那么我正在检查是否COLUMN NAME NAME 2 存在于我的ta
  • 如何使用 REST API 在 keycloak 中重置用户密码

    我想对我的 Keycloak 服务器进行休息调用 根据文档 这应该很容易 https www keycloak org docs api 10 0 rest api index html executeactionsemail https
  • 支持的视频尺寸 MediaRecorder API android

    我正在尝试使用 mediarecorder 和 mediaprojection api 记录屏幕内容 当我尝试在设备上将视频更改为高清时 录制失败 但在 640 x 480 分辨率下工作正常 所以我的问题是如何获得特定设备上支持的视频分辨率
  • Ionic 模拟 android ERR_CONNECTION_REFUSED localhost:8100

    我尝试在 Android 上模拟我的 Ionic 应用程序 一旦我的应用程序在模拟设备中启动 它就会中断并出现以下错误 应用程序错误 净 ERR CONNECTION REFUSED http 本地主机 8100 http localhos
  • 有没有更好的方法在 C# 中创建深克隆和浅克隆?

    我一直在为一个项目创建对象 在某些情况下我必须为此对象创建深层副本 我想出了使用 C 的内置函数 MemberwiseClone 困扰我的问题是 每当我创建一个新类时 我就必须编写一个像下面的代码这样的函数来进行浅拷贝 有人可以帮我改进这部
  • C# 链式ContinueWith不等待上一个任务完成

    我正在测试 C async await 的异步性 并发现了一个惊喜 ContinueWith 的后续代码不会等待上一个任务完成 public async Task
  • 如何在 Android 模拟器中模拟总网络丢失

    我正在尝试编写一个应用程序 需要知道何时没有可用的 IP 网络连接 我正在使用 android net conn CONNECTIVITY CHANGE 广播事件以及 ConnectivityManager 对状态变化做出反应以实现此目的
  • 如何根据文本过滤 VS Code 中的问题?

    我在 Windows 10 x64 上使用 VS Code 1 41 0 在我的代码 使用您可能从未听说过的研究语言 中 我在 问题 面板中收到很多特定类别的警告消息 我想忽略这些消息 消息的文本在不同实例中略有不同 但始终包含 重复 一词
  • 寻找一种非 LL(1) 的语言?

    我最近一直在研究很多非 LL 1 的语法 其中许多可以转换为 LL 1 的语法 然而 我从未见过这样的例子明确的语言这不是 LL 1 换句话说 一种语言的任何明确语法都不是 LL 1 我也不知道如果我不小心偶然发现了一种语言 我将如何证明我
  • Python 中 open 和 codecs.open 的区别

    在 Python 中打开文本文件有两种方法 f open filename And import codecs f codecs open filename encoding utf 8 When is codecs open优于open
  • 如何获取 Firestore 文档大小?

    From Firestore 文档 https firebase google com docs firestore quotas 我们得到Firestore 文档的最大大小 is 文档的最大尺寸1 MiB 1 048 576 字节 QUE
  • Polarion ALM 工具 [关闭]

    Closed 这个问题是无关 help closed questions 目前不接受答案 我们正在公司寻找完整的 ALM 解决方案 我们正在研究 Polarion ALM 和 RTC 有人听说过 Polarion 完整的 ALM 工具吗 如
  • 将 SVN 提交发送到 RSS 源

    所以我最喜欢的网络工具 Subtlety http subtlety errtheblog com 最近已停止使用 这意味着我无法再轻松访问我关注的各种 SVN 项目的提交日志 是否有任何其他工具可以轻松地为公共 SVN 存储库生成 RSS
  • 在 Flutter 中实现视频源的最佳方式是什么?

    我正在 flutter 中构建一个应用程序 其中包含类似 TikTok 中的视频源 您可以想象一个 ListView 您可以在其中滚动浏览一些视频 5 25 秒 这些视频存储在 Google Cloud Platform 中 目前 包含超过
  • 无符号整数差异的意外结果

    我很惊讶这个函数为 dif1 和 dif2 产生不同的值 void test unsigned int x 0 y 1 long long dif1 x y long long dif2 int x y printf dif lld lld
  • 如何关闭 Intellij IDEA 中的自动括号生成?

    当输入函数名称 或自动完成 时 IDEA 会自动在其后面添加括号并将光标放在它们之间 富 我非常不喜欢这个 并且更希望它让我自己输入括号 有什么办法可以做到这一点吗 Update 回复 插入配对支架 设置 所以 这个选项对我来说已经关闭了
  • Graphics CopyFromScreen 方法如何复制位图?

    private void startBot Click object sender EventArgs e Bitmap bmpScreenshot Screenshot this BackgroundImage bmpScreenshot
  • AVSpeechSynthesizer isSpeaking 在 Swift 中不起作用

    因此 更新到 Xcode 12 0 1 后 AVSpeechSynthesizer 现在可以在模拟器上运行 它已经有一段时间没有为我工作了 现在 无论合成器是否正在说话 isSpeaking 变量始终为 false 我想根据合成器是否在说话
  • 如果位置不是美国,则使用 Amazon Mechanical Turk?

    亚马逊土耳其机器人 https www mturk com mturk welcome是一个大规模微外包 API 您可以在其中以相对便宜的价格 例如每张图像 0 10 U 完成大量简单的小任务 例如 此图像中是否有商店 亚马逊似乎认为这项服
  • 使用数据增强层在 Tensorflow 2.7.0 上保存模型

    尝试使用 Tensorflow 版本 2 7 0 保存具有数据增强层的模型时出现错误 这是数据增强的代码 input shape rgb img height img width 3 data augmentation rgb tf ker