张量流联合训练和评估期间的 MSE 误差不同

2024-04-06

我正在联合张量流中实现回归模型。我从本教程中使用的 keras 简单模型开始:https://www.tensorflow.org/tutorials/keras/regression https://www.tensorflow.org/tutorials/keras/regression

我更改了模型以使用联邦学习。这是我的模型:

import pandas as pd
import tensorflow as tf

from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_federated as tff

dataset_path = keras.utils.get_file("auto-mpg.data", "http://archive.ics.uci.edu/ml/machine-learning-databases/auto-mpg/auto-mpg.data")

column_names = ['MPG','Cylinders','Displacement','Horsepower','Weight',
                'Acceleration', 'Model Year', 'Origin']
raw_dataset = pd.read_csv(dataset_path, names=column_names,
                      na_values = "?", comment='\t',
                      sep=" ", skipinitialspace=True)

df = raw_dataset.copy()
df = df.dropna()
dfs = [x for _, x in df.groupby('Origin')]


datasets = []
targets = []
for dataframe in dfs:
    target = dataframe.pop('MPG')

    from sklearn.preprocessing import StandardScaler
    standard_scaler_x = StandardScaler(with_mean=True, with_std=True)
    normalized_values = standard_scaler_x.fit_transform(dataframe.values)

    dataset = tf.data.Dataset.from_tensor_slices(({ 'x': normalized_values, 'y': target.values}))
    train_dataset = dataset.shuffle(len(dataframe)).repeat(10).batch(20)
    test_dataset = dataset.shuffle(len(dataframe)).batch(1)
    datasets.append(train_dataset)


def build_model():
  model = keras.Sequential([
    layers.Dense(64, activation='relu', input_shape=[7]),
    layers.Dense(64, activation='relu'),
    layers.Dense(1)
  ])
  return model
dataset_path


import collections


model = build_model()

sample_batch = tf.nest.map_structure(
    lambda x: x.numpy(), iter(datasets[0]).next())

def loss_fn_Federated(y_true, y_pred):
    return tf.reduce_mean(tf.keras.losses.MSE(y_true, y_pred))

def create_tff_model():
  keras_model_clone = tf.keras.models.clone_model(model)
#   adam = keras.optimizers.Adam()
  adam = tf.keras.optimizers.SGD(0.002)
  keras_model_clone.compile(optimizer=adam, loss='mse', metrics=[tf.keras.metrics.MeanSquaredError()])
  return tff.learning.from_compiled_keras_model(keras_model_clone, sample_batch)

print("Create averaging process")
# This command builds all the TensorFlow graphs and serializes them: 
iterative_process = tff.learning.build_federated_averaging_process(model_fn=create_tff_model)

print("Initzialize averaging process")
state = iterative_process.initialize()

print("Start iterations")
for _ in range(10):
  state, metrics = iterative_process.next(state, datasets)
  print('metrics={}'.format(metrics))
Start iterations
metrics=<mean_squared_error=95.8644027709961,loss=96.28633880615234>
metrics=<mean_squared_error=9.511247634887695,loss=9.522096633911133>
metrics=<mean_squared_error=8.26853084564209,loss=8.277074813842773>
metrics=<mean_squared_error=7.975323677062988,loss=7.9771647453308105>
metrics=<mean_squared_error=7.618809700012207,loss=7.644164562225342>
metrics=<mean_squared_error=7.347906112670898,loss=7.340310096740723>
metrics=<mean_squared_error=7.210267543792725,loss=7.210223197937012>
metrics=<mean_squared_error=7.045553207397461,loss=7.045469760894775>
metrics=<mean_squared_error=6.861278533935547,loss=6.878870487213135>
metrics=<mean_squared_error=6.80275297164917,loss=6.817670822143555>
evaluation = tff.learning.build_federated_evaluation(model_fn=create_tff_model)


test_metrics = evaluation(state.model, datasets)
print(test_metrics)
<mean_squared_error=27.308320999145508,loss=27.19877052307129>

我很困惑为什么 10 次迭代后评估的 mse 对于训练集来说更高,而迭代过程返回的 mse 却小得多。我在这里做错了什么?在tensorflow中fml的实现中是否隐藏了一些东西?有人可以向我解释一下吗?


