阅读(3393) (10)

TensorFlow函数教程:tf.nn.static_rnn

2019-02-13 10:41:10 更新

tf.nn.static_rnn函数

别名:

  • tf.contrib.rnn.static_rnn
  • tf.nn.static_rnn
tf.nn.static_rnn(
    cell,
    inputs,
    initial_state=None,
    dtype=None,
    sequence_length=None,
    scope=None
)

定义在:tensorflow/python/ops/rnn.py。

创建由RNNCell cell指定的循环神经网络。

生成的最简单的RNN网络形式是:

  state = cell.zero_state(...)
  outputs = []
  for input_ in inputs:
    output, state = cell(input_, state)
    outputs.append(output)
  return (outputs, state)

但是,还有一些其他选项:

可以提供初始状态。如果提供sequence_length向量,则执行动态计算。这种计算方法不计算超过最小批处理的最大序列长度的RNN步骤(从而节省计算时间),并且将示例的序列长度的状态适当地传播到最终状态输出。

在批处理行b的时间t上执行的动态计算:

  (output, state)(b, t) =
    (t >= sequence_length(b))
      ? (zeros(cell.output_size), states(b, sequence_length(b) - 1))
      : cell(input(b, t), state(b, t - 1))

参数:

  • cell:RNNCell的一个实例。
  • inputs:输入的长度为T的列表,每个Tensor具有shape [batch_size, input_size];或这些元素的嵌套元组。
  • initial_state:(可选)RNN的初始状态。如果cell.state_size是整数,则必须是具有适当的类型和shape为[batch_size, cell.state_size]的Tensor。如果cell.state_size是一个元组,这应该是具有shape [batch_size, s]的张量元组,其中s位于cell.state_size。
  • dtype:(可选)初始状态和预期输出的数据类型。如果未提供initial_state或RNN状态具有异构类型,则为必需。
  • sequence_length:指定输入中每个序列的长度。int32或int64向量(张量),大小为[batch_size],值位于[0, T)。
  • scope:用于创建子图的VariableScope;默认为“rnn”。

返回:

(outputs, state)对,其中:

  • outputs的长度为T的列表(每个输入一个),或这些元素的嵌套元组。
  • state是最终状态

可能引发的异常:

  • TypeError:如果cell不是RNNCell的实例。
  • ValueError:如果inputs为None或是一个空列表,或者无法通过形状推断从输入推断输入深度(列大小)。

实例:

import tensorflow as tf

x=tf.Variable(tf.random_normal([2,4,3])) #[batch_size,timesteps,embedding_dim] 
x=tf.unstack(x,axis=1) #按时间步展开 
n_neurons = 5 #输出神经元数量
 
basic_cell = tf.contrib.rnn.BasicRNNCell(num_units=n_neurons)
output_seqs, states = tf.contrib.rnn.static_rnn(basic_cell,x, dtype=tf.float32)
 
print(len(output_seqs)) #四个时间步 
print(output_seqs[0]) #每个时间步输出一个张量 
print(output_seqs[1]) #每个时间步输出一个张量
print(states) #隐藏状态

输出结果如下:

4
Tensor("rnn/basic_rnn_cell/Tanh:0", shape=(2, 5), dtype=float32)
Tensor("rnn/basic_rnn_cell/Tanh_1:0", shape=(2, 5), dtype=float32)
Tensor("rnn/basic_rnn_cell/Tanh_3:0", shape=(2, 5), dtype=float32)