Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Add support for fast variable-length LSTM #14208

Merged
merged 36 commits into from
May 7, 2019
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
5e47a83
initial commit for variable length sequence support w/ cudnn
stephenrawls Feb 19, 2019
f70f2eb
removing check about all vectors on same context (need to add back in)
stephenrawls Feb 19, 2019
183d5b5
fixing commented-out code to actually coment-out what I wanted
stephenrawls Feb 20, 2019
665afab
fixing cudnn layout type to be unpacked in var-length seq case
stephenrawls Feb 20, 2019
8000d13
looks like param.batch_size_ etc weren't previousy getting set in cud…
stephenrawls Feb 20, 2019
569553c
must call cudnnSetRNNPaddingMode() to enable unpacked padded sequences
stephenrawls Feb 21, 2019
663c39e
cleaning up & adding unit tests
stephenrawls Feb 26, 2019
27cac24
cleanign up
stephenrawls Feb 26, 2019
5e387e4
cleanign up
stephenrawls Feb 26, 2019
49d2018
removing stringstream and checking for cudnn >= 7.2
stephenrawls Apr 28, 2019
8ee6696
fixing whitespace formatting errors; adding ifdef version guard for c…
stephenrawls Apr 28, 2019
6a78bd7
fixing a few syntax errors
stephenrawls Apr 28, 2019
e10ad25
changing order of arguments in hybird_forward for backward compatibility
stephenrawls Apr 28, 2019
2101c3e
more build validation fixes
stephenrawls Apr 28, 2019
57acd15
using emplace_back to make linter happy
stephenrawls Apr 28, 2019
6cb2a8c
adding import of mxnet.ndarray
stephenrawls Apr 28, 2019
4ce0862
switching order of sequence_length in hybrid_forward again
stephenrawls Apr 29, 2019
f8ea574
adding __call__ override to rnn layer to handle optional sequence_len…
stephenrawls Apr 29, 2019
a5d85c6
whoops swapped order of args in one place but not the other
stephenrawls Apr 29, 2019
a61f861
changing type() to isinstance() to make linter happy
stephenrawls Apr 29, 2019
862ce69
changing lstm var seq length call to explciitly name sequence_length …
stephenrawls Apr 29, 2019
e500463
fixing bad scope of if-statement checking state outputs
stephenrawls Apr 30, 2019
143ddbb
resolving reviewer comments
stephenrawls May 3, 2019
8ded5da
making linter happy by putting var definitions in appropriate ifdef
stephenrawls May 3, 2019
806cbbc
fixing linter again
stephenrawls May 3, 2019
8625537
fixing whitespace issues with linter
stephenrawls May 3, 2019
36dc234
fixing whitespace issues with linter
stephenrawls May 3, 2019
aeed5f4
fixing some typos that emerged fixing linter
stephenrawls May 3, 2019
ee05388
linter
stephenrawls May 3, 2019
5691c33
fixing more whitespace issues
stephenrawls May 3, 2019
fcf50a2
only access kTempSpace if on gpu
stephenrawls May 4, 2019
b0c978c
removing tabs that slipped in
stephenrawls May 4, 2019
6b14793
fixing too-long line
stephenrawls May 4, 2019
dd5bf2d
changing ifdef guard to be more generic
stephenrawls May 7, 2019
ad73af2
reverting change so whitespace stays same w/ master
stephenrawls May 7, 2019
1ee5155
adding todo comment
stephenrawls May 7, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions cpp-package/example/charRNN.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,9 @@ Symbol LSTMWithBuiltInRNNOp(int num_lstm_layer, int sequence_length, int input_d
auto rnn_h_init = Symbol::Variable("LSTM_init_h");
auto rnn_c_init = Symbol::Variable("LSTM_init_c");
auto rnn_params = Symbol::Variable("LSTM_parameters"); // See explanations near RNNXavier class
auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, num_hidden, num_lstm_layer,
RNNMode::kLstm, false, dropout, !isTrain);
auto variable_sequence_length = Symbol::Variable("sequence_length");
auto rnn = RNN(embed, rnn_params, rnn_h_init, rnn_c_init, variable_sequence_length, num_hidden,
num_lstm_layer, RNNMode::kLstm, false, dropout, !isTrain);
auto hidden = Reshape(rnn[0], Shape(), false, Shape(0, num_hidden), false);

