如何使用export_savedmodel保存和恢复tf.estimator.Estimator模型?

2023-11-29

我最近开始使用 Tensorflow,并尝试习惯 tf.estimator.Estimator 对象。我想做一些非常自然的先验事情:在训练了我的分类器之后,即 tf.estimator.Estimator 的实例(带有train方法),我想将其保存在文件中(无论扩展名如何),然后稍后重新加载以预测一些新数据的标签。由于官方文档建议使用 Estimator API,我想应该实现和记录同样重要的事情。

我在其他页面上看到这样做的方法是export_savedmodel (see 官方文档)但我根本不理解文档。没有说明如何使用此方法。论据是什么serving_input_fn?我从来没有遇到过它创建自定义估算器教程或我读过的任何教程。通过进行一些谷歌搜索,我发现大约一年前,估计器是使用其他类定义的(tf.contrib.learn.Estimator)并且看起来 tf.estimator.Estimator 正在重用以前的一些 API。但我在文档中没有找到关于它的明确解释。

有人可以给我一个玩具示例吗?或者解释一下如何定义/找到这个serving_input_fn?

那么如何再次加载训练好的分类器呢?

感谢您的帮助!

Edit:我发现不一定需要使用export_savemodel来保存模型。它实际上是自动完成的。然后,如果我们稍后定义一个具有相同 model_dir 参数的新估计器,它也会自动恢复以前的估计器,如下所示here.


正如您所了解的,估计器会在训练期间自动为您保存并恢复模型。如果您想将模型部署到现场(例如为 Tensorflow Serving 提供最佳模型),export_savemodel 可能会很有用。

这是一个简单的例子:

est.export_savedmodel(export_dir_base=FLAGS.export_dir, serving_input_receiver_fn=serving_input_fn)

def serving_input_fn(): inputs = {'features': tf.placeholder(tf.float32, [None, 128, 128, 3])} return tf.estimator.export.ServingInputReceiver(inputs, inputs)

基本上serving_input_fn 负责用占位符替换数据集管道。在部署中,您可以将数据提供给此占位符,作为模型的输入以进行推理或预测。

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何使用export_savedmodel保存和恢复tf.estimator.Estimator模型? 的相关文章

