打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读
DL之LSTM:tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读
tf.contrib.rnn.BasicLSTMCell(rnn_unit)函数的解读
函数功能解读
"""Basic LSTM recurrent network cell.
The implementation is based on: http://arxiv.org/abs/1409.2329.
We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training.
It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline.  For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell}
that follows.
"""
def __init__(self,
num_units,
forget_bias=1.0,
state_is_tuple=True,
activation=None,
reuse=None,
name=None,
dtype=None):
"""Initialize the basic LSTM cell.
基本LSTM递归网络单元。
实现基于:http://arxiv.org/abs/1409.2329。
我们在遗忘门的偏见中加入了遗忘偏见(默认值:1),以减少训练开始时的遗忘程度。
它不允许细胞剪切(一个投影层),也不使用窥孔连接:它是基本的基线。对于高级模型,请使用完整的@{tf.n .rnn_cell. lstmcell}遵循。
Args:
num_units: int, The number of units in the LSTM cell.
forget_bias: float, The bias added to forget gates (see above).
Must set to `0.0` manually when restoring from CudnnLSTM-trained checkpoints.
state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`.  If False, they are concatenated along the column axis.  The latter behavior will soon be deprecated.
activation: Activation function of the inner states.  Default: `tanh`.
reuse: (optional) Python boolean describing whether to reuse variables in an existing scope.  If not `True`, and the existing scope already has the given variables, an error is raised.
name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases.
dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`.
When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead.
"""
参数:
num_units: int类型, LSTM单元中的单元数。
forget_bias: float类型,偏见添加到忘记门(见上面)。
从cudnnlstm训练的检查点恢复时,必须手动设置为“0.0”。
state_is_tuple: 如果为真,则接受状态和返回状态是' c_state '和' m_state '的二元组。如果为假,则沿着列轴连接它们。后一种行为很快就会被摒弃。
activation: 内部状态的激活功能。默认值tanh激活函数。
reuse: (可选)Python布尔值,描述是否在现有范围内重用变量。如果不是“True”,并且现有范围已经有给定的变量,则会引发错误。
name:字符串,层的名称。具有相同名称的层将共享权重,但是为了避免错误,我们需要在这种情况下重用=True。
dtype:该层的默认dtype(默认为'None’意味着使用第一个输入的类型)。当' build '在' call '之前被调用时是必需的。
从经过cudnnlstm训练的检查点恢复时,必须使用“CudnnCompatibleLSTMCell”。
”“”
函数代码实现
@tf_export("nn.rnn_cell.BasicLSTMCell")class BasicLSTMCell(LayerRNNCell): """Basic LSTM recurrent network cell. The implementation is based on: http://arxiv.org/abs/1409.2329. We add forget_bias (default: 1) to the biases of the forget gate in order to reduce the scale of forgetting in the beginning of the training. It does not allow cell clipping, a projection layer, and does not use peep-hole connections: it is the basic baseline. For advanced models, please use the full @{tf.nn.rnn_cell.LSTMCell} that follows. """ def __init__(self, num_units, forget_bias=1.0, state_is_tuple=True, activation=None, reuse=None, name=None, dtype=None): """Initialize the basic LSTM cell. Args: num_units: int, The number of units in the LSTM cell. forget_bias: float, The bias added to forget gates (see above). Must set to `0.0` manually when restoring from CudnnLSTM-trained checkpoints. state_is_tuple: If True, accepted and returned states are 2-tuples of the `c_state` and `m_state`. If False, they are concatenated along the column axis. The latter behavior will soon be deprecated. activation: Activation function of the inner states. Default: `tanh`. reuse: (optional) Python boolean describing whether to reuse variables in an existing scope. If not `True`, and the existing scope already has the given variables, an error is raised. name: String, the name of the layer. Layers with the same name will share weights, but to avoid mistakes we require reuse=True in such cases. dtype: Default dtype of the layer (default of `None` means use the type of the first input). Required when `build` is called before `call`. When restoring from CudnnLSTM-trained checkpoints, must use `CudnnCompatibleLSTMCell` instead. """ super(BasicLSTMCell, self).__init__(_reuse=reuse, name=name, dtype=dtype) if not state_is_tuple: logging.warn("%s: Using a concatenated state is slower and will soon be " "deprecated. Use state_is_tuple=True.", self) # Inputs must be 2-dimensional. self.input_spec = base_layer.InputSpec(ndim=2) self._num_units = num_units self._forget_bias = forget_bias self._state_is_tuple = state_is_tuple self._activation = activation or math_ops.tanh @property def state_size(self): return (LSTMStateTuple(self._num_units, self._num_units) if self._state_is_tuple else 2 * self._num_units) @property def output_size(self): return self._num_units def build(self, inputs_shape): if inputs_shape[1].value is None: raise ValueError("Expected inputs.shape[-1] to be known, saw shape: %s" % inputs_shape) input_depth = inputs_shape[1].value h_depth = self._num_units self._kernel = self.add_variable( _WEIGHTS_VARIABLE_NAME, shape=[input_depth + h_depth, 4 * self._num_units]) self._bias = self.add_variable( _BIAS_VARIABLE_NAME, shape=[4 * self._num_units], initializer=init_ops.zeros_initializer(dtype=self.dtype)) self.built = True def call(self, inputs, state): """Long short-term memory cell (LSTM). Args: inputs: `2-D` tensor with shape `[batch_size, input_size]`. state: An `LSTMStateTuple` of state tensors, each shaped `[batch_size, num_units]`, if `state_is_tuple` has been set to `True`. Otherwise, a `Tensor` shaped `[batch_size, 2 * num_units]`. Returns: A pair containing the new hidden state, and the new state (either a `LSTMStateTuple` or a concatenated state, depending on `state_is_tuple`). """ sigmoid = math_ops.sigmoid one = constant_op.constant(1, dtype=dtypes.int32) # Parameters of gates are concatenated into one multiply for efficiency. if self._state_is_tuple: c, h = state else: c, h = array_ops.split(value=state, num_or_size_splits=2, axis=one) gate_inputs = math_ops.matmul( array_ops.concat([inputs, h], 1), self._kernel) gate_inputs = nn_ops.bias_add(gate_inputs, self._bias) # i = input_gate, j = new_input, f = forget_gate, o = output_gate i, j, f, o = array_ops.split( value=gate_inputs, num_or_size_splits=4, axis=one) forget_bias_tensor = constant_op.constant(self._forget_bias, dtype=f.dtype) # Note that using `add` and `multiply` instead of `+` and `*` gives a # performance improvement. So using those at the cost of readability. add = math_ops.add multiply = math_ops.multiply new_c = add(multiply(c, sigmoid(add(f, forget_bias_tensor))), multiply(sigmoid(i), self._activation(j))) new_h = multiply(self._activation(new_c), sigmoid(o)) if self._state_is_tuple: new_state = LSTMStateTuple(new_c, new_h) else: new_state = array_ops.concat([new_c, new_h], 1) return new_h, new_state
本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
TensorFlow RNN Cell源码解析
使用tensorflow:LSTM神经网络预测股票(一)
Tensorflow中循环神经网络及其Wrappers
Tensorlfow 实现基于LSTM的语言模型
程序员如何借助 AI 开挂股票神预测?| 技术头条
一文详解如何用 TensorFlow 实现基于 LSTM 的文本分类(附源码)
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服