我正在尝试找到在批次之间传递 LSTM 状态的最佳方法。我已经搜索了所有内容,但找不到当前实施的解决方案。想象一下我有类似的东西:
cells = [rnn.LSTMCell(size) for size in [256,256]
cells = rnn.MultiRNNCell(cells, state_is_tuple=True)
init_state = cells.zero_state(tf.shape(x_hot)[0], dtype=tf.float32)
net, new_state = tf.nn.dynamic_rnn(cells, x_hot, initial_state=init_state ,dtype=tf.float32)
现在我想通过new_state
在每个批次中有效,因此无需将其存储回内存,然后使用重新馈送到 tffeed_dict
。更准确地说,我找到的所有解决方案都使用sess.run
评估new_state
and feed-dict
将其传递到init_state
。有没有什么办法可以做到这一点而没有使用瓶颈feed-dict
?
我想我应该使用tf.assign
在某种程度上,但文档不完整,我找不到任何解决方法。
我要感谢所有提前询问的人。
Cheers,
弗朗西斯科·萨维里奥
我在堆栈溢出上找到的所有其他答案都适用于旧版本或使用“feed-dict”方法来传递新状态。例如:
1) TensorFlow:记住下一批的 LSTM 状态(有状态 LSTM) https://stackoverflow.com/questions/38241410/tensorflow-remember-lstm-state-for-next-batch-stateful-lstm这是通过使用“feed-dict”来提供状态占位符来实现的,我想避免这种情况
2) Tensorflow - 批处理内的 LSTM 状态重用 https://stackoverflow.com/questions/42133661/tensorflow-lstm-state-reuse-within-batch这不适用于状态 tuple
3) 在 Tensorflow 中的运行之间保存 LSTM RNN 状态 https://stackoverflow.com/questions/41915056/saving-lstm-rnn-state-between-runs-in-tensorflow同样在这里