将多个 TensorFlow 数据集交错在一起

2023-11-21

当前的 TensorFlow 数据集交错功能基本上是一个交错平面地图,将单个数据集作为输入。考虑到当前的 API,将多个数据集交错在一起的最佳方法是什么?假设它们已经建成,并且我有一份清单。我想交替地从它们中生成元素,并且我想支持具有超过 2 个数据集的列表(即,堆叠的 zip 和交错会非常难看)。

谢谢! :)

@mrry 或许能帮上忙。


也可以看看:

  • tf.data.Dataset.choose_from_datasets,执行确定性数据集交错。

  • tf.data.Dataset.sample_from_datasets,即使执行随机采样,这也很有用。


尽管这并不“干净”,但这是我想出的唯一解决方法。

datasets = [tf.data.Dataset...]

def concat_datasets(datasets):
    ds0 = tf.data.Dataset.from_tensors(datasets[0])
    for ds1 in datasets[1:]:
        ds0 = ds0.concatenate(tf.data.Dataset.from_tensors(ds1))
    return ds0

ds = tf.data.Dataset.zip(tuple(datasets)).flat_map(
    lambda *args: concat_datasets(args)
)
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

将多个 TensorFlow 数据集交错在一起 的相关文章

随机推荐