随机推荐

  • 理解@interface声明中“(Private)”的这种用法

    我见过一些这样写的代码 interface AViewController Private 我想知道是否是这样 Private 提交到 App Store 时意味着什么 一般而言 这意味着什么 这是一个名为 私人 的类别 看看类别和扩展Ob
  • TimeSlider 插件和传单 - 标记未按顺序出现

    更新了一个JSFIDDLE 链接 我正在使用 LeafletJS 构建带有时间轴滑块的网络地图 我正在使用LeafletSlider插件显示一组基于名为的 GEOJSON 属性的标记DATE START 这是我的数据对象的示例 var ca
  • PDF 发送意图上的 Android SecurityException

    我在执行期间遇到以下异常ACTION SEND数据类型的意图application pdf java lang SecurityException Permission Denial starting Intent act android
  • ios - 在 WkWebView 中禁用 Youtube 自动播放

    我在用着WKWebView打开pages in Youtube 问题是 打开后他们开始播放视频并进入全屏 这是不想要的行为 视频未嵌入 它是带有描述 评论等的整个页面 有办法阻止他们玩吗 有用 请阅读评论 import UIKit impo
  • 您是否必须在 Redis 脚本中提前声明您的密钥?

    我的计划是将一些现有的 Redis 键存储在哈希中 稍后从 Redis Lua 脚本中获取并执行操作 我读到 最佳实践是在调用时提供脚本中使用的所有键EVAL 我的问题是 运行运行时没有提供任何密钥的脚本是否安全EVAL 但对从以下位置获取
  • `DS.attr()` 中的嵌套对象不受 `DS.rollbackAttributes()` 影响

    我有一个模型User如下 import DS from ember data const attr Model DS export default Model extend name attr string properties attr
  • templateUrl 更改时 AngularJS Modal 不显示

    到目前为止 我所拥有的是 Angular UI 示例 控制器 var ModalDemoCtrl function scope modal scope open function var modalInstance modal open t
  • 从托管代码中挂钩 LoadLibrary 调用

    我们希望挂钩对 LoadLibrary 的调用 以便下载未找到的程序集 我们有一个 ResolveAssembly 处理程序来处理托管程序集 但我们还需要处理非托管程序集 我们尝试通过 Microsoft Windows 的编程应用程序 中
  • 动态改变JTree中特定节点的图标

    我已经看过很多在树实例化期间更改节点图标的示例 但我想要一种稍后动态更改单个节点图标的方法 因此 在我的主代码中 我将自定义渲染器添加到我的树中 Icon I want to set nodes to later ImageIcon che
  • 在 Groovy 中执行 Unix cat 命令?

    Hallo 我想从 Groovy 程序执行类似 cat path to file1 path to file2 gt path to file3 的内容 我尝试了 cat path to file1 path to file2 gt pat
  • Android - 访问在线数据库SQlite

    我需要从 Android 应用程序打开 读取项目并将其插入到在线 SQLite 数据库中 我知道网址 用户名和密码 在 JavaSE 中我会执行以下操作 Class forName com mysql jdbc Driver Connect
  • 再次使用java进行字符串比较

    新手问题 但我有这个代码 import java util import java io public class Welcome1 main method begins execution of Java application publ
  • pthread - 暂停/取消暂停所有线程

    我正在尝试在我的应用程序中编写暂停 取消暂停所有线程 该线程由 SIGUSR1 暂停 和 SIGUSR2 取消暂停 激活 我想用pthread cond wait 在所有线程中 当收到信号时 使用pthread cond broadcast
  • 如何使用android创建dll

    我是 Android 应用程序开发新手 我想开发一个dll使用安卓 是否可以开发并集成到android应用程序 请告诉我解决方案 如果可以的话请将解决方案一一告诉我 至于我 我曾经为自己写过一篇关于 NDK 的笔记 这里是 Required
  • MySQL 中加密数据的搜索过滤器

    查询说明 假设我有一个数据库表 它以加密形式存储所有用户的数据 我有一个功能 管理员可以搜索用户数据 现在的问题是 管理员将在文本框中输入普通文本 我必须根据管理员的输入过滤用户列表 在每次文本更改时 因此 与此同时 我拥有大量加密形式的数
  • 如何让tableFooterView始终位于UITableView的底部

    我有一个UITableView具有可变数量的部分 每个部分都有不同数量的单元格 并且每个部分都有页眉和页脚 我的UITableView还有一个tableFooterView我想始终将其保留在屏幕底部 除非表格太大而无法在屏幕上显示 然后ta
  • iphone 粘性菜单 jquery onscroll ios 9

    在更新到之前 这段代码在我的 iPhone 上运行良好iOS 9 0 1 13A404 但现在相同的代码似乎只有在手指松开后才能工作 或者在 jQuery 之后onscroll结束 当我快速滑动以使页面滚动时 document on scr
  • odbc_prepare 给出致命错误:允许的内存大小已耗尽

    我有一个 Debian 服务器 64 位 我想通过 PHP 将其连接到 AS400 的数据库 我已经安装了 IBM i Access for Linux 和 unixodbc 我已经遵循了这个教程 https www albertopica
  • 如何在插入语句的目标数据库名称中使用变量?

    我想声明一个服务器名称并在插入语句中使用该名称 到目前为止我收到的只是一条错误消息 declare machine nvarchar 6 declare bar nvarchar 3 set machine Name00 set bar f
  • 如何使用export_savedmodel保存和恢复tf.estimator.Estimator模型?

    我最近开始使用 Tensorflow 并尝试习惯 tf estimator Estimator 对象 我想做一些非常自然的先验事情 在训练了我的分类器之后 即 tf estimator Estimator 的实例 带有train方法 我想将