如何在张量流中使用带有估计器的衰减学习率?

2024-05-26

我正在尝试将 LinearClassifier 与具有衰减学习率的 GradientDescentOptimizer 一起使用。

My code:

def main():
# load data
    features = np.load('data/feature_data.npz')
    tx = features['arr_0']
    y = features['arr_1']

## Prepare logistic regression
    n_point, n_feat = tx.shape

# Input functions
    def get_input_fn_from_numpy(tx, y, num_epochs=None, shuffle=True):
    # Preprocess data
        return tf.estimator.inputs.numpy_input_fn(
        x={"x":tx},
        y=y,
        num_epochs=num_epochs,
        shuffle=shuffle,
        batch_size=128
        )

    cols_label = "x"
    feature_cols = [tf.contrib.layers.real_valued_column(cols_label)]

    my_input_fn_train = get_input_fn_from_numpy(tx, y)

    model_dir = 'data/tmp/' + datetime.datetime.now().strftime("%m-%d_%H:%M:%S")
    global_step = tf.Variable(0, trainable=False)
    learning_rate=tf.train.exponential_decay(0.001*np.ones((20,1), dtype=np.float32), global_step, 10000, 0.95, staircase=False)
    regressor = tf.contrib.learn.LinearClassifier(feature_columns=feature_cols,
                                              model_dir=model_dir,
                                                  optimizer=tf.train.GradientDescentOptimizer(learning_rate=learning_rate))

    regressor.fit(input_fn=get_input_fn_from_numpy(tx_train, y_train), steps=100000)
    results = regressor.evaluate(input_fn=my_input_fn_test)

我收到错误:

  File "training.py", line 71, in <module>
main()
  File "training.py", line 63, in main
regressor.fit(input_fn=get_input_fn_from_numpy(tx_train, y_train), steps=100000)
  File "/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py", line 296, in new_func
return func(*args, **kwargs)
  File "/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 458, in fit
loss = self._train_model(input_fn=input_fn, hooks=hooks)
  File "/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 958, in _train_model
model_fn_ops = self._get_train_ops(features, labels)
 File "/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1165, in _get_train_ops
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.TRAIN)
  File "/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1136, in _call_model_fn
model_fn_results = self._model_fn(features, labels, **kwargs)
  File "/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/linear.py", line 186, in _linear_model_fn
logits=logits)
  File "/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/head.py", line 854, in create_model_fn_ops
enable_centered_bias=self._enable_centered_bias)
  File "/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/head.py", line 649, in _create_model_fn_ops
batch_size, loss_fn, weight_tensor)
  File "/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/head.py", line 1911, in _train_op
train_op = train_op_fn(loss)
  File "/lib/python3.6/site-packages/tensorflow/contrib/learn/python/learn/estimators/linear.py", line 179, in _train_op_fn
zip(grads, my_vars), global_step=global_step))
  File "/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 456, in apply_gradients
update_ops.append(processor.update_op(self, grad))
  File "/lib/python3.6/site-packages/tensorflow/python/training/optimizer.py", line 97, in update_op
return optimizer._apply_dense(g, self._v)  # pylint: disable=protected-access
  File "/lib/python3.6/site-packages/tensorflow/python/training/gradient_descent.py", line 50, in _apply_dense
use_locking=self._use_locking).op
  File "/lib/python3.6/site-packages/tensorflow/python/training/gen_training_ops.py", line 370, in apply_gradient_descent
name=name)
  File "/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py", line 330, in apply_op
g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
  File "/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 4262, in _get_graph_from_inputs
_assert_same_graph(original_graph_element, graph_element)
  File "/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 4201, in _assert_same_graph
"%s must be from the same graph as %s." % (item, original_item))
ValueError: Tensor("ExponentialDecay:0", shape=(20, 1), dtype=float32) must be from the same graph as Tensor("linear/x/weight/part_0:0", shape=(20, 1), dtype=float32_ref).

我使用的是张量流1.3。 如果我用一个常数(比如 0.01)替换学习率,它就会起作用。我过去曾将衰减学习率与最小化操作结合使用,但试图在 LinearClassifier 中使用它。 我发现有些东西似乎不一致,因为我没有将全局步骤链接到拟合步骤,但我想知道这是如何工作的。我想我可以按照建议使用占位符here https://stackoverflow.com/questions/33919948/how-to-set-adaptive-learning-rate-for-gradientdescentoptimizer但我不明白为什么如果不需要的话我应该自己编写更新规则。

关于如何解决这个问题有什么建议吗?


您是否尝试过获得global_step通过致电tf.train.get_global_step()?这应该返回global_step由你使用LinearClassifier model.

代替

global_step = tf.Variable(0, trainable=False)

use

global_step = tf.train.get_global_step()

这对我有用,用我自己的Estimator类,我使用的地方tf.train.MomentumOptimizer以尽量减少tf.nn.sparse_softmax_cross_entropy_with_logits.

本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)

如何在张量流中使用带有估计器的衰减学习率? 的相关文章

