finetune

2023-11-16

finetune的含义是获取预训练好的网络的部分结构和权重,与自己新增的网络部分一起训练。下面介绍几种finetune的方法。

完整代码:https://github.com/toyow/learn_tensorflow/tree/master/finetune

一,如何恢复预训练的网络

方法一:

思路:恢复原图所有的网络结构(op)以及权重,获取中间层的tensor,自己只需要编写新的网络结构,然后把中间层的tensor作为新网络结构的输入。

存在的问题:

1.这种方法是把原图所有结构载入到新图中,也就是说不需要的那部分也被载入了,浪费资源。

2.在执行优化器操作的时候,如果不锁定共有的结构(layer2=tf.stop_gradient(layer2,name='layer2_stop')),会导致重名提示报错,因为原结构已经有一个优化器操作了,你再优化一下就重名了。

核心代码:
1.把原网络加载到新图里
def train():
#恢复原网络的op tensor
    with tf.Graph().as_default() as g:
        saver=tf.train.import_meta_graph('./my_ckpt_save_dir/wdy_model-15.meta')#把原网络载入到图g中
2.获取原图中间层tensor作为新网络的输入
        x_input=g.get_tensor_by_name('input/x:0')#恢复原op的tensor
        y_input = g.get_tensor_by_name('input/y:0')
        layer2=g.get_tensor_by_name('layer2/layer2:0')
        #layer2=tf.stop_gradient(layer2,name='layer2_stop')#layer2及其以前的op均不进行反向传播

        softmax_linear=inference(layer2)#继续前向传播
        cost=loss(y_input,softmax_linear)

        train_op=tf.train.AdamOptimizer(0.001,name='Adma2').minimize(cost)#重名,所以改名

 3.恢复所有权重
        saver.restore(sess,save_path=tf.train.latest_checkpoint('./my_ckpt_save_dir/'))
     

方法二:

思路:重新定义网络结构,保持共有部分与原来同名。在恢复权重时,只恢复共有部分。

1.自定义网络结构
def inference(x):
    with tf.variable_scope('layer1') as scope:
        weights=weights_variabel('weights',[784,256],0.04)
        bias=bias_variabel('bias',[256],tf.constant_initializer(0.0))
        layer1=tf.nn.relu(tf.add(tf.matmul(x,weights),bias),name=scope.name)
    with tf.variable_scope('layer2') as scope:
        weights=weights_variabel('weights',[256,128],0.02)
        bias=bias_variabel('bias',[128],tf.constant_initializer(0.0))
        layer2=tf.nn.relu(tf.add(tf.matmul(layer1,weights),bias),name=scope.name)
        # layer2=tf.stop_gradient(layer2,name='layer2_stop')#layer2及其以前的op均不进行反向传播
    with tf.variable_scope('layer3') as scope:
        weights=weights_variabel('weights',[128,64],0.001)
        bias=bias_variabel('bias',[64],tf.constant_initializer(0.0))
        layer3=tf.nn.relu(tf.add(tf.matmul(layer2,weights),bias),name=scope.name)
    with tf.variable_scope('softmax_linear_1') as scope:
        weights = weights_variabel('weights', [64, 10], 0.0001)
        bias = bias_variabel('bias', [10], tf.constant_initializer(0.0))
        softmax_linear = tf.add(tf.matmul(layer3, weights), bias,name=scope.name)
    return softmax_linear


2.恢复指定的权重
        variables_to_restore = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)[:4]#这里获取权重列表,只选择自己需要的部分
        saver = tf.train.Saver(variables_to_restore)

    with tf.Session(graph=g) as sess:
        #恢复权重
        saver.restore(sess,save_path=tf.train.latest_checkpoint('./my_ckpt_save_dir/'))#这个时候就是只恢复需要的权重了

二,如何获取锁层部分的变量名称,如何避免名称不匹配的问题。

   锁住了也可以显示所有变量。
   params_1=slim.get_model_variables()#放心大胆地获取纯净的参数变量,包括batchnorm
   
   params_2 = slim.get_variables_to_restore()  # 包含优化函数里面定义的动量等等变量,exclude       只能写全名。
   params_2 = [val for val in params_2 if 'Logits' not in val.name]#剔除含有这个字符的变量

   锁住了(trianable=False)就不显示。
   params_3 = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)  
   params_4 = tf.trainable_variables()
   # 不包含优化器参数

解决方法,利用slim.get_variables_to_restore(),紧跟在原网络结构后面。之后再写自己定义的操作。

 

三,如何给不同层设置不同的学习率

思路:minizie()函数实际由compute_gradients()和apply_gradients()两个步骤完成。

compute_gradients()返回的是(gradent,varibel)元组对的列表,把这个列表varibel对应的gradent乘以学习率,再把新列表传入apply_gradients()就搞定了。

核心代码:

