我已关注这个emnist教程创建图像分类实验(7 个类别),目的是使用 TFF 框架在 3 个数据孤岛上训练分类器。
在训练开始之前,我使用以下命令将模型转换为 tf keras 模型tff.learning.assign_weights_to_keras_model(model,state.model)
评估我的验证集。无论标签如何,模型仅预测一类。这是可以预料的,因为尚未对模型进行训练。但是,我在每轮联合平均后重复此步骤,但问题仍然存在。所有验证图像都被预测为一类。我还在每轮之后保存 tf keras 模型权重,并对测试集进行预测 - 没有任何变化。
我为检查问题根源而采取的一些步骤:
- 检查每轮后转换 FL 模型时 tf keras 模型权重是否正在更新 - 它们正在更新。
- 确保缓冲区大小大于每个客户端的训练数据集大小。
- 将预测与训练数据集中的类别分布进行比较。存在类别不平衡,但模型预测的一个类别不一定是多数类别。而且,它并不总是同一类。大多数情况下,它仅预测 0 类。
- 将轮数增加到 5,每轮 epoch 增加到 10。这在计算上非常密集,因为它是一个相当大的模型,每个客户端需要大约 1500 个图像进行训练。
- 研究每次训练尝试的 TensorBoard 日志。随着回合的进行,训练损失正在减少。
- 尝试了一个更简单的模型 - 具有 2 个转换层的基本 CNN。这使我能够大大增加 epoch 和 rounds 的数量。在测试集上评估该模型时,它预测了 4 个不同的类别,但性能仍然很差。这表明我只需要增加原始模型的轮数和纪元数即可增加预测的变化。这是很困难的,因为这会导致大量的训练时间。
型号详情:
该模型使用 XceptionNet 作为基础模型,权重未冻结。当所有训练图像都汇集到全局数据集中时,这在分类任务中表现良好。我们的目标是希望达到与 FL 相当的性能。
base_model = Xception(include_top=False,
weights=weights,
pooling='max',
input_shape=input_shape)
x = GlobalAveragePooling2D()( x )
predictions = Dense( num_classes, activation='softmax' )( x )
model = Model( base_model.input, outputs=predictions )
这是我的训练代码:
def fit(self):
"""Train FL model"""
# self.load_data()
summary_writer = tf.summary.create_file_writer(
self.logs_dir
)
federated_averaging = self._construct_iterative_process()
state = federated_averaging.initialize()
tfkeras_model = self._convert_to_tfkeras_model( state )
print( np.argmax( tfkeras_model.predict( self.val_data ), axis=-1 ) )
val_loss, val_acc = tfkeras_model.evaluate( self.val_data, steps=100 )
with summary_writer.as_default():
for round_num in tqdm( range( 1, self.num_rounds ), ascii=True, desc="FedAvg Rounds" ):
print( "Beginning fed avg round..." )
# Round of federated averaging
state, metrics = federated_averaging.next(
state,
self.training_data
)
print( "Fed avg round complete" )
# Saving logs
for name, value in metrics._asdict().items():
tf.summary.scalar(
name,
value,
step=round_num
)
print( "round {:2d}, metrics={}".format( round_num, metrics ) )
tff.learning.assign_weights_to_keras_model(
tfkeras_model,
state.model
)
# tfkeras_model = self._convert_to_tfkeras_model(
# state
# )
val_metrics = {}
val_metrics["val_loss"], val_metrics["val_acc"] = tfkeras_model.evaluate(
self.val_data,
steps=100
)
for name, metric in val_metrics.items():
tf.summary.scalar(
name=name,
data=metric,
step=round_num
)
self._checkpoint_tfkeras_model(
tfkeras_model,
round_num,
self.checkpoint_dir
)
def _checkpoint_tfkeras_model(self,
model,
round_number,
checkpoint_dir):
# Obtaining model dir path
model_dir = os.path.join(
checkpoint_dir,
f'round_{round_number}',
)
# Creating directory
pathlib.Path(
model_dir
).mkdir(
parents=True
)
model_path = os.path.join(
model_dir,
f'model_file_round{round_number}.h5'
)
# Saving model
model.save(
model_path
)
def _convert_to_tfkeras_model(self, state):
"""Converts global TFF modle of TF keras model
Takes the weights of the global model
and pushes them back into a standard
Keras model
Args:
state: The state of the FL server
containing the model and
optimization state
Returns:
(model); TF Keras model
"""
model = self._load_tf_keras_model()
model.compile(
loss=self.loss,
metrics=self.metrics
)
tff.learning.assign_weights_to_keras_model(
model,
state.model
)
return model
def _load_tf_keras_model(self):
"""Loads tf keras models
Raises:
KeyError: A model name was not defined
correctly
Returns:
(model): TF keras model object
"""
model = create_models(
model_type=self.model_type,
input_shape=[self.img_h, self.img_w, 3],
freeze_base_weights=self.freeze_weights,
num_classes=self.num_classes,
compile_model=False
)
return model
def _define_model(self):
"""Model creation function"""
model = self._load_tf_keras_model()
tff_model = tff.learning.from_keras_model(
model,
dummy_batch=self.sample_batch,
loss=self.loss,
# Using self.metrics throws an error
metrics=[tf.keras.metrics.CategoricalAccuracy()] )
return tff_model
def _construct_iterative_process(self):
"""Constructing federated averaging process"""
iterative_process = tff.learning.build_federated_averaging_process(
self._define_model,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD( learning_rate=0.02 ),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD( learning_rate=1.0 ) )
return iterative_process