auto cls_weight = Symbol::Variable("cls_weight");
Expand Down
41 changes: 29 additions & 12 deletions python/mxnet/gluon/rnn/rnn_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, hidden_size, num_layers, layout,
i2h_bias_initializer, h2h_bias_initializer,
mode, projection_size, h2r_weight_initializer,
lstm_state_clip_min, lstm_state_clip_max, lstm_state_clip_nan,
dtype, **kwargs):
dtype, use_sequence_length=False, **kwargs):
super(_RNNLayer, self).__init__(**kwargs)
assert layout in ('TNC', 'NTC'), \
"Invalid layout %s; must be one of ['TNC' or 'NTC']"%layout
Expand All @@ -58,6 +58,7 @@ def __init__(self, hidden_size, num_layers, layout,
self._lstm_state_clip_max = lstm_state_clip_max
self._lstm_state_clip_nan = lstm_state_clip_nan
self._dtype = dtype
self._use_sequence_length = use_sequence_length

self._gates = {'rnn_relu': 1, 'rnn_tanh': 1, 'lstm': 4, 'gru': 3}[mode]

Expand Down Expand Up @@ -219,29 +220,39 @@ def begin_state(self, batch_size=0, func=ndarray.zeros, **kwargs):
states.append(func(name='%sh0_%d'%(self.prefix, i), **info))
return states

def hybrid_forward(self, F, inputs, states=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]
skip_states = states is None
if skip_states:
if F is ndarray:
def __call__(self, inputs, states=None, sequence_length=None, **kwargs):
self.skip_states = states is None
if states is None:
if isinstance(inputs, ndarray.NDArray):
batch_size = inputs.shape[self._layout.find('N')]
states = self.begin_state(batch_size, ctx=inputs.context, dtype=inputs.dtype)
else:
states = self.begin_state(0, func=symbol.zeros)
if isinstance(states, tensor_types):
states = [states]

if self._use_sequence_length:
return super(_RNNLayer, self).__call__(inputs, states, sequence_length, **kwargs)
else:
return super(_RNNLayer, self).__call__(inputs, states, **kwargs)


def hybrid_forward(self, F, inputs, states, sequence_length=None, **kwargs):
if F is ndarray:
batch_size = inputs.shape[self._layout.find('N')]

if F is ndarray:
for state, info in zip(states, self.state_info(batch_size)):
if state.shape != info['shape']:
raise ValueError(
"Invalid recurrent state shape. Expecting %s, got %s."%(
str(info['shape']), str(state.shape)))
out = self._forward_kernel(F, inputs, states, **kwargs)
out = self._forward_kernel(F, inputs, states, sequence_length, **kwargs)

# out is (output, state)
return out[0] if skip_states else out
return out[0] if self.skip_states else out

def _forward_kernel(self, F, inputs, states, **kwargs):
def _forward_kernel(self, F, inputs, states, sequence_length, **kwargs):
""" forward using CUDNN or CPU kenrel"""
if self._layout == 'NTC':
inputs = F.swapaxes(inputs, dim1=0, dim2=1)
Expand All @@ -261,14 +272,20 @@ def _forward_kernel(self, F, inputs, states, **kwargs):

params = F._internal._rnn_param_concat(*params, dim=0)

rnn = F.RNN(inputs, params, *states, state_size=self._hidden_size,
projection_size=self._projection_size,
if self._use_sequence_length:
rnn_args = states + [sequence_length]
else:
rnn_args = states

rnn = F.RNN(inputs, params, *rnn_args, use_sequence_length=self._use_sequence_length,
state_size=self._hidden_size, projection_size=self._projection_size,
num_layers=self._num_layers, bidirectional=self._dir == 2,
p=self._dropout, state_outputs=True, mode=self._mode,
lstm_state_clip_min=self._lstm_state_clip_min,
lstm_state_clip_max=self._lstm_state_clip_max,
lstm_state_clip_nan=self._lstm_state_clip_nan)


if self._mode == 'lstm':
outputs, states = rnn[0], [rnn[1], rnn[2]]
else:
Expand Down
1 change: 1 addition & 0 deletions src/imperative/imperative_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ inline Context GetContext(const nnvm::NodeAttrs& attrs,
Context ctx;
if (inputs.size()) {
ctx = inputs[0]->ctx();

szha marked this conversation as resolved.
Show resolved Hide resolved
for (size_t i = 1; i < inputs.size(); ++i) {
CHECK_EQ(inputs[i]->ctx().dev_mask(), ctx.dev_mask())
<< "Operator " << attrs.op->name
Expand Down
Loading