联邦学习训练期间模型性能没有提高

2023-12-14

我已关注这个emnist教程创建图像分类实验(7 个类别),目的是使用 TFF 框架在 3 个数据孤岛上训练分类器。

在训练开始之前,我使用以下命令将模型转换为 tf keras 模型tff.learning.assign_weights_to_keras_model(model,state.model)评估我的验证集。无论标签如何,模型仅预测一类。这是可以预料的,因为尚未对模型进行训练。但是,我在每轮联合平均后重复此步骤,但问题仍然存在。所有验证图像都被预测为一类。我还在每轮之后保存 tf keras 模型权重,并对测试集进行预测 - 没有任何变化。

我为检查问题根源而采取的一些步骤:

  1. 检查每轮后转换 FL 模型时 tf keras 模型权重是否正在更新 - 它们正在更新。
  2. 确保缓冲区大小大于每个客户端的训练数据集大小。
  3. 将预测与训练数据集中的类别分布进行比较。存在类别不平衡,但模型预测的一个类别不一定是多数类别。而且,它并不总是同一类。大多数情况下,它仅预测 0 类。
  4. 将轮数增加到 5,每轮 epoch 增加到 10。这在计算上非常密集,因为它是一个相当大的模型,每个客户端需要大约 1500 个图像进行训练。
  5. 研究每次训练尝试的 TensorBoard 日志。随着回合的进行,训练损失正在减少。
  6. 尝试了一个更简单的模型 - 具有 2 个转换层的基本 CNN。这使我能够大大增加 epoch 和 rounds 的数量。在测试集上评估该模型时,它预测了 4 个不同的类别,但性能仍然很差。这表明我只需要增加原始模型的轮数和纪元数即可增加预测的变化。这是很困难的,因为这会导致大量的训练时间。

型号详情:

该模型使用 XceptionNet 作为基础模型,权重未冻结。当所有训练图像都汇集到全局数据集中时,这在分类任务中表现良好。我们的目标是希望达到与 FL 相当的性能。

base_model = Xception(include_top=False,
                      weights=weights,
                      pooling='max',
                      input_shape=input_shape)
x = GlobalAveragePooling2D()( x )
predictions = Dense( num_classes, activation='softmax' )( x )
model = Model( base_model.input, outputs=predictions )

这是我的训练代码:

def fit(self):
    """Train FL model"""
    # self.load_data()
    summary_writer = tf.summary.create_file_writer(
        self.logs_dir
    )
    federated_averaging = self._construct_iterative_process()
    state = federated_averaging.initialize()
    tfkeras_model = self._convert_to_tfkeras_model( state )
    print( np.argmax( tfkeras_model.predict( self.val_data ), axis=-1 ) )
    val_loss, val_acc = tfkeras_model.evaluate( self.val_data, steps=100 )

    with summary_writer.as_default():
        for round_num in tqdm( range( 1, self.num_rounds ), ascii=True, desc="FedAvg Rounds" ):

            print( "Beginning fed avg round..." )
            # Round of federated averaging
            state, metrics = federated_averaging.next(
                state,
                self.training_data
            )
            print( "Fed avg round complete" )
            # Saving logs
            for name, value in metrics._asdict().items():
                tf.summary.scalar(
                    name,
                    value,
                    step=round_num
                )
            print( "round {:2d}, metrics={}".format( round_num, metrics ) )
            tff.learning.assign_weights_to_keras_model(
                tfkeras_model,
                state.model
            )
            # tfkeras_model = self._convert_to_tfkeras_model(
            #     state
            # )
            val_metrics = {}
            val_metrics["val_loss"], val_metrics["val_acc"] = tfkeras_model.evaluate(
                self.val_data,
                steps=100
            )
            for name, metric in val_metrics.items():
                tf.summary.scalar(
                    name=name,
                    data=metric,
                    step=round_num
                )
            self._checkpoint_tfkeras_model(
                tfkeras_model,
                round_num,
                self.checkpoint_dir
            )
def _checkpoint_tfkeras_model(self,
                              model,
                              round_number,
                              checkpoint_dir):
    # Obtaining model dir path
    model_dir = os.path.join(
        checkpoint_dir,
        f'round_{round_number}',
    )
    # Creating directory
    pathlib.Path(
        model_dir
    ).mkdir(
        parents=True
    )
    model_path = os.path.join(
        model_dir,
        f'model_file_round{round_number}.h5'
    )
    # Saving model
    model.save(
        model_path
    )

def _convert_to_tfkeras_model(self, state):
    """Converts global TFF modle of TF keras model

    Takes the weights of the global model
    and pushes them back into a standard
    Keras model

    Args:
        state: The state of the FL server
            containing the model and
            optimization state

    Returns:
        (model); TF Keras model

    """
    model = self._load_tf_keras_model()
    model.compile(
        loss=self.loss,
        metrics=self.metrics
    )
    tff.learning.assign_weights_to_keras_model(
        model,
        state.model
    )
    return model

