您好, 欢迎来到 !    登录 | 注册 | | 设为首页 | 收藏本站

当state_is_tuple = True时如何设置TensorFlow RNN状态?

当state_is_tuple = True时如何设置TensorFlow RNN状态?

Tensorflow占位符的一个问题是,您只能使用python列表或Numpy数组来提供它(我认为)。因此,您无法在LSTMStateTuple的元组中保存两次运行之间的状态。

我通过将状态保存在这样的张量中解决了这个问题

initial_state = np.zeros((num_layers, 2, batch_size, state_size))

在LSTM层中有两个组件, 和 ,这就是“ 2”的含义。(这篇文章很棒:https ://arxiv.org/pdf/1506.00019.pdf)

构建图时,您将解压缩并创建元组状态,如下所示:

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])
l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
         [tf.nn.rnn_cell.LSTMStateTuple(l[idx][0],l[idx][1])
          for idx in range(num_layers)]
)

然后您以通常的方式获得新状态

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)

outputs, state = tf.nn.dynamic_rnn(cell, series_batch_input, initial_state=rnn_tuple_state)

可能不应该这样……也许他们正在研究解决方案。

其他 2022/1/1 18:38:49 有347人围观

撰写回答


你尚未登录,登录后可以

和开发者交流问题的细节

关注并接收问题和回答的更新提醒

参与内容的编辑和改进,让解决方法与时俱进

请先登录

推荐问题


联系我
置顶