TFF 加载预训练的 Keras 模型

2023-12-04

我的目标是从 .hdf5 文件加载基本模型(它是 Keras 模型),并继续通过联合学习对其进行训练。以下是我初始化 FL 基本模型的方法:

def model_fn():
    model = tf.keras.load_model(path/to/model.hdf5)
    return tff.learning.from_keras_model(model=model, 
                                         dummy_batch=db, 
                                         loss=loss, 
                                         metrics=metrics)

trainer = tff.learning.build_federated_averaging_process(model_fn)
state = trainer.initialize()

然而,似乎生成的 state.model 权重是随机初始化的,并且与我保存的模型不同。当我在任何联合训练之前评估模型的性能时,它的性能就像随机初始化的模型一样:准确度为 50%。以下是我评估性能的方式:

def evaluate(state):
    keras_model = tf.keras.models.load_model(path/to/model.hdf5, compile=False)
    tff.learning.assign_weights_to_keras_model(keras_model, state.model)
    keras_model.compile(loss=loss, metrics=metrics)
    return keras_model.evaluate(features, values)

如何使用保存的模型权重初始化 tff 模型?


是的,我认为预计initialize将重新运行初始化程序并返回该值。

然而,有一种方法可以用 TFF 来做这样的事情。 TFF 是强类型和函数式的 - 如果我们可以使用正确的值构造一个参数,该参数与上面的联合平均过程所期望的类型相匹配,那么事情应该“正常工作”。所以这里的目标是构造满足这些要求的论证。

您可以查看FileCheckpointManager's加载实施在这里寻求一点灵感,但我认为 Keras 的情况更简单。

假设你有手state就像上面和model您的 Keras 模型,这里有一个解包和重新打包所有内容的快捷方式 - 如中所示本节TFF 的教程之一——即使用tff.learning.state_with_new_model_weights。如果您具有上述状态和模型(并且 TF 处于 eager 模式),则以下内容应该适合您:

state = tff.learning.state_with_new_model_weights(
    state,
    trainable_weights=[v.numpy() for v in model.trainable_weights],
    non_trainable_weights=[
        v.numpy() for v in model.non_trainable_weights
    ])

这应该将模型的权重重新分配给模型的适当元素state object.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TFF 加载预训练的 Keras 模型 的相关文章

