diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py index 1b9afee14bf2..3ec1bab45f79 100644 --- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py +++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py @@ -22,6 +22,8 @@ from ...rnn import BidirectionalCell, SequentialRNNCell, ModifierCell, HybridRecurrentCell from ...rnn.rnn_cell import _format_sequence, _get_begin_state, _mask_sequence_variable_length from ... import tensor_types +from .... import symbol, ndarray +from ....base import _as_list class VariationalDropoutCell(ModifierCell): """ @@ -315,3 +317,64 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, return next_r, [next_r, next_c] # pylint: enable= arguments-differ + + +def _format_sequence(inputs, layout, in_layout=None): + assert inputs is not None, \ + "unroll(inputs=None) has been deprecated. " \ + "Please create input variables outside unroll." + + axis = layout.find('T') + batch_axis = layout.find('N') + batch_size = 0 + in_axis = in_layout.find('T') if in_layout is not None else axis + assert isinstance(inputs, tensor_types) + if isinstance(inputs, symbol.Symbol): + F = symbol + else: + F = ndarray + batch_size = inputs.shape[batch_axis] + + if axis != in_axis: + inputs = F.swapaxes(inputs, dim1=axis, dim2=in_axis) + + return inputs, axis, F, batch_size + + +def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0, + layout='NTC', valid_length=None): + inputs, axis, F, batch_size = _format_sequence(inputs, layout) + states = begin_state + + if drop_inputs: + inputs = F.Dropout(inputs, p=drop_inputs, axes=(axis,)) + + if valid_length is None: + def loop_body(inputs, states): + return cell(inputs, states) + else: + zeros = [] + for i in range(len(states)): + zeros.append(F.zeros_like(states[i])) + states = _as_list(states) + states.append(F.zeros((1))) + def loop_body(inputs, states): + cell_states = states[:-1] + iter_no = states[-1] + out, new_states = cell(inputs, cell_states) + for i in range(len(new_states)): + new_states[i] = F.where(F.broadcast_greater(valid_length, iter_no), + new_states[i], zeros[i]) + new_states.append(iter_no + 1) + return out, new_states + + outputs, states = F.contrib.foreach(loop_body, inputs, states) + if drop_outputs: + outputs = F.Dropout(outputs, p=drop_outputs, axes=(axis,)) + if valid_length is not None: + outputs = F.SequenceMask(outputs, sequence_length=valid_length, + use_sequence_length=True, axis=axis) + # the last state is the iteration number. We don't need it. + return outputs, states[:-1] + else: + return outputs, states diff --git a/tests/python/unittest/test_gluon_contrib.py b/tests/python/unittest/test_gluon_contrib.py index a1cd8ea537d7..98526e51f093 100644 --- a/tests/python/unittest/test_gluon_contrib.py +++ b/tests/python/unittest/test_gluon_contrib.py @@ -17,10 +17,12 @@ from __future__ import print_function import mxnet as mx +import copy +from mxnet import gluon from mxnet.gluon import contrib from mxnet.gluon import nn from mxnet.gluon.contrib.nn import Concurrent, HybridConcurrent, Identity, SparseEmbedding -from mxnet.test_utils import almost_equal +from mxnet.test_utils import almost_equal, default_context, assert_almost_equal from common import setup_module, with_seed, teardown import numpy as np from numpy.testing import assert_allclose @@ -228,6 +230,84 @@ def test_sampler(): assert list(interval_sampler) == [0, 3, 6, 9] +class TestRNNLayer(gluon.HybridBlock): + def __init__(self, cell_type, hidden_size, prefix=None, params=None): + super(TestRNNLayer, self).__init__(prefix=prefix, params=params) + self.cell = cell_type(hidden_size, prefix='rnn_') + + def hybrid_forward(self, F, inputs, states, valid_length): + if isinstance(valid_length, list) and len(valid_length) == 0: + valid_length = None + return contrib.rnn.rnn_cell.unroll(self.cell, inputs, states, + valid_length=valid_length, layout='TNC') + +def check_unroll(cell_type, num_states): + batch_size = 1 + input_size = 5 + hidden_size = 3 + seq_len = 1 + rnn_data = mx.nd.normal(loc=0, scale=1, shape=(seq_len, batch_size, input_size)) + valid_length = mx.nd.round(mx.nd.random.uniform(low=1, high=10, shape=(batch_size))) + state_shape = (batch_size, hidden_size) + states = [mx.nd.normal(loc=0, scale=1, shape=state_shape) for i in range(num_states)] + + cell = cell_type(hidden_size, prefix='rnn_') + cell.initialize(ctx=default_context()) + cell(rnn_data[0], states) + params1 = cell.collect_params() + orig_params1 = copy.deepcopy(params1) + + trainer = gluon.Trainer(params1, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res1, states1 = cell.unroll(seq_len, rnn_data, states, valid_length=valid_length, + layout='TNC', merge_outputs=True) + res1.backward() + trainer.step(batch_size) + + configs = [ + #{}, + {'static_alloc': True}, + #{'static_alloc': True, 'static_shape': True} + ] + # We can't pass None to a hybrid block, but it accepts an empty list. + # so we use an empty list to represent valid_length if it's None. + if valid_length is None: + valid_length = [] + for config in configs: + layer = TestRNNLayer(cell_type, hidden_size) + layer.initialize(ctx=default_context()) + layer.hybridize(**config) + res2, states2 = layer(rnn_data, states, valid_length) + params2 = layer.collect_params() + for key, val in orig_params1.items(): + params2[key].set_data(copy.deepcopy(val.data())) + + trainer = gluon.Trainer(params2, 'sgd', {'learning_rate' : 0.03}) + with mx.autograd.record(): + res2, states2 = layer(rnn_data, states, valid_length) + assert_almost_equal(res1.asnumpy(), res2.asnumpy(), rtol=0.001, atol=0.0001) + assert len(states1) == len(states2) + for i in range(len(states1)): + assert_almost_equal(states1[i].asnumpy(), states2[i].asnumpy(), + rtol=0.001, atol=0.0001) + res2.backward() + trainer.step(batch_size) + + for key, val in params1.items(): + weight1 = val.data() + weight2 = params2[key].data() + assert_almost_equal(weight1.asnumpy(), weight2.asnumpy(), + rtol=0.001, atol=0.0001) + + +@with_seed() +def test_contrib_unroll(): + cell_types = [(gluon.rnn.RNNCell, 1), (gluon.rnn.LSTMCell, 2), + (gluon.rnn.GRUCell, 1)] + for cell_type, num_states in cell_types: + check_unroll(cell_type, num_states) + + if __name__ == '__main__': import nose nose.runmodule()