tensorflow模型保存、读取与可训练参数提取

2023-05-16

一、保存、读取说明

  我们创建好模型之后需要保存模型,以方便后续对模型的读取与调用,保存模型我们可能有下面三种需求:1、只保存模型权重参数;2、同时保存模型图结构与权重参数;3、在训练过程的检查点保存模型数据。下面分别对这三种需求进行实现。

 

二、仅保存模型参数

  仅保存模型参数可以用一下的API:

  Model.save_weights(file_path)  # 将文件保存到save_path
  Model.load_weights(file_path)  # 将文件读取到save_path

  注意:由于save_weights只是保存权重w、b的参数值,所以在加载时最好保证我们的模型结构和原来保存的模型结构是相同的,否则可能会报错。.

  模型在保存之后会有多个文件:

  • index类型文件,在分布式计算中,索引文件会指示哪些权重存储在哪个分片。
  • checkpoint类型文件,检查文件点包含: 一个或多个包含模型权重的分片
  • 如果您只在一台机器上训练模型,那么您将有一个带有后缀的分片:.data-00000-of-00001

复制代码

import tensorflow as tf
import os

# 读取数据集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

# 数据集归一化
train_images = train_images / 255
train_labels = train_labels / 255  # 进行数据的归一化,加快计算的进程

# 创建模型结构
net_input=tf.keras.Input(shape=(28,28))
fl=tf.keras.layers.Flatten()(net_input)#调用input
l1=tf.keras.layers.Dense(32,activation="relu")(fl)
l2=tf.keras.layers.Dropout(0.5)(l1)
net_output=tf.keras.layers.Dense(10,activation="softmax")(l2)

# 创建模型类
model = tf.keras.Model(inputs=net_input, outputs=net_output)

# 查看模型的结构
model.summary()

# 模型编译
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss="sparse_categorical_crossentropy",
              metrics=['acc'])

# 模型训练
model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1)

# 模型存放路径
save_path = './save_weights/'
model.save_weights(save_path)

# 模型加载
model.load_weights(save_path)

# # 定义一个与原模型结构不同的模型
# net_in=tf.keras.Input(shape=(748,))
# net_out=tf.keras.layers.Dense(10,activation="softmax")(net_in)
# 
# # 用不同结构的模型读取参数,这里会报错
# model2=tf.keras.Model(inputs=net_in,outputs=net_out)
# model2.load_weights(save_path)

复制代码

 

三、同时保存结构与参数

  Keras使用HDF5标准提供基本保存格式,出于我们的目的,可以将保存的模型视为单个二进制blob。

  保存完整的模型非常有用,使我们可以在TensorFlow.js(HDF5, Saved Model) 中加载它们,然后在Web浏览器中训练和运行它们,或者使用TensorFlow Lite(HDF5, Saved Model)将它们转换为在移动设备上运行。

复制代码

# 模型训练
model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1)

# 保存模型
model.save('net_model.h5')

# 模型加载
new_model=tf.keras.models.load_model('net_model.h5')

复制代码

 

四、在训练过程的检查点保存模型数据

  在训练过程的检查点保存模型数据有两个作用:1、我们可以保存训练各个节点的数据,便于我们把训练效果最好的节点的模型挑选出来。2、可以随时先暂停训练模型,当想要训练时继续训练。

  在训练的检查点保存模型需要用到tf.keras.callbacks.ModelCheckpoint()类,这个是一个回调类,可以以列表形式传入到fit()方法的callbacks参数中。

  回调中类,文件名以.ckpt作为后缀,如文件路径'./checkpoint/train.ckpt',会在checkpoint生成三个文件,后缀与Model.save_weights()方法创建的文件后缀相同,意义也相同。以下为回调类的参数:

 

  tf.keras.callbacks.ModelCheckpoint()

  • filepath:string,保存模型文件的路径。
  • monitor:监控:要监控的数量。
  • verbose详细:详细模式,0或1。
  • save_best_only:如果save_best_only = True,则不会覆盖根据监控数量的最新最佳模型。
  • save_weights_only:如果为True,则只有模型的权重
    保存(model.save_weights(filepath)),否则保存完整模型(model.save(filepath))。
  • mode:{auto,min,max}之一。 如果save_best_only =
    True,则根据监控数量的最大化或最小化来决定覆盖当前保存文件。
    对于val_acc,这应该是max,对于val_loss,这应该是min等。在自动模式下,从监控量的名称自动推断方向。
  • period:检查点之间的间隔(时期数)。

  

复制代码

# 模型编译
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss="sparse_categorical_crossentropy",
              metrics=['acc'])

# 创建一个保存模型的回调函数,每5个周期保存一次权重
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='./checkpoint/train.ckpt',
    verbose=1,
    save_weights_only=True,
    period=5
)

# 模型训练
model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1, callbacks=[cp_callback])

# 加载模型
model.load_weights('./checkpoint/train.ckpt')

# # 继续训练模型
# model.fit()

复制代码

 

五、模型可训练参数的提取

  有时候我们需要查看模型的参数,但是模型参数的显示有时候由于数据过多不能再控制台全部显示,所以需要存放到文件来查看。以下是提取可训练参数的方法:

复制代码

# 参看可训练参数
import numpy as np
model.trainable_variables

# 设置全部可训练参数可打印,不然数据过多,有一部分会以省略号的形式显示
np.set_printoptions(threshold=np.inf)

# 可训练参数保存到文件
with open('trainable.txt', mode='w',encoding = "utf-8") as f:
    for t_v in model.trainable_variables:
        f.writelines(str(t_v.name) + '\n')  # 保存参数名字
        f.writelines(str(t_v.shape) + '\n')  # 保存参数形状
        f.writelines(str(t_v.numpy()) + '\n')  # 保存参数数值
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