你实际上已经击中了非常有趣的现象在联邦学习中。特别是,这里需要问的问题是:训练指标是如何计算的?

通常计算训练指标在当地培训期间;因此,它们是在客户端拟合其本地数据时计算的;在 TFF 中,它们是在执行每个局部步骤之前计算的——这种情况发生here https://github.com/tensorflow/federated/blob/master/tensorflow_federated/python/learning/keras_utils.py#L465在前向传递呼叫期间。如果您想象极端情况,其中指标仅在end对每个客户进行一轮培训后,您会清楚地看到一件事——客户正在报告代表的指标它与他的本地数据的吻合程度如何.

然而,联邦学习必须在每轮训练结束时生成一个单一的全局模型——在联邦平均中,这些局部模型是在参数空间中一起平均。在一般情况下,尚不清楚如何直观地解释这样的步骤 - 参数空间中非线性模型的平均值不会为您提供平均预测或类似的结果。

联合评估采用这个平均模型,并对每个客户端运行本地评估,根本不拟合本地数据。因此,如果您的客户端数据集具有截然不同的分布,则您应该预期从联合评估返回的指标与从一轮联合训练返回的指标有很大不同 - 联合平均正在报告收集的指标适应本地数据的过程中,而联合评估正在报告收集的指标将所有这些本地训练的模型平均后.

事实上,如果您交错调用next迭代过程和评估函数的函数,您将看到如下模式:

train metrics=<mean_squared_error=88.22489929199219,loss=88.6319351196289>
eval metrics=<mean_squared_error=33.69473648071289,loss=33.55160140991211>
train metrics=<mean_squared_error=8.873666763305664,loss=8.882776260375977>
eval metrics=<mean_squared_error=29.235883712768555,loss=29.13833236694336>
train metrics=<mean_squared_error=7.932246208190918,loss=7.918393611907959>
eval metrics=<mean_squared_error=27.9038028717041,loss=27.866817474365234>
train metrics=<mean_squared_error=7.573018550872803,loss=7.576478958129883>
eval metrics=<mean_squared_error=27.600923538208008,loss=27.561887741088867>
train metrics=<mean_squared_error=7.228050708770752,loss=7.224897861480713>
eval metrics=<mean_squared_error=27.46322250366211,loss=27.36537742614746>
train metrics=<mean_squared_error=7.049572944641113,loss=7.03688907623291>
eval metrics=<mean_squared_error=26.755760192871094,loss=26.719152450561523>
train metrics=<mean_squared_error=6.983217716217041,loss=6.954374313354492>
eval metrics=<mean_squared_error=26.756895065307617,loss=26.647253036499023>
train metrics=<mean_squared_error=6.909178256988525,loss=6.923810005187988>
eval metrics=<mean_squared_error=27.047882080078125,loss=26.86684799194336>
train metrics=<mean_squared_error=6.8190460205078125,loss=6.79202938079834>
eval metrics=<mean_squared_error=26.209386825561523,loss=26.10053062438965>
train metrics=<mean_squared_error=6.7200140953063965,loss=6.737307071685791>
eval metrics=<mean_squared_error=26.682661056518555,loss=26.64984703063965>

也就是说,您的联合评估也在下降,只是比您的训练指标慢得多 - 有效地测量客户数据集的变化。您可以通过运行来验证这一点:

eval_metrics = evaluation(state.model, [datasets[0]])
print('eval metrics on 0th dataset={}'.format(eval_metrics))
eval_metrics = evaluation(state.model, [datasets[1]])
print('eval metrics on 1st dataset={}'.format(eval_metrics))
eval_metrics = evaluation(state.model, [datasets[2]])
print('eval metrics on 2nd dataset={}'.format(eval_metrics))

你会看到类似的结果

eval metrics on 0th dataset=<mean_squared_error=9.426984786987305,loss=9.431192398071289>
eval metrics on 1st dataset=<mean_squared_error=34.96992111206055,loss=34.96992492675781>
eval metrics on 2nd dataset=<mean_squared_error=72.94075775146484,loss=72.88787841796875>

