Tensorflow - 平均恢复模型的模型权重

2023-12-31

鉴于我在相同的数据上训练了多个不同的模型,并且我训练的所有神经网络都具有相同的架构,我想知道是否可以恢复这些模型,平均它们的权重并使用平均值初始化我的权重。

这是图表外观的示例。基本上我需要的是我要加载的重量的平均值。

import tensorflow as tf
import numpy as np

#init model1 weights
weights = {
    'w1': tf.Variable(),
    'w2': tf.Variable()
}
# init model1 biases
biases = {
    'b1': tf.Variable(),
    'b2': tf.Variable()
}
#init model2 weights
weights2 = {
    'w1': tf.Variable(),
    'w2': tf.Variable()
}
# init model2 biases
biases2 = {
    'b1': tf.Variable(),
    'b2': tf.Variable(),
}

# this the average I want to create
w = {
    'w1': tf.Variable(
        tf.add(weights["w1"], weights2["w1"])/2
    ),
    'w2': tf.Variable(
        tf.add(weights["w2"], weights2["w2"])/2
    ),
    'w3': tf.Variable(
        tf.add(weights["w3"], weights2["w3"])/2
    )
}
# init biases
b = {
    'b1': tf.Variable(
        tf.add(biases["b1"], biases2["b1"])/2
    ),
    'b2': tf.Variable(
        tf.add(biases["b2"], biases2["b2"])/2
    ),
    'b3': tf.Variable(
        tf.add(biases["b3"], biases2["b3"])/2
    )
}

weights_saver = tf.train.Saver({
    'w1' : weights['w1'],
    'w2' : weights['w2'],
    'b1' : biases['b1'],
    'b2' : biases['b2']
    })
weights_saver2 = tf.train.Saver({
    'w1' : weights2['w1'],
    'w2' : weights2['w2'],
    'b1' : biases2['b1'],
    'b2' : biases2['b2']
    })

这就是我运行 tf 会话时想要得到的。 c 包含我想要用于开始训练的权重。

# Create a session for running operations in the Graph.
init_op = tf.global_variables_initializer()
init_op2 = tf.local_variables_initializer()

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    # Initialize the variables (like the epoch counter).
    sess.run(init_op)
    sess.run(init_op2)
    weights_saver.restore(
        sess,
        'my_model1/model_weights.ckpt'
    )
    weights_saver2.restore(
        sess,
        'my_model2/model_weights.ckpt'
    )
    a = sess.run(weights)
    b = sess.run(weights2)
    c = sess.run(w)

