Keras:如何保存模型或权重?

2024-05-21

如果这个问题看起来很简单,我很抱歉。但是阅读 Keras 保存和恢复帮助页面:

https://www.tensorflow.org/beta/tutorials/keras/save_and_restore_models https://www.tensorflow.org/beta/tutorials/keras/save_and_restore_models

我不明白如何使用“ModelCheckpoint”在训练期间保存。帮助文件提到它应该提供 3 个文件,我只看到一个,MODEL.ckpt。

这是我的代码:

checkpoint_dir = FolderName + "/tmp/model.ckpt"
cp_callback = k.callbacks.ModelCheckpoint(checkpoint_dir,verbose=1,save_weights_only=True)    
parallel_model.compile(optimizer=tf.keras.optimizers.Adam(lr=learning_rate),loss=my_cost_MSE, metrics=['accuracy])
    parallel _model.fit(image, annotation, epochs=epoch, 
    batch_size=batch_size, steps_per_epoch=10,
                                 validation_data=(image_val,annotation_val),validation_steps=num_batch_val,callbacks=callbacks_list)

另外,当我想在训练后加载重量时:

model = k.models.load_model(file_checkpoint)

我收到错误:

"raise ValueError('Unknown ' + printable_module_name + ':' + object_name) 
ValueError: Unknown loss function:my_cost_MSE"

my-cost_MSE 是我在训练中使用的成本函数。


首先,看起来您正在使用tf.keras(来自张量流)实现而不是keras(来自 keras-team/keras 存储库)。在这种情况下,如tf.keras 指南 https://www.tensorflow.org/guide/keras#import_tfkeras :

保存模型权重时,tf.keras 默认为检查点 格式。传递 save_format='h5' 以使用 HDF5。

另一方面,请注意添加回调ModelCheckpoint通常,大致相当于调用model.save(...)在每个纪元结束时,这就是为什么您应该期望保存三个文件(根据检查点格式 https://www.tensorflow.org/guide/checkpoints).

它不这样做的原因是因为,通过使用该选项save_weights_only=True,您只节省了权重。大致相当于替换调用model.save for model.save_weights在每个纪元结束时。因此,唯一被保存的文件是带有权重的文件。

从这里开始,您可以通过两种不同的方式进行操作:

仅存储权重

您需要预先加载模型(比如说结构),然后调用model.load_weights代替keras.models.load_model:

model = MyModel(...)  # Your model definition as used in training
model.load_weights(file_checkpoint)

请注意,在这种情况下,自定义定义不会出现问题(my_cost_MSE)因为您只是加载模型权重。

存储整个模型

另一种方法是存储整个模型并相应地加载它:

cp_callback = k.callbacks.ModelCheckpoint(
    checkpoint_dir,verbose=1,
    save_weights_only=False
)    
parallel_model.compile(
    optimizer=tf.keras.optimizers.Adam(lr=learning_rate),
    loss=my_cost_MSE,
    metrics=['accuracy']
)

model.fit(..., callbacks=[cp_callback])

然后你可以通过以下方式加载它:

model = k.models.load_model(file_checkpoint, custom_objects={"my_cost_MSE": my_cost_MSE})

请注意,在后一种情况下,您需要指定custom_objects因为需要它的定义来反序列化模型。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Keras:如何保存模型或权重? 的相关文章

随机推荐