查看: 1752|回复: 0

LSTM中tf.nn.dynamic_rnn处理过程详解

[复制链接]

665

主题

1234

帖子

6695

积分

xdtech

Rank: 5Rank: 5

积分
6695
发表于 2019-9-16 19:41:13 | 显示全部楼层 |阅读模式
对于tf.nn.dynamic_rnn处理过程的代码如下,但是每一步缺少细致的解释,本博客旨在帮助小伙伴们详细了解每一的步骤以及为什么要这样做。

lstmCell = tf.contrib.rnn.BasicLSTMCell(lstmUnits)
lstmCell = tf.contrib.rnn.DropoutWrapper(cell=lstmCell, output_keep_prob=0.75)
value, _ = tf.nn.dynamic_rnn(lstmCell, data, dtype=tf.float32)
lstmUnits为神经元的个数,前两行代码比较好理解,第三行代码生成的value和_令我百思不得其解。接着又出现另外几行代码更让我云里雾里。

weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))
bias = tf.Variable(tf.constant(0.1, shape=[numClasses]))
value = tf.transpose(value, [1, 0, 2])
#取最终的结果值
last = tf.gather(value, int(value.get_shape()[0]) - 1)
prediction = (tf.matmul(last, weight) + bias)
看到这里不禁会发问,为什么要对value进行value = tf.transpose(value, [1, 0, 2])这部分操作,然后last = tf.gather(value, int(value.get_shape()[0]) - 1)这一步又有什么作用?带着这些疑问,我通过不停地百度,参考https://blog.csdn.net/qq_35203425/article/details/79572514这篇文章终于得出解答。



首先tf.nn.dynamic_rnn的输出包括outputs和states两部分。在唐宇迪例子中value相当于outputs,我们需要找outputs的最后一个step的输出。对value进行value = tf.transpose(value, [1, 0, 2])操作后得到的shape为[step,batch_size,lstmUnits].而后last = tf.gather(value, int(value.get_shape()[0]) - 1),其中value.get_shape()[0]) - 1找到value经过transpose后的最后一个分片,last = tf.gather(value, int(value.get_shape()[0]) - 1)表示最后一个[batch_size,lstmUnits],也就是lstm最后的输出,这时候weight = tf.Variable(tf.truncated_normal([lstmUnits, numClasses]))的shape为[lstmUnits,numClasses],last的shape为[batch_size,lstmUnits],两者相乘的维度为[batch_size,numClasses],再与偏置向量相加即可得到。真的输出应该是states.h。

states是由(c,h)组成的tuple,大小均为[batch,lstmUnits]。所以如果想用dynamic_rnn得到输出后,只需要最后一次的状态输出,直接调用states.h即可,也可以按照上述进行操作

回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表