首先,我假设模型结构完全相同(相同的层数、相同的节点/层数)。如果不是,您将在映射变量时遇到问题(一个模型中有变量,但另一个模型中没有。

你想做的是进行 3 次会议。前 2 个从检查点加载,最后一个将保存平均值。您需要这样做是因为每个会话都将包含变量值的一个版本。

加载模型后使用tf.trainable_variables()获取模型中所有变量的列表。您可以将其传递给sess.run将变量获取为 numpy 数组。计算平均值后,使用 tf.assign 创建操作来更改变量。您还可以使用列表来更改初始值设定项,但这意味着传递到模型(并不总是一个选项)。

Roughly:

graph = tf.Graph()
session1 = tf.Session()
session2 = tf.Session()
session3 = tf.Session()

# Omitted code: Restore session1 and session2.
# Optionally initialize session3.

all_vars = tf.trainable_variables()
values1 = session1.run(all_vars)
values2 = session2.run(all_vars)

all_assign = []
for var, val1, val2 in zip(all_vars, values1, values2):
  all_assign.append(tf.assign(var, tf.reduce_mean([val1,val2], axis=0)))

session3.run(all_assign)

# Do whatever you want with session 3.
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

Tensorflow - 平均恢复模型的模型权重 的相关文章

随机推荐

  • strcpy() 和字符串数组

    我需要将用户的输入存储到字符串数组中 include
  • Ninject 和 DataContext 处置

    我正在使用 Ninject 从内核检索 DataContext 我想知道 Ninject 是否自动处置 DataContext 或者他如何处理 dispose 行为 根据我自己的经验 我知道处理数据上下文非常重要 并且每当您创建 DataC
  • 使用 R 合并重复列

    我有一个包含 4 列的表 其中第 1 3 列和第 2 4 列代表相同的变量 Codes Description Codes Description xxxxx describes xxxxx zzzzz describes zzzzz yy
  • CreateFile()串行通信问题[重复]

    这个问题在这里已经有答案了 我试图通过 USB 端口 名为 COM15 进行一些串行通信 但出现错误 这是发生错误的代码 HANDLE myPortHandle CreateFile COM15 GENERIC READ GENERIC W
  • 如何获取请求头、远程地址和其他HttpServletRequest特定信息?

    我有一个 JSF 2 0 Web 项目 我的 Web 有一个表单 它必须执行以下操作 获取表单的参数并将其保存在Bean中 完成 从 servlet 获取此信息 远程地址 远程主机 区域设置 内容类型 边界 内容长度 字符编码 将Bean数
  • 当后缀缺失时,编译器选择前缀 ++ - 谁说的?

    当您为用户定义类型定义前缀运算符 并且不提供后缀版本时 编译器 至少在 Visual C 中 将在您的代码调用缺少的 POSTFIX 版本时使用 PREFIX 版本 至少它会给你一个警告 但是 我的问题是 为什么它不给你一个未定义成员函数的
  • 如何处理 groovy 方法中的多个返回类型?

    我需要一种方法 在成功时返回 Id 在失败时返回错误列表 前代码片段 def save def errors if Employee save flush true return Employee id else errors add Ca
  • 在 Silverlight 中显示 ® 符号

    Folks 我正在尝试在我的 silverlight 应用程序中显示 和上标 TM 符号 我想将包含符号的文本保存在 resx 文件中 我尝试过的事情 将任何文档中的 符号复制粘贴到 resx 文件中 符号得到 显示在 resx 文件中 但
  • 获取方括号的内容,避免嵌套括号

    第一次发帖 来自 Google 的长期访客 我正在尝试提取一些方括号的内容 但是我遇到了一些麻烦 我已经让它适用于圆括号 如下所示 但我看不出应该如何修改它以适用于方括号 我本以为在这个例子中用圆形替换方形 反之亦然应该可行 但显然不行 它
  • 使用单个 flatMap() 比使用 map().flatMap() 更好吗?

    我想知道两种平面映射情况之间是否存在显着差异 Case 1 someCollection stream map CollectionElement getAnotherCollection flatMap Collection stream
  • 用户安装软件时自动安装依赖项(.Net)

    我正在使用 Net 3 5 c WPF 构建一个软件 我的软件需要用户安装 Net 3 5 和 Media Player 11 我想构建一个安装程序 在用户安装主软件时自动安装这两个组件 我该如何解决这个问题 该组件 1 Net 3 5 2
  • 如何在 C# 中生成 WSDL 而不发出 http 请求

    问候 我想编写一个单元测试来确保我们的 Web 服务没有更改上次已知发布版本的 WSDL 原因是对 WSDL 中对象的任何更改都会导致使用 Apache Axis 的客户端失败 即使您所做的只是添加一个不需要的属性 因此 如果发生更改 则需
  • 组合两个 def 后扁平化类型

    以下是一个玩具示例 用于演示现实生活中遗留方法的形状怪异和问题的要点 如你看到的anotherFunc 映射结束后personList将类型扩展为 Throwable List Throwable String 这不是预期的返回类型 而是效
  • 什么是 deep_ping [关闭]

    Closed 这个问题是无关 help closed questions 目前不接受答案 我不确定这是否是提问的正确论坛 但我也不知道在哪里提问 所以这是我的问题 深平 是什么意思 我尝试了谷歌 但仍然没有得到任何有关它的信息 另外 深度
  • DataTemplate 中的 TextBlock 忽略了 FontSize 样式

    TextBlock 的样式 如下 对 DataTemplate 的 TextBlock 没有影响 如果我在样式和模板中将 TextBlock 更改为 TextBox 则样式将按我的预期应用 为什么 TextBlock 会忽略样式 谢谢你 B
  • Android 撰写文本的自动链接

    有什么办法可以使用吗安卓 自动链接JetPack Compose Text 上的功能 我知道 在一个简单的标签 修饰符中使用此功能可能不是 声明性方式 但也许有一些简单的方法 对于文本样式我可以使用这种方式 val apiString An
  • 获取 R 中均值子组的均值

    我是 R 的新手 我不知道如何让 R 计算子组的平均值 而子组本身就是子组的平均值 我会解释得更清楚 我有一个像这样的数据框 GROUP WORD WLN 1 1 4 1 1 3 1 1 3 1 2 2 1 2 2 1 2 3 2 3 1
  • Python在同一个图上并排箱线图

    我正在尝试在 Python 2 7 中为下面 Pandas 数据框中 E 列中的每个分类值生成一个箱线图 A B C D E 0 0 647366 0 317832 0 875353 0 993592 1 1 0 504790 0 0418
  • Python - 反转列表中字符串的函数

    疯狂地学习Python 并且有很多很多的问题 这次关于函数 我需要创建两个函数 第一个函数用于数字来总结用户在列表中输入的所有内容 第二个函数是用户在列表中输入一些单词 并且函数不触及列表中的单词索引 取每个函数单词并返回相反的单词 在同一
  • Tensorflow - 平均恢复模型的模型权重

    鉴于我在相同的数据上训练了多个不同的模型 并且我训练的所有神经网络都具有相同的架构 我想知道是否可以恢复这些模型 平均它们的权重并使用平均值初始化我的权重 这是图表外观的示例 基本上我需要的是我要加载的重量的平均值 import tenso