softmax_linear=inference(x_input)#继续前向传播
cost=loss(y_input,softmax_linear)
train_op=tf.train.AdamOptimizer()
grads=train_op.compute_gradients(cost)#返回的是(gradent,varibel)元组对的列表
variables_low_LR = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)[:4]#获取低学习率的变量列表
low_rate=0.0001
high_rate=0.001
new_grads_varible=[]#新的列表
for grad in grads:#对属于低学习率的变量的梯度,乘以一个低学习率
    if grad[1] in variables_low_LR:
        new_grads_varible.append((low_rate*grad[0],grad[1]))
    else:
        new_grads_varible.append((high_rate * grad[0], grad[1]))
apply_gradient_op = train_op.apply_gradients(new_grads_varible)
sess.run(apply_gradient_op,feed_dict={x_input:x_train_batch,y_input:y_train_batch})

三,关于PB文件

一,保存:

ckpt类型文件,是把结构(mate)与权重(checkpoint)分开保存,恢复的时候也是可以单独恢复。而PB文件是把结构与权重保存进了一个文件里。其中权重被固化成了常量,无法进行再次训练了。

可以看到,我指定保存最后一个tensor。只保存了之前的结构和权重,甚至y都没保存。

核心代码:

graph = convert_variables_to_constants(sess,sess.graph_def,['softmax_linear/softmax_linear'])
tf.train.write_graph(graph,'.','graph.pb',as_text=False)

二,恢复

恢复的思路跟ckpt恢复网络结构类似,不过因为只保存了我指定tensor之前的结构,所以自然也只能恢复保存了的网络结构。