随机推荐

  • 使用下拉菜单动态过滤 WordPress 帖子(使用 php 和 ajax)

    目标 我想制作一个动态页面 允许访问者从下拉菜单中选择月份和年份 并根据所选值更改页面上的内容 帖子 我目前正在使用以下代码来显示特定月份和年份的特定类别的帖子 ul li li ul
  • 检查 UTD 参数在存储过程中是否有值

    我想检查具有某些用户定义表类型的参数是否具有值或为 NULL 但我收到以下错误 Msg 137 Level 16 State 1 Procedure SearchByWord Line 63 Batch Start Line 7 Must
  • 从 Gradle、Spring 和 DB2 开始的挑战

    我对该项目的想法是使用 Gradle 编写一个简单的 Spring 纯 java 应用程序 该应用程序将连接到 DB2 数据库并提取一些数据并在控制台上打印 首先 我使用 Eclipse Luna 创建了一个 Gradle 项目 我的挑战
  • C++ 中的重载 +=

    如果我已经重载了operator 和operator 我还需要重载吗 运算符 这样的东西可以工作 MyClass mc1 mc2 mc1 mc2 是的 您也需要定义它 然而 一个常见的技巧是定义operator 然后实施operator 就
  • 异常消息:当前应用程序配置不支持 WebSockets

    最近我升级到 Windows 2012 Standard 服务器 64 位 这样我就可以使用 Web 套接字 我已通过服务器管理器添加了角色 功能 我注意到我的 Windows 日志中有这样的消息 Event code 3005 Event
  • 如何使用 RGB 像素值绘制直方图?

    我正在netbeans平台上制作应用程序 我想画直方图 我有红色 绿色和蓝色的图像像素 那么 请有人向我询问如何使用该像素值绘制直方图 我的代码如下 其中我采用图像的红色 绿色和蓝色像素值 enter code here import ja
  • 未定义的局部变量或方法 - 使用烧杯测试 Puppet 模块

    我对这一切都很陌生 我正在尝试使用烧杯测试木偶模块 我不断得到这个 NoMethodError undefined method describe for Beaker TestCase 0x007fd6f95e6460 Users use
  • 如何在 Java 中按值(ArrayList)大小对 Map 进行排序?

    我有以下地图 Map
  • 为 Google.Apis.YouTube.v3 设置代理

    我有以下代码来调用 YouTubeService service new YouTubeService new BaseClientService Initializer ApiKey AppSettings Variables YouTu
  • docker内部和外部用户之间的混淆

    所以 我正在内部使用 apache2 构建一个 docker 容器 但我遇到权限问题 我不知道如何解决它 如果我运行没有 user 规范的容器 它运行良好 但我想外在地能够将其分配给用户并限制该用户只能读取和写入特定目录 我使用 v 映射的
  • 所有 OpenMP 任务在同一线程上运行

    我使用 OpenMP 中的任务编写了一个递归并行函数 虽然它给了我正确的答案并且运行良好 但我认为并行性存在问题 与串行解决方案相比 运行时间在我在没有任务的情况下解决的相同其他并行问题中无法扩展 当打印任务的每个线程时 它们都在线程 0
  • 打印给定 pid 的子进程 (MINIX)

    我目前正在开发一个项目 作为该项目的一部分 我需要在 MINIX 中实现系统调用 库函数 作为其中的一部分 我需要能够使用给定进程的 pid 打印其子进程列表 我想我已经找到了我需要的部分内容 但我坚持让它与给定的 pid 一起工作 str
  • Python 的 SSH 隧道自动关闭

    我需要一些关于我的程序结构的建议 我正在使用连接到外部 MySQL 数据库ssh隧道 现在它可以正常工作 我可以发出 SQL 命令并获取结果 但前提是这些命令与打开连接的函数相同 如果它们处于不同的功能 隧道会在我使用之前自动关闭 参见下面
  • GemBox 从电子表格或 Flexcel 检索计算值

    根据他们的文档 GemBox Spreadsheet可以读取和写入公式 但不能计算公式结果 当您在 MS Excel 中打开 XLS 文件时 将自动计算公式结果 因此 如果我创建一个包含一些注入值的电子表格并将其保存到磁盘 如果我在 Exc
  • 如何获取准确的拨出电话接听时间?

    我是安卓新手 我正在实现一个与来电和去电详细信息相关的应用程序 我通过使用广播接收器获取拨出电话和来电详细信息 问题是当有来电时广播接收器会上升 我拨打广播接收器拨打的电话 很好 但是当我单击绿色按钮时 拨出电话就会开始 但是 我想要接听对
  • 深拷贝和浅拷贝有什么区别?

    这个问题的答案是社区努力 编辑现有答案以改进这篇文章 目前不接受新的答案或互动 深拷贝和浅拷贝有什么区别 广度与深度 考虑以对象作为根节点的引用树 Shallow 变量 A 和 B 引用不同的内存区域 当将 B 分配给 A 时 这两个变量引
  • SQL Server 2008 - 高级搜索/排序

    我需要对列进行搜索并按特定顺序对结果进行排序 搜索条件和排序顺序如下 给定搜索文本的至少 x 和至多所有字符必须匹配 结果应按开头 然后按匹配的字符数排序 和包含 然后按匹配的字符数和字母顺序排序 进行分组 例如 搜索文本 联盟A 数据库中
  • Java 在 JTextPane 上设置缩进大小

    我想设置制表符 t 的大小JTextPane宽度为 4 个空格 经过一番谷歌搜索后 我发现了一些东西 我将在这里包含我所尝试过的东西以及它们失败的原因 如何在 JEditorPane 中设置选项卡大小 JTextPane不是一个普通的文档
  • Windows Phone 7 和 System.Xml.Linq 库

    我正在尝试遵循有关 WP7 开发的教程 http mobile tutsplus com tutorials windows introduction to windows mobile 7 development 它谈论的是使用 XEle
  • TFF 加载预训练的 Keras 模型

    我的目标是从 hdf5 文件加载基本模型 它是 Keras 模型 并继续通过联合学习对其进行训练 以下是我初始化 FL 基本模型的方法 def model fn model tf keras load model path to model