tf.data.Dataset.window
返回一个新的数据集,其元素是数据集,这些嵌套数据集的元素是所需大小的窗口。如果您有一个数据集(例如,Dataset.range(10)
并想要一个像这样的窗口数据集[0 1 2] [1 2 3] ... [7 8 9]
),有一个技巧可以做到这一点window
plus flat_map
:
>>> d = tf.data.Dataset.range(10).window(3, shift=1, drop_remainder=True).flat_map(lambda x: x.batch(3))
>>> print(list(d))
[<tf.Tensor: shape=(3,), dtype=int64, numpy=array([0, 1, 2])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([1, 2, 3])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([2, 3, 4])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([3, 4, 5])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([4, 5, 6])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([5, 6, 7])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([6, 7, 8])>, <tf.Tensor: shape=(3,), dtype=int64, numpy=array([7, 8, 9])>]
但是,那flat_map
导致数据集丢失基数信息:
>>> d.cardinality.numpy()
<tf.Tensor: shape=(), dtype=int64, numpy=-2>
(-2 is UNKNOWN_CARDINALITY https://www.tensorflow.org/api_docs/python/tf/data#UNKNOWN_CARDINALITY; see Tensorflow 2.0:flat_map() 压平数据集的数据集返回基数 -2 https://stackoverflow.com/questions/66287320/tensorflow-2-0-flat-map-to-flatten-dataset-of-dataset-returns-cardinality-2/68281606#68281606)
我想创建此类窗口的数据集,同时保留基数信息。使用未知基数的数据集的一个小烦恼是 Keras 训练进度条需要先运行一个 epoch,然后才能生成 ETA。我试过.take(n_windows)
我在哪里计算n_windows
我自己,但仍然返回了一个数据集UNKNOWN_CARDINALITY
.
有没有某种方法可以在不丢失基数信息的情况下对数据集进行窗口化?