如何从 TensorFlow 数据集中提取数据/标签

2023-12-21

有很多如何创建和使用 TensorFlow 数据集的示例,例如

dataset = tf.data.Dataset.from_tensor_slices((images, labels))

我的问题是如何以 numpy 形式从 TF 数据集中获取数据/标签?换句话说,want 是上面一行的反向操作,即我有一个 TF 数据集,并且想要从中获取图像和标签。


万一你的tf.data.Dataset https://www.tensorflow.org/api_docs/python/tf/data/Dataset是批处理的,以下代码将检索所有 y 标签:

y = np.concatenate([y for x, y in ds], axis=0)

快速解释: [y for x, y in ds]在Python中被称为“列表理解”。如果数据集是批处理的,则此表达式将循环遍历每个批次,并将每个批次 y(TF 1D 张量)放入列表中,然后返回它。然后,np.concatenate 将获取此一维张量列表(隐式转换为 numpy)并将其堆叠在 0 轴上以生成单个长向量。总而言之,它只是将一堆一维小向量转换为一个长向量。

Note:如果你的 y 更复杂,这个答案将需要一些小的修改。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何从 TensorFlow 数据集中提取数据/标签 的相关文章

随机推荐