随机推荐

  • 从 PYCHARM 运行时使 PYTEST 更安静

    更新 下面显示的消息不受 pytest 各种 q 安静选项控制 它们来自 TeamCity 插件 请参阅下面我的回答 原文 我已经阅读了用于沉默 pytest 的现有堆栈溢出答案 但没有人告诉我如何沉默我收到的大量冗余 测试通过 消息 我有
  • 如何在 Mongoose 中定义排序函数

    我正在开发一个小型 NodeJS Web 应用程序 使用 Mongoose 访问我的 MongoDB 数据库 我的收藏的简化架构如下 var MySchema mongoose Schema content type String loca
  • ASP.Net 将 401 错误代码转换为 302 错误代码

    我有一个自定义处理程序 在某些情况下 我想向用户代理表明他们未获得授权 Http 错误代码 401 if IsAuthorized context context Response StatusCode 401 context Respon
  • 使用表达式树构造 LINQ GroupBy 查询

    我已经在这个问题上坚持了一个星期了 但没有找到解决方案 我有一个像下面这样的 POCO public class Journal public int Id get set public string AuthorName get set
  • PHP/PDO 和 SQL Server 连接以及 i18n 问题

    在我们的网络应用程序中 我们使用 PHP5 2 6 PDO 连接到 SQL Server 2005 数据库并存储俄语文本 数据库排序规则是Cyrillic General CI AS 表排序规则是Cyrillic General CI AS
  • VB6 ActiveX exe - 正确的注册顺序是什么?

    我最近更新了一个 Visual Basic 6 应用程序 它是一个 ActiveX exe 在 Windows XP 上运行 我有几个此应用程序的测试人员 他们已收到 exe 的副本并正在尝试运行它 但是 他们收到一条错误消息 Unexpe
  • 从 Redux 状态删除一个项目

    我想知道如果可能的话你是否能帮我解决这个问题 我正在尝试从 Redux 状态中删除一个项目 我已经传入了用户点击的项目的IDaction data进入减速机 我想知道如何匹配action data使用 Redux 状态中的 ID 之一 然后
  • 从 UIScrollView 中删除所有子视图?

    如何从 UIScrollview 中删除所有子视图 Let scrollView是一个实例UIScrollView 在 Objective C 中 这非常简单 只需致电makeObjectsPerformSelector 像这样 Objec
  • SQL Server 2005 中的计数(*) 与计数(Id)

    我使用 SQLCOUNT函数获取表中的总数或行数 以下两种说法有什么区别吗 SELECT COUNT FROM Table and SELECT COUNT TableId FROM Table 另外 在性能和执行时间方面有什么区别吗 Th
  • 设置 MySQL 触发器

    我听说过有关触发器的事情 我有几个问题 什么是触发器 我该如何设置它们 除了典型的 SQL 内容之外 是否还应该采取任何预防措施 触发器允许您在发生某些事件 例如 插入表 时在数据库中执行某个功能 我无法具体评论mysql 注意事项 触发器
  • 在字符串数组中查找下一个可用日期

    我一直在尝试找出如何根据当前日期获取下一个可用日期 即 如果今天是星期五 则在数组中搜索下一个最近的日期 例如数组值为 1 星期一 2 星期二 4 星期四 6 星期六 那么我的第二天应该是星期六 这是我尝试过的 Here i ll get
  • RecyclerView 上的删除按钮删除了错误的项目

    我正在使用 Firestore 适配器RecyclerView我在使用 删除 按钮时遇到问题 当我按下它时 它会删除错误的项目 而不是我想要的项目 这是我的按钮内部的代码onBindViewHolder protected void onB
  • 了解单目标迷宫的 A* 启发式

    我有一个像下面这样的迷宫 P
  • 传说在北卡罗来纳州地理地图上消失?

    我正在使用 R 编程语言 使用北卡罗来纳州的内置地图 我生成了 3 个随机变量 收入 孩子数量 体重 然后为此数据创建了地图 使用 传单 库 通过循环 library sf library mapview library leaflet l
  • jQuery Mobile 1.4.0:动态更改页面的标题和标题

    动态更改 jQuery Mobile 1 4 0 页面的标题 data role header 和 title 的正确方法是什么 添加方法有很多种toolbars 页眉 页脚 动态 此外 jQuery Mobile 1 4 提供intern
  • 检测堆栈已满

    在编写 C 代码时 我了解到使用堆栈来存储内存是一个好主意 但最近我遇到了一个问题 我有一个实验 其代码如下所示 void fun const unsigned int N float data 1 N N float data 2 N N
  • python:将base64编码的png图像转换为jpg

    我想使用 python 将一些 base64 编码的 png 图像转换为 jpg 我知道如何从 Base64 解码回原始 import base64 pngraw base64 decodestring png b64text 但现在我怎样
  • 土耳其语字符显示不正确[重复]

    这个问题在这里已经有答案了 MySql 数据库使用 utf 8 编码 数据存储正确 我使用 set name utf8 查询来确保调用的数据是 utf 8 编码 只要标头字符集是 utf 8 数据库中的所有变量都可以正常工作 但静态html
  • 无需发送消息即可获取 GCM 规范注册 ID

    我在使用 GCM 的应用程序时遇到问题 情况如下 该应用程序已安装 应用程序调用 GCM 注册方法获取注册 ID RID 1 该应用程序已卸载 再次安装该应用程序 应用程序再次调用 GCM 注册方法 获取注册 ID RID 2 在第 5 步
  • 如何在张量流中使用带有估计器的衰减学习率?

    我正在尝试将 LinearClassifier 与具有衰减学习率的 GradientDescentOptimizer 一起使用 My code def main load data features np load data feature