with tf.Graph().as_default() as g:
    x_place = tf.placeholder(tf.float32, shape=[None, 784], name='x')
    y_place = tf.placeholder(tf.float32, shape=[None, 10], name='y')
    with open('./graph.pb','rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())
        tf.import_graph_def(graph_def, name='')
#恢复tensor
        graph_op = tf.import_graph_def(graph_def,name='',input_map={'input/x:0':x_place},
                                       return_elements=['layer2/layer2:0','layer1/weights:0'])

或者可以用

# x_place = g.get_tensor_by_name('input/x:0')
 #y_place = g.get_tensor_by_name('input/y:0')
 #layer2 = g.get_tensor_by_name('layer2/layer2:0')

 

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

finetune 的相关文章

  • 矩阵论——正交向量

    向量正交 向量 u u u与向量 v v v正交 u
  • WPF中判断滚动条滚动条滑倒了最底端

    先是获取滚动条的方法 可以获取到空间内部自带的ScrollViewer region 获取所有控件子级元素的方法 返回该类型的List集合 public static List
  • 自己写不出代码我该怎么办

    在作业和练习中 自己写不出代码 这是一个在学习中经常出现的问题 那么该怎么解决这个问题呢 1 先分析实现的思路 拿到作业 按照要实现的功能 先分析去实现的思路 如果完全不知到该怎么去实现 完全是一头雾水 那最好就先看看其他人是如何实现的 或
  • simulink的模块封装与解封装

    MATLAB Simulink 使用技巧 模块封装 简单 1 新建或者打开Simulink仿真环境 2 选择需要封装的部分 单击鼠标右键选择 create subsystem 选项即可封装模块 MATLAB Simulink 使用技巧 模块
  • 在微信小程序中 使用uView rules 表单校验 validator 不起作用(无效)

    注意 如果需要兼容微信小程序 最好通过setRules方法设置rules规则 onReady 如果需要兼容微信小程序 并且校验规则中含有方法等 只能通过setRules方法设置规则 this refs form1 setRules this
  • Sublime4如何自定义代码补全内容

    1 先进入如下文件夹 2 这里举C 为例子 其他语言类似 创建C 文件夹 并在C 文件夹内创建Snippets文件夹 3 在Snippets文件夹下创建以 sublime snippet为后缀的文件 4 在文件中自定义代码补全的信息

随机推荐

  • vue 在style标签中使用变量

    1 定义变量 export default data return 背景y颜色 backgroundColor 00f 字体颜色 fontColor f00 2 在HTML中设置CSS使用的变量
  • [深入研究4G/5G/6G专题-22]: 5G NR开机流程3.4 - MAC层对SIB1的调度 - SIB1消息的格式与内容

    作者主页 文火冰糖的硅基工坊 文火冰糖 王文兵 的博客 文火冰糖的硅基工坊 CSDN博客 本文网址 目录 前言 前置条件 第1章 SIB1消息的格式 内容解析
  • Android sqlite常见sql语句

    创建一个测试表man select from man 查询man表所有信息 select from man where name like 四 删除操作 delete from man where name lucy2 部分字段查询 sel
  • java.sql.SQLException: The user specified as a definer (‘combined‘@‘%‘) does not exist

    java sql SQLException The user specified as a definer combined does not exist 今天我把公司的项目拷回来处理一些遗留的问题 文明的我 爆了粗口TMD 罪过罪过 话不
  • React中高阶组件、Render props、hooks

    这三者都是react中解决代码复用的主要方式 1 HOC 在官方解释中 高阶组件 HOC 是 React 中复用组件逻辑的一种高级技巧 HOC自身不是 React API 的一部分 它是一种基于 React 的组合特性而形成的一种设计模式
  • springCloud整合 Hystrix熔断器(配置)

    springCloud整合 Hystrix熔断器 文章目录 springCloud整合 Hystrix熔断器 前言 一 添加Hystrix依赖 二 properties文件开启熔断器 三 为调用另一个服务的接口添加实现类 前言 在分布式环境
  • Qt Installer Framework使用教程:

    步骤一 下载并安装Qt Installer Framework工具 http download qt io official releases qt installer framework 将安装目录添加到环境变量 如安装D盘时D Qt Q
  • 狂神说 MyBatis 笔记

    这里写目录标题 Mybatis 1 简介 1 1 什么是MyBaits 1 2 持久话 1 3 持久层 1 4 为什么需要Mybatis 2 第一个Mybatis程序 2 1 搭建环境 2 2 创建一个模块 2 3 编写代码 2 4 测试
  • 二分插入排序(c语言)

    一 什么是二分插入排序 二分法插入排序 简称二分排序 是在插入第i个元素时 对前面的0 i 1元素进行折半 先跟他们中间的那个元素比 如果小 则对前半再进行折半 否则对后半进行折半 直到left
  • <02-01-01> Spring IoC容器与Bean介绍(Introduction to the Spring IoC Container and Beans)

    上一篇 02 01 控制反转容器 The IoC Container 本章介绍了Spring 框架对控制反转 Inversion of Control IoC 设计原则的实现 IoC也被称为依赖注入 Dependency Injection
  • SpringBoot配置多个mysql数据源

    当我们在进行数据库分库分表操作是可能会需要到多个数据库 那么我们就需要对多个数据库的数据源进行配置 整理一下 今天在SpringBoot框架下多个数据源的配置过程 两个为例 1 配置数据库信息 在yml配置文件中配置需要的数据库信息 spr
  • 分布式理论基础:CAP和BASE

    CAP定理 分区 在分布式系统中 不同的节点分布在不同的子网络中 由于一些特殊的原因 这些子节点之间出现了网络不通的状态 但他们的内部子网络是正常的 从而导致了整个系统的环境被切分成了若干个孤立的区域 这就是分区 CAP定理 CAP原则又称
  • opencv-python中 boundingRect(cnt)以及cv2.rectangle用法

    转自 http blog csdn net zhangxb35 article details 47275277 矩形边框 Bounding Rectangle 是说 用一个最小的矩形 把找到的形状包起来 还有一个带旋转的矩形 面积会更小
  • 实用常识

    WolframAlpha是开发计算数学应用软件的沃尔夫勒姆 Wolfram 研究公司基于科学计算软件Mathematica开发出的新一代的搜索引擎 试图挑战Google搜索引擎的地位 能根据问题直接给出标准化答案的网站 比如输入一种材料名称
  • iOS开发 多线程的高级应用-信号量semaphore

    在iOS开发的道路上 多线程的重要性不言而喻 大部分我们都停留在基础的使用上面 缺乏高级应用 缺乏提升 是因为我们面对他太少 复杂的事情重复做 复杂的事务基础化 差距就是这样拉开了 言归正传 今天讲讲GCD的高级应用之信号量篇 一 信号量的
  • CMake使用小结

    CMake使用小结 指定本地库的位置 set Qt5 DIR path list APPEND CMAKE PREFIX PATH Qt5 DIR 设置编译输出的路径 set CMAKE ARCHIVE OUTPUT DIRECTORY D
  • 一阶低通滤波

    一阶低通滤波 前言 在使用单片机开发中 常常会用到的外设包括ADC采样 而采样必然会伴随这随机干扰引起的毛刺噪声 对于需要捕捉采样值突变的系统来说尤其需要减小毛刺突变的影响 从硬件电路和软件算法上都能一定程度的减少噪声达到滤波的目的 本文主
  • VSCode配置文件“.vscode/c_cpp_properties.json”不断被覆盖的原因及解决方法

    一 问题现象 昨天 我在用VSCode写一个小算法程序 使用CMake配置文件 CMakeLists txt 进行工程管理 算法测试倒还顺利 但VSCode出现了一个令人恼火的问题 每次重新打开VSCode后 配置文件 vscode c c
  • android 反编译

    使用工具 CSDN上下载地址 apktool 资源文件获取 下载 dex2jar 源码文件获取 下载 jd gui 源码查看 下载 Android反编译整合工具包 最新 下载 官方最新版本下载地址 apktool google code d
  • finetune

    finetune的含义是获取预训练好的网络的部分结构和权重 与自己新增的网络部分一起训练 下面介绍几种finetune的方法 完整代码 https github com toyow learn tensorflow tree master