shaoheshaohe 发表于 2019-9-16 19:41:13

LSTM中tf.nn.dynamic_rnn处理过程详解

对于tf.nn.dynamic_rnn处理过程的代码如下,但是每一步缺少细致的解释,本博客旨在帮助小伙伴们详细了解每一的步骤以及为什么要这样做。

lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits)
lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.75)
value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32)
lstmUnits为神经元的个数,前两行代码比较好理解,第三行代码生成的value和_令我百思不得其解。接着又出现另外几行代码更让我云里雾里。

weight = tf.Variable(tf.truncated_normal())
bias = tf.Variable(tf.constant(0.1, shape=))
value = tf.transpose(value, )
#取最终的结果值
last = tf.gather(value, int(value.get_shape()) - 1)
prediction = (tf.matmul(last, weight) + bias)
看到这里不禁会发问,为什么要对value进行value = tf.transpose(value, )这部分操作,然后last = tf.gather(value, int(value.get_shape()) - 1)这一步又有什么作用?带着这些疑问,我通过不停地百度,参考https://blog.csdn.net/qq_35203425/article/details/79572514这篇文章终于得出解答。



首先tf.nn.dynamic_rnn的输出包括outputs和states两部分。在唐宇迪例子中value相当于outputs,我们需要找outputs的最后一个step的输出。对value进行value = tf.transpose(value, )操作后得到的shape为.而后last = tf.gather(value, int(value.get_shape()) - 1),其中value.get_shape()) - 1找到value经过transpose后的最后一个分片,last = tf.gather(value, int(value.get_shape()) - 1)表示最后一个,也就是lstm最后的输出,这时候weight = tf.Variable(tf.truncated_normal())的shape为,last的shape为,两者相乘的维度为,再与偏置向量相加即可得到。真的输出应该是states.h。

states是由(c,h)组成的tuple,大小均为。所以如果想用dynamic_rnn得到输出后,只需要最后一次的状态输出,直接调用states.h即可,也可以按照上述进行操作

页: [1]
查看完整版本: LSTM中tf.nn.dynamic_rnn处理过程详解