def _load_tf_keras_model(self):
    """Loads tf keras models

    Raises:
        KeyError: A model name was not defined
            correctly

    Returns:
        (model): TF keras model object

    """
    model = create_models(
        model_type=self.model_type,
        input_shape=[self.img_h, self.img_w, 3],
        freeze_base_weights=self.freeze_weights,
        num_classes=self.num_classes,
        compile_model=False
    )

    return model

def _define_model(self):
    """Model creation function"""
    model = self._load_tf_keras_model()

    tff_model = tff.learning.from_keras_model(
        model,
        dummy_batch=self.sample_batch,
        loss=self.loss,
        # Using self.metrics throws an error
        metrics=[tf.keras.metrics.CategoricalAccuracy()] )

    return tff_model

def _construct_iterative_process(self):
    """Constructing federated averaging process"""
    iterative_process = tff.learning.build_federated_averaging_process(
        self._define_model,
        client_optimizer_fn=lambda: tf.keras.optimizers.SGD( learning_rate=0.02 ),
        server_optimizer_fn=lambda: tf.keras.optimizers.SGD( learning_rate=1.0 ) )
    return iterative_process

  1. 回合数增加到5...

只运行几个rounds联邦学习听起来还不够。最早的联邦平均论文之一(麦克马汉 2016当 MNIST 数据具有非独立同分布分裂时,需要运行数百轮。最近(雷迪2020)需要数千个rounds对于 CIFAR-100。需要注意的一点是,每一“轮”都是全局模型的一个“步骤”。随着更多客户端时代的到来,该步长可能会更大,但这些是平均的,不同的客户端可能会减少全局步长的幅度。

我还在每轮之后保存 tf keras 模型权重,并对测试集进行预测 - 没有任何变化。

这可能令人担忧。如果您可以共享 FL 训练循环中使用的代码,调试会更容易。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

联邦学习训练期间模型性能没有提高 的相关文章

  • Python 中的 Lanczos 插值与 2D 图像

    我尝试重新缩放 2D 图像 灰度 图像大小为 256x256 所需输出为 224x224 像素值范围从 0 到 1300 我尝试了两种使用 Lanczos 插值来重新调整它们的方法 首先使用PIL图像 import numpy as np
  • Python 的键盘中断不会中止 Rust 函数 (PyO3)

    我有一个使用 PyO3 用 Rust 编写的 Python 库 它涉及一些昂贵的计算 单个函数调用最多需要 10 分钟 从 Python 调用时如何中止执行 Ctrl C 好像只有执行结束后才会处理 所以本质上没什么用 最小可重现示例 Ca
  • Django 管理员在模型编辑时间歇性返回 404

    我们使用 Django Admin 来维护导出到我们的一些站点的一些数据 有时 当单击标准更改列表视图来获取模型编辑表单而不是路由到正确的页面时 我们会得到 Django 404 页面 模板 它是偶尔发生的 我们可以通过重新加载三次来重现它
  • SQLAlchemy 通过关联对象声明式多对多自连接

    我有一个用户表和一个朋友表 它将用户映射到其他用户 因为每个用户可以有很多朋友 这个关系显然是对称的 如果用户A是用户B的朋友 那么用户B也是用户A的朋友 我只存储这个关系一次 除了两个用户 ID 之外 Friends 表还有其他字段 因此
  • 通过最小元素比较对 5 个元素进行排序

    我必须在 python 中使用元素之间的最小比较次数来建模对 5 个元素的列表进行排序的执行计划 除此之外 复杂性是无关紧要的 结果是一个对的列表 表示在另一时间对列表进行排序所需的比较 我知道有一种算法可以通过 7 次比较 总是在元素之间
  • Django:按钮链接

    我是一名 Django 新手用户 尝试创建一个按钮 单击该按钮会链接到我网站中的另一个页面 我尝试了一些不同的例子 但似乎没有一个对我有用 举个例子 为什么这不起作用
  • 如何使用Conda下载python包并随后离线安装?

    我知道通过 pip 我可以使用以下命令下载 Python 包 但 pip install 破坏了我的内部包依赖关系 当我做 pip download
  • 根据列值突出显示数据框中的行?

    假设我有这样的数据框 col1 col2 col3 col4 0 A A 1 pass 2 1 A A 2 pass 4 2 A A 1 fail 4 3 A A 1 fail 5 4 A A 1 pass 3 5 A A 2 fail 2
  • 如何从网页中嵌入的 Tableau 图表中抓取工具提示值

    我试图弄清楚是否有一种方法以及如何使用 python 从网页中的 Tableau 嵌入图形中抓取工具提示值 以下是当用户将鼠标悬停在条形上时带有工具提示的图表示例 我从要从中抓取的原始网页中获取了此网址 https covid19 colo
  • ubuntu 20.04 上无法获取卷积算法错误~tensorflow-gpu

    我有一个 NVIDIA 2070 RTX GPU 我的操作系统是 Ubuntu20 04 我已经使用 conda 安装了tensorflow gpu 包 我有not安装了 CUDA toolkit 我相信它还会安装 CUDA toolkit
  • SQLALchemy .query:类“Car”的未解析属性引用“query”

    我有一个这里已经提到的问题https youtrack jetbrains com issue PY 44557 https youtrack jetbrains com issue PY 44557 但我还没有找到解决方案 我使用 Pyt
  • 以编程方式停止Python脚本的执行? [复制]

    这个问题在这里已经有答案了 是否可以使用命令在任意行停止执行 python 脚本 Like some code quit quit at this point some more code that s not executed sys e
  • Python pickle:腌制对象不等于源对象

    我认为这是预期的行为 但想检查一下 也许找出原因 因为我所做的研究结果是空白 我有一个函数可以提取数据 创建自定义类的新实例 然后将其附加到列表中 该类仅包含变量 然后 我使用协议 2 作为二进制文件将该列表腌制到文件中 稍后我重新运行脚本
  • 如何加速Python中的N维区间树?

    考虑以下问题 给定一组n间隔和一组m浮点数 对于每个浮点数 确定包含该浮点数的区间子集 这个问题已经通过构建一个解决区间树 https en wikipedia org wiki Interval tree 或称为范围树或线段树 已经针对一
  • 如何使用 OpencV 从 Firebase 读取图像?

    有没有使用 OpenCV 从 Firebase 读取图像的想法 或者我必须先下载图片 然后从本地文件夹执行 cv imread 功能 有什么办法我可以使用cv imread link of picture from firebase 您可以
  • Python 的“zip”内置函数的 Ruby 等价物是什么?

    Ruby 是否有与 Python 内置函数等效的东西zip功能 如果不是 做同样事情的简洁方法是什么 一些背景信息 当我试图找到一种干净的方法来进行涉及两个数组的检查时 出现了这个问题 如果我有zip 我可以写这样的东西 zip a b a
  • Pygame:有没有简单的方法可以找到按下的任何字母数字的字母/数字?

    我目前正在开发的游戏需要让人们以自己的名义在高分板上计时 我对如何处理按键有点熟悉 但我只处理过寻找特定的按键 有没有一种简单的方法可以按下任意键的字母 而不必执行以下操作 for event in pygame event get if
  • 使用 Python 绘制 2D 核密度估计

    I would like to plot a 2D kernel density estimation I find the seaborn package very useful here However after searching
  • 使用其构造函数初始化 OrderedDict 以便保留初始数据的顺序的正确方法?

    初始化有序字典 OD 以使其保留初始数据的顺序的正确方法是什么 from collections import OrderedDict Obviously wrong because regular dict loses order d O
  • Statsmodels.formula.api OLS不显示截距的统计值

    我正在运行以下源代码 import statsmodels formula api as sm Add one column of ones for the intercept term X np append arr np ones 50

随机推荐

  • JavaScript 中的继承

    当我使用原型在 Javascript 中实现继承时 我遇到了一个奇怪的错误 我想知道是否有人可以解释这一点 在下面的代码中 我正在尝试从父类派生子类 parent class function byref if parent class p
  • 为什么使用ByRef时变量应该被赋值为“.Value”?

    有什么区别 A Something and A Value Something 我发现这仅在以下情况下才有效 Value用来 function main A Original A B Original B SetByRef1 ref A S
  • 在 Joomla 中添加特定于页面的 javascript 或 CSS

    如何仅在 Joomla 的某篇文章中包含 javascript 或 CSS 文件 我有一篇文章需要 jQuery UI 和相关主题 由于它没有在任何其他页面上使用 因此我只需要在这篇特定的文章中使用它 添加必要的
  • 为什么我的应用程序显示我正在请求通讯录权限?

    我有一个表盘应用程序 显示我正在请求联系人权限 但我没有 我不明白这是为什么 我有应用程序内结算功能 并且可以访问 Google Fit 数据 以及 Google Analytics 以下是我的清单中的权限列表
  • 如何基于现有文件数据库创建具有架构的内存数据库

    我有一个现有的数据库 其结构在整个应用程序中使用 数据库的实例会定期轮换 我有一个数据库文件template sqlite它用作所有新创建的数据库的模板 我想使用它 而不是创建脚本 这样我只需维护一个文件 即空数据库模板本身 我想基于该模板
  • 什么标准调用实际上是宏

    我问了一个问题here about assert它在标准中作为宏而不是函数实现 这给我带来了一个问题 因为这样的方式assert从接受参数的角度来看 它似乎是一个函数 assert true 因此我尝试将其用作 std assert tru
  • Array.fill 和 for 循环创建数组有什么区别[重复]

    这个问题在这里已经有答案了 我正在使用 React js 创建一个地下城爬行游戏 并使用 Array fill 0 初始化棋盘 但是当我在二维数组中设置一个元素时 它将整个数组 列 设置为 player 而不是单一元素 我还有一个creat
  • VFP OleDb 的 Sql 参数化语法错误

    我正在尝试为 DBF 文件创建 SQL 参数化更新命令 Visual Fox Pro 我不知道为什么 但我在 DbCommand ExecuteNonQuery 上有一个 语法错误 异常错误消息是 语法错误 我没有任何额外的信息 strin
  • 如何查找 .NET 命名空间的程序集名称,例如 Microsoft.WindowsAzure.ServiceRuntime

    我有一个一般性问题和具体示例 根据 Stack Overflow 上有关命名空间程序集的所有类似问题 这应该很容易 最常见的答案是在问题中找到的我如何知道导入特定 NET 命名空间时要包含哪些引用 所有 MSDN 文档页面都提到命名空间和程
  • 循环,每次迭代仅在 jQuery 延迟之后发生,何时/然后可能没有递归?

    我想在循环中调用 jQuery 延迟函数 但每次迭代都应该等待上一个迭代使用延迟函数完成when function num of iterations var arr for var i 1 i lt num of iterations i
  • 将鼠标悬停在文本上时显示工具提示

    我想创建扩展 当我将鼠标悬停在文本上时 该扩展允许显示自定义消息 例如 test text 应该给出工具提示 OK 而不是当前的 ITrackin 我试着跟随https learn microsoft com en us visualstu
  • Visual Studio - 不同的断点集

    在 Visual Studio 2015 及更高版本 中 是否可以拥有多组断点 我有几个场景 我需要调试 但对于每个场景 我希望有不同的断点集 手动启用 禁用它们非常耗时 您可以从断点窗口导出和导入断点 然后根据需要导入它们 或者 如果您不
  • 有没有办法通过 .onLongPressGesture 将第三个切换选项添加到开/关状态?

    我已经设置了一个切换开关 如下图所示 可以打开 关闭图像或通过 失败 我正在尝试使用长按手势向图像添加第三种状态 这会将图像变成带有斜杠图标的灰色 我已经在文本元素中实现了这一点 因为 at is 没有 bool 条件 但经过多次搜索后无法
  • 创建未刷新的文件输出缓冲区

    我正在尝试解决在 Linux 上运行的几个不同语言的程序中未刷新的文件 I O 缓冲区出现的问题 刷新缓冲区的解决方案很简单 但是未刷新缓冲区的问题是随机发生的 我对如何创建 重现 和诊断这种情况感兴趣 而不是寻求可能导致这种情况的帮助 这
  • 使用 Selenium WebDriver 进行 PrimeFaces 文件上传测试

    我已经成功测试了 fileUploadSimplehttp www primefaces org showcase ui fileUploadSimple jsf使用 webElement sendKeys 方法 它不适用于自动上传 有没有
  • 使用 JSON 对象作为负载向 REST API 发出 POST 请求

    我正在尝试使用具有 JSON 负载的 POST 请求从 REST API 获取 JSON 响应 应在发送前转换为 URL 编码文本 我已经按照一些教程来实现该过程 但收到状态代码 400 的错误 我可能没有对给定的 JSON 字符串进行编码
  • 如何在 iframe 上设置“X-Frame-Options”?

    如果我创建一个iframe像这样 var dialog div align center div dialog 如何使用 JavaScript 修复以下错误 拒绝展示 https www google com ua gws rd ssl 在
  • 执行 chrome.extension.getBackgroundPage() 时抛出错误

    我正在开发我的第一个扩展 并尝试创建一个简单的扩展来在页面上注入可拖动的 div 这很好用 但我想保留 div 在后台页面上的位置 我也在尝试本地存储 但想了解为什么这不起作用 我不需要按钮 因此没有创建 popup html 文件 我相信
  • XPath 查找节点是否存在

    使用 XPath 查询如何查找节点 标签 是否存在 例如 如果我需要确保网站页面具有正确的基本结构 例如 html body and html head title
  • 联邦学习训练期间模型性能没有提高

    我已关注这个emnist教程创建图像分类实验 7 个类别 目的是使用 TFF 框架在 3 个数据孤岛上训练分类器 在训练开始之前 我使用以下命令将模型转换为 tf keras 模型tff learning assign weights to