TensorFlow Custom Estimator 预测投掷值误差

2023-12-03

注意:这个问题有一个附带的、记录的Colab笔记本。

有时,TensorFlow 的文档还有很多不足之处。一些针对较低级别 api 的旧文档似乎已被删除,而大多数较新的文档都指向使用较高级别的 api,例如 TensorFlow 的子集keras or estimators。如果较高级别的 api 不经常紧密依赖于较低级别的 API,那么这不会有太大问题。举个例子,estimators(特别是input_fn使用 TensorFlow Records 时)。

通过以下 Stack Overflow 帖子:

  • Tensorflow v1.10:将图像存储为字节字符串或每个通道?
  • Tensorflow 1.10 TFRecordDataset - 恢复 TFRecord
  • Tensorflow v1.10+ 为什么在没有输入服务接收器功能的情况下创建检查点时需要输入服务接收器功能?
  • TensorFlow 1.10+ 自定义估算器通过 train_and_evaluate 提前停止
  • 训练后调用评估时 TensorFlow 自定义估计器卡住

在 TensorFlow / StackOverflow 社区的慷慨帮助下,我们离 TensorFlow 所做的事情又近了一步“创建自定义估算器”指南还没有,演示如何制作一个可以在实践中实际使用的估计器(而不是玩具示例),例如其中:

  • 有一个验证集,可以在性能恶化时提前停止,
  • 从 TF Records 中读取,因为许多数据集都大于 TensorFlow 建议的内存 1Gb,并且
  • 在训练时保存其最佳版本

虽然我对此仍然有很多疑问(从将数据编码到 TF 记录中的最佳方式,到到底是什么)serving_input_fn期望),有一个问题比其他问题更突出:

如何用我们刚刚制作的自定义估计器进行预测?

根据文档predict, 它指出:

input_fn:构造特征的函数。预测持续到input_fn引发输入结束异常(tf.errors.OutOfRangeError or StopIteration)。有关详细信息,请参阅预制估算​​器。该函数应构造并返回以下内容之一:

  • tf.data.Dataset 对象:Dataset 对象的输出必须具有与以下相同的约束。
  • features:tf.Tensor 或 Tensor 的字符串特征名称字典。特征由 model_fn 使用。它们应该满足输入中 model_fn 的期望。
  • 一个元组,在这种情况下,第一项将被提取为特征。

(也许)最有可能的是,如果一个人正在使用estimator.predict,他们正在使用内存中的数据,例如密集张量(因为保留的测试集可能会通过evaluate).

所以我在随附的Colab,创建一个密集示例,将其包装在tf.data.Dataset,并致电predict得到一个ValueError.

如果有人能向我解释我该如何做,我将不胜感激:

  1. 加载我保存的估算器
  2. 给定一个密集的内存示例,使用估计器预测输出

to_predict = random_onehot((1, SEQUENCE_LENGTH, SEQUENCE_CHANNELS))\
        .astype(tf_type_string(I_DTYPE))
pred_features = {'input_tensors': to_predict}

pred_ds = tf.data.Dataset.from_tensor_slices(pred_features)
predicted = est.predict(lambda: pred_ds, yield_single_examples=True)

next(predicted)

ValueError:Tensor("IteratorV2:0", shape=(), dtype=resource) 必须来自与 Tensor("TensorSliceDataset:0", shape=(), dtype=variant) 相同的图。

当您使用tf.data.Dataset模块,它实际上定义了一个独立于模型图的输入图。这里发生的是,您首先通过调用创建了一个小图tf.data.Dataset.from_tensor_slices(),然后估算器 API 通过调用创建了第二个图dataset.make_one_shot_iterator()自动地。这两个图无法通信,因此会引发错误。

为了避免这种情况,您永远不应该在 estimator.train/evaluate/predict 之外创建数据集。这就是为什么所有相​​关数据都包含在输入函数中的原因。

def predict_input_fn(data, batch_size=1):
  dataset = tf.data.Dataset.from_tensor_slices(data)
  return dataset.batch(batch_size).prefetch(None)

predicted = est.predict(lambda: predict_input_fn(pred_features), yield_single_examples=True)
next(predicted)

现在,图表不是在预测调用之外创建的。

我还添加了dataset.batch()因为代码的其余部分需要批处理数据,并且它抛出了形状错误。预取只是加快速度。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

TensorFlow Custom Estimator 预测投掷值误差 的相关文章

随机推荐