From 0b2734e203f7e8ec33829a403442b0ea2e3e5c00 Mon Sep 17 00:00:00 2001 From: Da Zheng Date: Mon, 27 Aug 2018 18:09:01 +0000 Subject: [PATCH] remove _contrib_format_sequence. --- python/mxnet/gluon/contrib/rnn/rnn_cell.py | 25 ++-------------------- 1 file changed, 2 insertions(+), 23 deletions(-) diff --git a/python/mxnet/gluon/contrib/rnn/rnn_cell.py b/python/mxnet/gluon/contrib/rnn/rnn_cell.py index 5883be420797..d0fdd23b36b9 100644 --- a/python/mxnet/gluon/contrib/rnn/rnn_cell.py +++ b/python/mxnet/gluon/contrib/rnn/rnn_cell.py @@ -319,28 +319,6 @@ def hybrid_forward(self, F, inputs, states, i2h_weight, # pylint: enable= arguments-differ -def _contrib_format_sequence(inputs, layout, in_layout=None): - assert inputs is not None, \ - "unroll(inputs=None) has been deprecated. " \ - "Please create input variables outside unroll." - - axis = layout.find('T') - batch_axis = layout.find('N') - batch_size = 0 - in_axis = in_layout.find('T') if in_layout is not None else axis - assert isinstance(inputs, tensor_types) - if isinstance(inputs, symbol.Symbol): - F = symbol - else: - F = ndarray - batch_size = inputs.shape[batch_axis] - - if axis != in_axis: - inputs = F.swapaxes(inputs, dim1=axis, dim2=in_axis) - - return inputs, axis, F, batch_size - - def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0, layout='TNC', valid_length=None): """Unrolls an RNN cell across time steps. @@ -407,7 +385,8 @@ def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0, """ - inputs, axis, F, _ = _contrib_format_sequence(inputs, layout) + # Merge is always True, so we don't need length. + inputs, axis, F, _ = _format_sequence(0, inputs, layout, True) if axis != 0: axes = list(range(len(layout))) tmp = axes[0]