张量流:在多个检查点运行模型评估

2024-01-12

在我当前的项目中,我训练一个模型并每 100 个迭代步骤保存检查点。检查点文件全部保存到同一目录(model.ckpt-100、model.ckpt-200、model.ckpt-300 等)。之后,我想根据所有已保存检查点(而不仅仅是最新检查点)的验证数据来评估模型。

目前我用于恢复检查点文件的代码如下所示:

ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
ckpt_list = saver.last_checkpoints
print(ckpt_list)
if ckpt and ckpt.model_checkpoint_path:
    print("Reading model parameters from %s" % ckpt.model_checkpoint_path)
    saver.restore(sess, ckpt.model_checkpoint_path)
    # extract global_step from it.
    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
    print('Succesfully loaded model from %s at step=%s.' %
            (ckpt.model_checkpoint_path, global_step))
else:
    print('No checkpoint file found')
    return

但是,这仅恢复最新保存的检查点文件。那么如何在所有保存的检查点文件上编写循环呢?我尝试使用 saver.last_checkpoints 获取检查点文件列表,但是返回的列表为空。

任何帮助将不胜感激,提前致谢!


最快的解决方案:

tensor2tensor有一个模块utils带脚本avg_checkpoints.py https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/utils/avg_checkpoints.py将平均权重保存在新的检查点中。假设您有一个要计算平均值的检查点列表。您有 2 个使用选项:

  1. 从命令行

    TRAIN_DIR=path_to_your_model_folder
    FNC_PATH=path_to_tensor2tensor+'/utils/avg.checkpoints.py'
    CKPTS=model.ckpt-10000,model.ckpt-20000,model.ckpt-100000
    
    python3 $FNC_PATH --prefix=$TRAIN_DIR --checkpoints=$CKPTS \ 
        --output_path="${TRAIN_DIR}averaged.ckpt"
    
  2. 从您自己的代码(使用os.system):

    import os
    os.system(
        "python3 "+FNC_DIR+" --prefix="+TRAIN_DIR+" --checkpoints="+CKPTS+
        " --output_path="+TRAIN_DIR+"averaged.ckpt"
    )
    

作为指定检查点列表并使用--checkpoints参数,你可以使用--num_checkpoints=10计算最后 10 个检查点的平均值。

如果您不想依赖tensor2tensor:

这是一个不依赖的代码片段tensor2tensor,但仍然可以平均检查点数量可变(与特德的回答相反)。认为steps是应该合并的检查点列表(例如[10000, 20000, 30000, 40000]).

Then:

# Restore all sessions and save the weight matrices
values = []
for step in steps:
    tf.reset_default_graph()
    path = model_path+'/model.ckpt-'+str(step)
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph(path+'.meta')
        saver.restore(sess, path)
        values.append(sess.run(tf.all_variables()))

