我最近开始使用 Tensorflow,并尝试习惯 tf.estimator.Estimator 对象。我想做一些非常自然的先验事情:在训练了我的分类器之后,即 tf.estimator.Estimator 的实例(带有train
方法),我想将其保存在文件中(无论扩展名如何),然后稍后重新加载以预测一些新数据的标签。由于官方文档建议使用 Estimator API,我想应该实现和记录同样重要的事情。
我在其他页面上看到这样做的方法是export_savedmodel
(see 官方文档)但我根本不理解文档。没有说明如何使用此方法。论据是什么serving_input_fn
?我从来没有遇到过它创建自定义估算器教程或我读过的任何教程。通过进行一些谷歌搜索,我发现大约一年前,估计器是使用其他类定义的(tf.contrib.learn.Estimator
)并且看起来 tf.estimator.Estimator 正在重用以前的一些 API。但我在文档中没有找到关于它的明确解释。
有人可以给我一个玩具示例吗?或者解释一下如何定义/找到这个serving_input_fn
?
那么如何再次加载训练好的分类器呢?
感谢您的帮助!
Edit:我发现不一定需要使用export_savemodel来保存模型。它实际上是自动完成的。然后,如果我们稍后定义一个具有相同 model_dir 参数的新估计器,它也会自动恢复以前的估计器,如下所示here.
正如您所了解的,估计器会在训练期间自动为您保存并恢复模型。如果您想将模型部署到现场(例如为 Tensorflow Serving 提供最佳模型),export_savemodel 可能会很有用。
这是一个简单的例子:
est.export_savedmodel(export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=serving_input_fn)
def serving_input_fn():
inputs = {'features': tf.placeholder(tf.float32, [None, 128, 128, 3])}
return tf.estimator.export.ServingInputReceiver(inputs, inputs)
基本上serving_input_fn 负责用占位符替换数据集管道。在部署中,您可以将数据提供给此占位符,作为模型的输入以进行推理或预测。
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)