因此您可以看到,您的平均模型在这三个数据集上的性能显着不同。

最后一点:您可能会注意到,您的最终结果evaluate函数是not你的三次损失的平均值——这是因为evaluate函数将是example- 加权,不client-加权——也就是说,拥有更多数据的客户在平均中获得更大的权重。

希望这可以帮助!

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

张量流联合训练和评估期间的 MSE 误差不同 的相关文章

  • Python Selenium:如何在文本文件中打印网站上的值?

    我正在尝试编写一个脚本 该脚本将从 tulsaspca org 网站获取以下 6 个值并将其打印在 txt 文件中 最终输出应该是 905 4896 7105 23194 1004 42000 放置的动物 的 HTML span class
  • 如何防止用户控件表单在 C# 中处理键盘输入(箭头键)

    我的用户控件包含其他可以选择的控件 我想实现使用箭头键导航子控件的方法 问题是家长控制拦截箭头键并使用它来滚动其视图什么是我想避免的事情 我想自己解决控制内容的导航问题 我如何控制由箭头键引起的标准行为 提前致谢 MTH 这通常是通过重写
  • 如何在发布期间复制未版本化的测试资源:执行?

    我的问题与 Maven 在发布时不会复制未跟踪的资源 https stackoverflow com questions 10378708 maven doesnt copy untracked resources while releas
  • CFdump cfcomponent cfscript

    可以在 cfcomponent 中使用 cfdump 吗 可以在 cfscript 中使用 cfdump 吗 我知道 anser 不是 那么如何发出 insde cfcomponent 函数的值 cf脚本 我用的是CF8 可以在 cfcom
  • 如何确定所有角度2分量都已渲染?

    当所有 Angular2 组件完成渲染时 是否会触发一个角度事件 For jQuery 我们可以用 function 然而 对于 Angular2 当domready事件被触发 html 只包含角度组件标签 每个组件完成渲染后 domrea
  • TIFF 元数据的最大大小是多少?

    TIFF 文件元数据的单个字段中可以合并的元数据数量是否有最大限制 我想在 ImageDescription 字段中存储大文本 最多几 MB 没有具体的最大限制ImageDescription但是 整个 TIFF 文件存在最大文件大小 该最
  • 如何在执行新操作时取消先前操作的执行?

    我有一个动作创建器 它会进行昂贵的计算 并在每次用户输入内容时调度一个动作 基本上是实时更新 但是 如果用户输入多个内容 我不希望之前昂贵的计算完全运行 理想情况下 我希望能够取消执行先前的计算并只执行当前的计算 没有内置功能可以取消Pro
  • 使用.NET技术录制屏幕视频[关闭]

    Closed 这个问题正在寻求书籍 工具 软件库等的推荐 不满足堆栈溢出指南 help closed questions 目前不接受答案 有没有一种方法可以使用 NET 技术来录制屏幕 无论是桌面还是窗口 我的目标是免费的 我喜欢小型 低
  • 如何从日期中查找该月的最后一天?

    如何在 PHP 中获取该月的最后一天 Given a date 2009 11 23 我要2009 11 30 并给出 a date 2009 12 23 我要2009年12月31日 t返回给定日期所在月份的天数 请参阅的文档date ht
  • 如何使用asm.js进行测试和开发?

    最近我读到asm js规范 看起来很酷 但是是否有任何环境 工具来开发和测试这个工具 这还只是处于规范阶段吗 您可以尝试使用 emscripten 和 ASM JS 1 并从侧分支在 firefox 构建中运行它 有关 asm js 的链接
  • 从超立方体图像中获取文本的确切位置

    使用 tesseract 中的 GetHOCRText 0 方法 我能够检索 html 中的文本 并在 webview 中呈现 html 时 我能够获取文本 但图像中文本的位置与输出不同 任何想法都非常有帮助 tesseract gt Se
  • 循环内的异步性

    我正在使用 jQuery getJSON 用于从一组实用程序的给定 URL 检索数据的 API 我真的很想找到一种为每个实用程序重用代码 完全相同 的方法 由于循环的执行与 ajax 调用无关 因此我无法找到保留循环值的方法 我知道这个描述
  • 用于验证目的的动态查找方法

    我正在使用 Ruby on Rails 3 0 7 我想在运行时查找一些记录以进行验证 但为该查找方法传递 设置一个值 也就是说 在我的班级中 我有以下内容 class Group lt lt ActiveRecord Base valid
  • neo4j - python 驱动程序,服务不可用

    我对 neo4j 非常陌生 我正在尝试建立从 python3 6 到 neo4j 的连接 我已经安装了驱动程序 并且刚刚开始执行第一步 导入请求 导入操作系统 导入时间 导入urllib 从 neo4j v1 导入 GraphDatabas
  • 使用 xpath 和 vtd-xml 以字符串形式获取元素的子节点和文本

    这是我的 XML 的一部分
  • 如何使用 Pycharm 安装 tkinter? [复制]

    这个问题在这里已经有答案了 I used sudo apt get install python3 6 tk而且效果很好 如果我在终端中打开 python Tkinter 就可以工作 但我无法将其安装在我的 Pycharm 项目上 pip
  • 如何将输入读取为数字?

    这个问题的答案是社区努力 help privileges edit community wiki 编辑现有答案以改进这篇文章 目前不接受新的答案或互动 Why are x and y下面的代码中使用字符串而不是整数 注意 在Python 2
  • Erlang dict的时间复杂度

    我想知道 Erlang OTP 是否dict模块是作为哈希表实现的 在这种情况下它是否能提供这样的性能 平均情况 Search O 1 n k Insert O 1 Delete O 1 n k 最坏的情况下 Search O n Inse
  • 在 Nexus 7 2013 上更改方向时 CSS 媒体查询不起作用

    我目前正在我的笔记本电脑 台式电脑和 Nexus 7 2013 上测试 CSS 媒体查询 除了 Nexus 7 之外 它们在台式机和笔记本电脑上都运行良好 当我更改方向时 除非刷新页面 否则样式不会应用 例如 以纵向模式握住设备时 页面正常
  • 强制 Listview 不重复使用视图(复选框)

    我做了一个定制Listview 没有覆盖getView 方法 Listview 中的每个项目都具有以下布局 联系布局 xml