# Average weights
variables = tf.all_variables()
all_assign = []
for ind, var in enumerate(variables):
    weights = np.concatenate(
        [np.expand_dims(w[ind],axis=0)  for w in values],
        axis=0
    )
    all_assign.append(tf.assign(var, np.mean(weights, axis=0))

然后您可以按照您的喜好继续操作,例如保存平均检查点:

# Now save the new values into a separate checkpoint
with tf.Session() as sess_test:
    sess_test.run(all_assign)
    saver = tf.train.Saver() 
    saver.save(sess_test, model_path+'/average_'+str(num_checkpoints))
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

张量流:在多个检查点运行模型评估 的相关文章

随机推荐

  • mysql 连接一列中的多个值

    我需要进行一个查询 创建来自 2 个表的 3 列 这些表具有以下关系 表 1 的列 ID 与表 2 的列 ID2 相关 在表 1 中有一个名为 user 的列 在表 2 中有一个名为名称的列 可以有 1 个唯一用户 但可以有多个与该用户关联
  • ngram 的 dict 函数

    我有这样的文字 library dplyr glimpse text chr 1 11 Welcome to Wikipedia bla Discover Ekopedia the practical encyclopedia about
  • MKPolylineRenderer 上的两条彩色/自定义线

    我正在画 x 条线MKMapView 数据是从网络服务下载的 对于每条需要绘制的线 我正在创建一个MKPolyline 将其添加到overlayArray 某些叠加层有多个折线 最后通过以下方式将该叠加层添加到地图中 self mapVie
  • .Value = "" 和 .ClearContents 之间有什么区别?

    如果我运行以下代码 Sub Test 1 Cells 1 1 ClearContents Cells 2 1 Value End Sub 当我使用公式检查 Cells 1 1 和 Cells 2 1 时ISBLANK 两个结果都返回TRUE
  • 指定的解决方案配置“最新|任何 CPU”无效

    我在 Jenkins 中运行的 MSBuild 脚本遇到此错误 C
  • 将 HTML 导出到 Excel 的单元格背景颜色

    我正在尝试将 HTML 表导出到 Excel 但我无法通过 CSS 设置单元格的背景颜色 我尝试过 Response Write 但这对我的输出没有影响 单元格颜色是否有一些 mso css 属性 尝试使用类似的东西
  • 在 Android 中编写许多 HTTP 请求的良好设计模式

    在我的应用程序中 我有很多 GET POST PUT 请求 现在 我有一个单例类来保存我下载的数据 并且有许多扩展 AsyncTask 的内部类 在我的单例类中 我还有一些这样的接口 Handlers for notifying liste
  • IMagick 检查图像亮度

    我需要能够在图像内自动写入一些文本 根据图像亮度 脚本必须用白色或黑色书写 那么如何使用 Imagick 检查图像的亮度 暗度呢 你可以这样做 Load the image imagick new Imagick image jpg con
  • Django/Python:了解 super 在函数中的使用方式

    我刚刚开始思考什么super以及它是如何在 Django 中基于视图的类中实现的 我试图了解 super 在以下代码中是如何工作的 有人可以尝试为我一点一点地分解它吗 from django views generic detail imp
  • Bash 忽略特定命令的错误

    我正在使用以下选项 set o pipefail set e 在 bash 脚本中 出现错误时停止执行 我有大约 100 行脚本正在执行 我不想检查脚本中每一行的返回代码 但对于一个特定的命令 我想忽略该错误 我怎样才能做到这一点 解决方案
  • 从 SearchView 更改片段提交又名级联到后台堆栈

    我目前正在使用SearchView对象 以便为我的应用程序提供建议输入的功能 然而 这个小部件在提交时使用intent filter开始您的搜索 当我的应用程序在手机上运行时 这非常棒 因为我想做的是启动搜索结果Activity显示响应 H
  • Devise - 如果帐户未确认,则重定向到页面

    如果用户的帐户尚未得到确认 我会尝试重定向用户 所以这涉及到两部分代码 首次创建帐户后重定向用户 如果他们在确认帐户之前尝试登录 请重定向他们 我需要第二个方面的帮助 我首先能够通过放入after inactive sign up path
  • Maven 编译器插件

    我知道默认的 Maven 编译器插件绑定到 compile 测试编译 生命周期 一般在不指定附加配置的情况下 我们不必 在我们的 POM 中明确定义它 但我仍然看到经验丰富的开发人员将诸如 这在他们的 POM 中 例如
  • 如何禁用主干历史记录但仍允许基于哈希的路由?

    假设我执行以下操作 单击主页 上的链接并转到 posts 1 触发事件并前往主干路由 posts 1 1 edit 我点击返回 我需要这样做 以便用户最终回到主页 而不是回到 posts 1 所以我需要允许骨干哈希路由工作但不修改历史记录
  • gcc 抑制警告“太小,无法容纳所有值”

    我需要使用范围枚举 以便我可以将它们作为特定类型传递给我们的序列化程序 我已经为枚举成员给出了明确的整数值Enum1 我已将与上面的描述相匹配的两个作用域枚举放入位字段中 enum class Enum1 value1 0x0 value2
  • Recyclerview 按字母顺序滚动条

    我需要实现一个类似于三星音乐应用程序的recyclerview字母滚动条 由于信誉低 我无法发布图像 我已阅读有关此的所有帖子 但我不想要气泡卷轴 我将所有字母表都放在垂直 LinearLayout 中 我想知道如何滚动到特定项目 你可以用
  • 如何在 JavaScript 中使用 x,y 坐标模拟点击?

    是否可以使用给定的坐标来模拟网页中 JavaScript 的点击 您可以派遣一个click事件 尽管这与真正的点击不同 例如 它不能用于欺骗跨域 iframe 文档 使其认为它已被单击 所有现代浏览器都支持document elementF
  • 在Python中按索引从列表中删除元素的简洁方法

    我有一个字符列表和索引列表 myList a b c d toRemove 0 2 我想通过一次操作得到这个 myList b d 我可以做到这一点 但有没有办法做得更快 toRemove reverse for i in toRemove
  • Java FileWriter 和 BufferedWriter 的区别

    它们之间有什么区别 我刚刚学习 Java ATM 但似乎我可以两种方式写入文件 我没有在这里复制 try catch 块 FileWriter file new FileWriter foo txt file write foobar fi
  • 张量流:在多个检查点运行模型评估

    在我当前的项目中 我训练一个模型并每 100 个迭代步骤保存检查点 检查点文件全部保存到同一目录 model ckpt 100 model ckpt 200 model ckpt 300 等 之后 我想根据所有已保存检查点 而不仅仅是最新检