tf.nn.dynamic_rnn
接受一批(具有小批量的含义)不相关的序列。
-
cell
是您要使用的实际单元(LSTM、GRU,...)
-
inputs
形状为batch_size x max_time x input_size
其中 max_time 是最长序列中的步数(但所有序列可以具有相同的长度)
-
sequence_length
是一个大小向量batch_size
其中每个元素给出批次中每个序列的长度(如果所有序列的大小相同,则将其保留为默认值。该参数定义单元展开尺寸。
隐藏状态处理
处理隐藏状态的通常方法是在隐藏状态之前定义一个初始状态张量dynamic_rnn
,例如这样:
hidden_state_in = cell.zero_state(batch_size, tf.float32)
output, hidden_state_out = tf.nn.dynamic_rnn(cell,
inputs,
initial_state=hidden_state_in,
...)
在上面的代码片段中,两个hidden_state_in
and hidden_state_out
具有相同的形状[batch_size, ...]
(实际形状取决于您使用的单元格类型,但重要的是第一个维度是批量大小).
这边走,dynamic_rnn
每个序列都有一个初始隐藏状态。它将在每个序列的时间步长之间传递隐藏状态inputs
参数本身, and hidden_state_out
将包含批次中每个序列的最终输出状态。同一批次的序列之间不会传递任何隐藏状态,而只会在同一序列的时间步之间传递。
什么时候需要手动反馈隐藏状态?
通常,当您进行训练时,每个批次都是无关的,因此您不必在执行训练时反馈隐藏状态session.run(output)
.
但是,如果您正在测试,并且需要每个时间步骤的输出(即您必须执行session.run()
在每个时间步)您将需要使用如下所示的方法来评估并反馈输出隐藏状态:
output, hidden_state = sess.run([output, hidden_state_out],
feed_dict={hidden_state_in:hidden_state})
否则tensorflow将只使用默认值cell.zero_state(batch_size, tf.float32)
在每个时间步,这相当于在每个时间步重新初始化隐藏状态。