随机推荐

  • 自旋锁与信号量

    信号量和自旋锁之间的基本区别是什么 我们什么时候会使用信号量而不是自旋锁 自旋锁和信号量主要有四个不同点 1 它们是什么 A spinlock是锁的一种可能实现 即通过忙等待 旋转 实现的锁 信号量是锁的概括 或者 相反 锁是信号量的特例
  • 带有节标题的列表视图android

    在 android listview gt Headerbar section 中是否有可能不滚动 直到该部分的列表不滚动 就像 iPhone 的桌面视图一样 我使用了部分列表视图 但我想要像这个 iphone 表格视图 有没有可能 谢谢
  • Jenkins 在 ClearCase 中创建视图

    我正在使用 Jenkins 和 ClearCase 进行自动构建 但遇到了问题 我编写了一个批处理脚本 使用cleartool命令mkview在ClearCase中创建视图 当我通过单击脚本来执行该脚本时 一切正常 视图是在 ClearCa
  • 解析在tinyxml中

    如何在 TinyXML 中解析以下内容
  • netstandard 1.5 中的 BinaryFormatter

    根据 NET CoreFx API 及其关联的 NET 平台标准版本列表 https github com dotnet corefx blob master Documentation architecture net platform
  • 在 .Net 中使用私有集初始化属性

    public class Foo public string Name get private set lt Because set is private void Main var bar new Foo Name baz lt This
  • 二维数组作为函数的参数

    为什么不能像处理普通数组一样在函数中声明二维数组参数 void F int bar Ok void Fo int bar Not ok void Foo int bar SIZE Ok 为什么需要声明列的大小 静态数组 你似乎没有完全明白这
  • 如何在yii2高级模板中上传web文件夹中的文件?

    我尝试在后端上传文件 每次上传文件时 它都会成功上传并成功保存在数据库中 但它没有保存到我指定的目录中 因此我的应用程序找不到该文件 并且我已经给出了 777对 web 目录中的 uploads 文件夹的权限 下面是我的代码 处理和保存文件
  • 如何使用 Compact Framework 在 C# 中验证 X.509 证书

    我正在尝试使用 C 和 NetCF 验证 X 509 证书 我有 CA 证书 如果我理解正确的话 我需要使用该 CA 证书中的公钥来解密不受信任的证书的签名 这应该给我不可信证书的计算哈希值 然后我应该自己计算证书的哈希值并确保两个值匹配
  • Swift组合:使用其他发布者(使用CombineLatest)的后续发布者不会“触发”

    我正在尝试复制 WWDC 2019 会议 实践中组合 中给出的 向导学校注册 示例https developer apple com videos play wwdc2019 721 https developer apple com vi
  • 属性的访问器实现

    是否有一些文档说明编译器如何自动生成属性的访问器 当编写自定义访问器 覆盖合成的访问器 时 最好了解原始实现 特别是要查看具有不同 弱 强 保留 复制等 属性的属性的访问器的不同实现 是否有一些文档说明编译器如何自动生成属性的访问器 编译器
  • 从 openstreetmap 获取城市边界

    我正在开发一个网站 我需要根据用户输入获取某个区域的所有边界 例如 用户想知道名为 x 的城市的边界 我应该如何从 openstreetmap 获取它 我听说过 xapi 和 osmosis 但在任何地方都找不到任何例子 谢谢 我在这里尝试
  • 使用media3库时添加MediaItem导致错误

    我正在使用最新的Android Media3库 但是我在使用它时发现了一个问题 我创建了一个媒体会话服务 然后得到MediaController中的Activity 然后当我尝试调用媒体控制器并添加一些 MediaItem 时 发生错误 j
  • Python/PyODBC 通过 IP 与可信连接连接到 SQL Server 2008 DB

    如果有人问这个问题 我提前道歉 尽管我发现了类似的问题 但我找不到正确的答案 我正在尝试通过使用可信连接的 IP 端口来连接到 SQL Server 2008 DB 另外一点复杂性是 数据库位于美国境外 通常我们通过 Citrix 登录 登
  • 告诉编译器泛型返回类型不借用任何对参数的引用?

    tldr gt 给定一个接受通用回调参数并返回关联类型的特征函数 编译器会抱怨关联类型可能从回调函数借用参数 有没有办法告诉编译器事实并非如此 细节 我计划实现一个接受回调参数的特征函数 并希望强制该特征函数的实现实际调用该回调 我通过让回
  • 保证文件关闭

    我有一个类 在构造函数中创建一个文件对象 该类还实现了 finish 方法作为其接口的一部分 在该方法中我关闭了文件对象 问题是 如果我在此之前遇到异常 文件将不会被关闭 相关类还有许多使用文件对象的其他方法 我需要将所有这些包装在一个最后
  • REST API 资源命名约定 - 用户或用户(复数)

    长版 对于某些人 包括我自己 来说 构建 REST API 过程中最痛苦 最令人头疼的部分之一是确定每个资源及其随附端点的名称 当然 这取决于个人喜好 有些事情是受到社区鼓励的 例如 大多数人 包括我 都会将他们的资源名称复数 GET no
  • 如何从日期时间获取时间跨度

    设想 第三方网络服务退货datetime在两个单独的字段中 即日期和时间 我需要一种连接成单个字段的方法 e g startDate 24 06 2012 startTime 1 01 1970 1 00 00 AM Expected re
  • 编辑距离矩阵

    我正在尝试构建一个程序 该程序接受两个字符串并为它们填充编辑距离矩阵 让我困惑的是 对于第二个字符串输入 它跳过了第二个输入 我尝试使用 getch 清除缓冲区 但没有成功 我也尝试过切换到 scanf 但这也导致了一些崩溃 请帮助 Cod
  • 张量流联合训练和评估期间的 MSE 误差不同

    我正在联合张量流中实现回归模型 我从本教程中使用的 keras 简单模型开始 https www tensorflow org tutorials keras regression https www tensorflow org tuto