Skip to content

Commit

Permalink
Add support for fast variable-length LSTM (apache#14208)
Browse files Browse the repository at this point in the history
* initial commit for variable length sequence support w/ cudnn

* removing check about all vectors on same context (need to add back in)

* fixing commented-out code to actually coment-out what I wanted

* fixing cudnn layout type to be unpacked in var-length seq case

* looks like param.batch_size_ etc weren't previousy getting set in cudnn operator code. still doesn't fix cudnn error though

* must call cudnnSetRNNPaddingMode() to enable unpacked padded sequences

* cleaning up & adding unit tests

* cleanign up

* cleanign up

* removing stringstream and checking for cudnn >= 7.2

* fixing whitespace formatting errors; adding ifdef version guard for cudnn padding

* fixing a few syntax errors

* changing order of arguments in hybird_forward for backward compatibility

* more build validation fixes

* using emplace_back to make linter happy

* adding import of mxnet.ndarray

* switching order of sequence_length in hybrid_forward again

* adding __call__ override to rnn layer to handle optional sequence_length argument

* whoops swapped order of args in one place but not the other

* changing type() to isinstance() to make linter happy

* changing lstm var seq length call to explciitly name sequence_length parameter

* fixing bad scope of if-statement checking state outputs

* resolving reviewer comments

* making linter happy by putting var definitions in appropriate ifdef

* fixing linter again

* fixing whitespace issues with linter

* fixing whitespace issues with linter

* fixing some typos that emerged fixing linter

* linter

* fixing more whitespace issues

* only access kTempSpace if on gpu

* removing tabs that slipped in

* fixing too-long line

* changing ifdef guard to be more generic

* reverting change so whitespace stays same w/ master

* adding todo comment
  • Loading branch information
stephenrawls authored and haohuw committed Jun 23, 2019
1 parent 4e972e1 commit eb5c6c7
Show file tree
Hide file tree
Showing 6 changed files with 343 additions and 117 deletions.
5 changes: 3 additions & 2 deletions cpp-package/example/charRNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,9 @@ Symbol LSTMWithBuiltInRNNOp(int num_lstm_layer, int sequence_length, int input_d
auto rnn_h_init = Symbol::Variable("LSTM_init_h");
auto rnn_c_init = Symbol::Variable("LSTM_init_c");
auto rnn_params = Symbol::Variable("LSTM_parameters"); // See explanations near RNNXavier class
auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, num_hidden, num_lstm_layer,
RNNMode::kLstm, false, dropout, !isTrain);
auto variable_sequence_length = Symbol::Variable("sequence_length");
auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, variable_sequence_length, num_hidden,
num_lstm_layer, RNNMode::kLstm, false, dropout, !isTrain);
auto hidden = Reshape(rnn[0], Shape(), false, Shape(0, num_hidden), false);

auto cls_weight = Symbol::Variable("cls_weight");
Expand Down
41 changes: 29 additions & 12 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, hidden_size, num_layers, layout,
i2h_bias_initializer, h2h_bias_initializer,
mode, projection_size, h2r_weight_initializer,
lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
dtype, **kwargs):
dtype, use_sequence_length=False, **kwargs):
super(_RNNLayer, self).__init__(**kwargs)
assert layout in ('TNC', 'NTC'), \
"Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
Expand All @@ -58,6 +58,7 @@ def __init__(self, hidden_size, num_layers, layout,
self._lstm_state_clip_max = lstm_state_clip_max
self._lstm_state_clip_nan = lstm_state_clip_nan
self._dtype = dtype
self._use_sequence_length = use_sequence_length

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

Expand Down Expand Up @@ -219,29 +220,39 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
return states

def hybrid_forward(self, F, inputs, states=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]
skip_states = states is None
if skip_states:
if F is ndarray:
def __call__(self, inputs, states=None, sequence_length=None, **kwargs):
self.skip_states = states is None
if states is None:
if isinstance(inputs, ndarray.NDArray):
batch_size = inputs.shape[self._layout.find('N')]
states = self.begin_state(batch_size, ctx=inputs.context, dtype=inputs.dtype)
else:
states = self.begin_state(0, func=symbol.zeros)
if isinstance(states, tensor_types):
states = [states]

if self._use_sequence_length:
return super(_RNNLayer, self).__call__(inputs, states, sequence_length, **kwargs)
else:
return super(_RNNLayer, self).__call__(inputs, states, **kwargs)


def hybrid_forward(self, F, inputs, states, sequence_length=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]

if F is ndarray:
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
out = self._forward_kernel(F, inputs, states, **kwargs)
out = self._forward_kernel(F, inputs, states, sequence_length, **kwargs)

# out is (output, state)
return out[0] if skip_states else out
return out[0] if self.skip_states else out

def _forward_kernel(self, F, inputs, states, **kwargs):
def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = F.swapaxes(inputs, dim1=0, dim2=1)
Expand All @@ -261,14 +272,20 @@ def _forward_kernel(self, F, inputs, states, **kwargs):

params = F._internal._rnn_param_concat(*params, dim=0)

rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
projection_size=self._projection_size,
if self._use_sequence_length:
rnn_args = states + [sequence_length]
else:
rnn_args = states

rnn = F.RNN(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length,
state_size=self._hidden_size, projection_size=self._projection_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode,
lstm_state_clip_min=self._lstm_state_clip_min,
lstm_state_clip_max=self._lstm_state_clip_max,
lstm_state_clip_nan=self._lstm_state_clip_nan)


if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
else:
Expand Down
Loading

0 comments on commit eb5c6c7

Please sign in to comment.