I have n
(例如:n=3)范围和x
(例如:x=4)每个作用域中定义的变量数量。
范围是:
model/generator_0
model/generator_1
model/generator_2
计算损失后,我想根据运行时的标准从其中一个范围中提取并提供所有变量。因此范围的索引idx
我选择的是一个 argmin 张量转换为 int32
<tf.Tensor 'model/Cast:0' shape=() dtype=int32>
我已经尝试过:
train_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'model/generator_'+tf.cast(idx, tf.string))
这显然不起作用。
有什么办法可以得到所有的x
属于该特定范围的变量使用 idx 传递到优化器中。
提前致谢!
维涅什·斯里尼瓦桑
您可以在 TF 1.0 rc1 或更高版本中执行类似的操作:
v = tf.Variable(tf.ones(()))
loss = tf.identity(v)
with tf.variable_scope('adamoptim') as vs:
optim = tf.train.AdamOptimizer(learning_rate=0.1).minimize(loss)
optim_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope=vs.name)
print([v.name for v in optim_vars]) #=> prints lists of vars created
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系:hwhale#tublm.com(使用前将#替换为@)