当前的 TensorFlow 数据集交错功能基本上是一个交错平面地图,将单个数据集作为输入。考虑到当前的 API,将多个数据集交错在一起的最佳方法是什么?假设它们已经建成,并且我有一份清单。我想交替地从它们中生成元素,并且我想支持具有超过 2 个数据集的列表(即,堆叠的 zip 和交错会非常难看)。
谢谢! :)
@mrry 或许能帮上忙。
也可以看看:
尽管这并不“干净”,但这是我想出的唯一解决方法。
datasets = [tf.data.Dataset...]
def concat_datasets(datasets):
ds0 = tf.data.Dataset.from_tensors(datasets[0])
for ds1 in datasets[1:]:
ds0 = ds0.concatenate(tf.data.Dataset.from_tensors(ds1))
return ds0
ds = tf.data.Dataset.zip(tuple(datasets)).flat_map(
lambda *args: concat_datasets(args)
)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)