Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
remove _contrib_format_sequence.
Browse files Browse the repository at this point in the history
  • Loading branch information
zheng-da committed Aug 27, 2018
1 parent f7b5799 commit 0b2734e
Showing 1 changed file with 2 additions and 23 deletions.
25 changes: 2 additions & 23 deletions python/mxnet/gluon/contrib/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,28 +319,6 @@ def hybrid_forward(self, F, inputs, states, i2h_weight,
# pylint: enable= arguments-differ


def _contrib_format_sequence(inputs, layout, in_layout=None):
assert inputs is not None, \
"unroll(inputs=None) has been deprecated. " \
"Please create input variables outside unroll."

axis = layout.find('T')
batch_axis = layout.find('N')
batch_size = 0
in_axis = in_layout.find('T') if in_layout is not None else axis
assert isinstance(inputs, tensor_types)
if isinstance(inputs, symbol.Symbol):
F = symbol
else:
F = ndarray
batch_size = inputs.shape[batch_axis]

if axis != in_axis:
inputs = F.swapaxes(inputs, dim1=axis, dim2=in_axis)

return inputs, axis, F, batch_size


def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
layout='TNC', valid_length=None):
"""Unrolls an RNN cell across time steps.
Expand Down Expand Up @@ -407,7 +385,8 @@ def unroll(cell, inputs, begin_state, drop_inputs=0, drop_outputs=0,
<NDArray 3x2x5 @cpu(0)>
"""

inputs, axis, F, _ = _contrib_format_sequence(inputs, layout)
# Merge is always True, so we don't need length.
inputs, axis, F, _ = _format_sequence(0, inputs, layout, True)
if axis != 0:
axes = list(range(len(layout)))
tmp = axes[0]
Expand Down

0 comments on commit 0b2734e

Please sign in to comment.