tensorflow模型保存、读取与可训练参数提取 的相关文章

随机推荐

  • Re3 : Real-Time Recurrent Regression Networks for Visual Tracking of Generic Objects

    Re3 Real Time Recurrent Regression Networks for Visual Tracking of Generic Objects 2019 10 04 14 42 54 Paper xff1a https
  • 最新Java面试题大全1000+面试题附答案详解,最全面详细,看完稳了

    进大厂是大部分程序员的梦想 xff0c 而进大厂的门槛也是比较高的 xff0c 所以这里整理了一份阿里 美团 滴滴 头条等大厂面试大全 xff0c 其中概括的知识点有 xff1a Java MyBatis ZooKeeper Dubbo E
  • 程序员必须掌握的十种算法---二分查找算法

    二分查找算法核心代码简单 xff0c 但需要数组是经过排序的 span class hljs variable arr span 要查找的数组 span class hljs variable length span 数组的长度 span
  • 笔试题:在一个字符串中查找子字符串的个数

    题目 xff1a 在一个字符串中查找子字符串的个数 要求 xff1a 两个字符串之间以空格隔开 xff0c 前一个为字符串 xff0c 后一个为要查找的子字符串 结果输出字符串中包含的子字符串的个数 例如 xff1a 输入 xff1a ab
  • 笔试题:输入一串数字,以逗号隔开,将数字排序后输出

    题目 xff1a 输入一串数字 xff0c 以逗号隔开 xff0c 将数字排序后输出 具体要求 xff1a 如果有几个数字是连续的 xff0c 只输出最大的和最小的数字 例如 xff1a 输入 xff1a 1 xff0c 4 xff0c 6
  • C语言-函数指针

    1 函数指针 函数原型 xff1a span class hljs keyword int span span class hljs keyword sum span span class hljs keyword int span a s
  • C语言应用小技巧

    1 求字符串长度 span class hljs preprocessor include lt stdlib h gt span size t span class hljs built in strlen span span class
  • C和指针-编程练习

    第六章 1 查找source字符串中匹配chars字符串中任何字符的第一个字符 xff0c 函数返回一个指向source中第1个匹配所找到的位置的指针 如果source中的所有字符均不匹配 xff0c 返回NULL指针 如果任何一个参数为N
  • Bootloader和BIOS、uboot和grub和bootmgr的区别

    版权声明 xff1a 本文章参考了 Bootloader和BIOS Grub uboot概念 未经作者允许 xff0c 严禁用于商业出版 xff0c 否则追究法律责任 网络转载请注明出处 xff0c 这是对原创者的起码的尊重 xff01 x
  • 医学影像常用名词:

    医学影像处理 xff1a MPR xff1a Multiplanarreconstruction allows images to be created from the original axial plane ineither the
  • MsOS——概述

    自己接触的操作系统也有不少了 xff0c 如RT Thread Cos等 xff0c 这些实时操作系统基本的思想就是围绕任务的调度 更像是一个软件平台 xff0c 提供使用者丰富的软件资源 RT Thread是比较成功的应用于消费类产品的国
  • Pixhawk室内自动控制:参数设置

    Pixhawk室内自动控制 xff1a 参数设置 本文针对使用光流传感器和超声波传感器 xff08 或激光雷达 xff09 的Pixhawk用户 ArduCopter目前 xff08 3 52 xff09 已经能够使用光流传感器提供的位置信
  • python函数--capitalize()方法

    capitalize 方法 描述 Python capitalize 将字符串的第一个字母变成大写 其他字母变小写 语法 capitalize 方法语法 xff1a str capitalize 参数 无 返回值 该方法返回一个首字母大写的
  • c# 接口

    1 接口的特点 接口的定义是指定一组函数成员而不实现成员的引用类型 xff0c 其它类型和接口可以继承接口 定义还是很好理解的 xff0c 但是没有反映特点 xff0c 接口主要有以下特点 xff1a span class token pu
  • 在linux上增加swap交换空间

    在虚拟机里面安装oracle11g grid时 发现之前分配的swap交换空间不满足oracke gi安装的最低要求 xff0c 因为我分配的物理内存是8G xff0c 那么就按照要求需要8 12G的swap交换空间 xff0c 而我分配的
  • Elasticsearch7.6.1安装报错及解决过程

    Windows环境Elasticsearch7 6 1安装报错及解决过程 Elasticsearch是一个基于Lucene的搜索服务器 第一次安装ES7 6 1 xff0c 过程中遇到了一些报错 xff0c 把解决方法列出来 xff0c 总
  • 我的 Ubuntu 装机必备软件

    文章目录 我的 Ubuntu 装机必备软件Ubuntu的安装u盘制作添加中科大镜像源NVIDIA显卡驱动的安装卸载旧显卡驱动 安装sogou输入法下载安装配置 安装gitROS kinetic installationgoogle chro
  • GCC源码分析(十三) — 机器描述文件

    版权声明 xff1a 本文为CSDN博主 ashimida 64 的原创文章 xff0c 遵循CC 4 0 BY SA版权协议 xff0c 转载请附上原文出处链接及本声明 原文链接 xff1a https blog csdn net lid
  • VNC登录报错too many security failures解决方法

    桌面进程编号为1 xff0c 可以通过使用 sudo vncserver kill 1 sudo vncserver 1 杀掉并重启解决
  • tensorflow模型保存、读取与可训练参数提取

    一 保存 读取说明 我们创建好模型之后需要保存模型 xff0c 以方便后续对模型的读取与调用 xff0c 保存模型我们可能有下面三种需求 xff1a 1 只保存模型权重参数 xff1b 2 同时保存模型图结构与权重参数 xff1b 3 在训