Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix unroll of BidirectionalCell
Browse files Browse the repository at this point in the history
  • Loading branch information
BeyonderXX committed Dec 11, 2018
1 parent 2540d3a commit c6fab04
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,21 @@ def _mask_sequence_variable_length(F, data, length, valid_length, time_axis, mer
squeeze_axis=True))
return outputs

def _reverse_sequence(sequences, time_step, valid_length=None):
def _reverse_sequences(sequences, unroll_step, valid_length=None):
if isinstance(sequences[0], symbol.Symbol):
F = symbol
else:
F = ndarray

if valid_length is None:
reversed_inputs = list(reversed(sequences))
reversed_sequences = list(reversed(sequences))
else:
reversed_inputs = F.SequenceReverse(F.stack(*sequences, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
reversed_inputs = F.split(reversed_inputs, axis=0, num_outputs=time_step, squeeze_axis=True)
reversed_sequences = F.SequenceReverse(F.stack(*sequences, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)

return reversed_inputs
return reversed_sequences


class RecurrentCell(Block):
Expand Down Expand Up @@ -1052,7 +1052,8 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
self.reset()

inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
reversed_inputs = list(_reverse_sequence(inputs, length, valid_length))
reversed_inputs = list(_reverse_sequences(inputs, length, valid_length))
begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)

states = begin_state
l_cell, r_cell = self._children.values()
Expand All @@ -1065,7 +1066,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
begin_state=states[len(l_cell.state_info(batch_size)):],
layout=layout, merge_outputs=False,
valid_length=valid_length)
reversed_r_outputs = _reverse_sequence(r_outputs, length, valid_length)
reversed_r_outputs = _reverse_sequences(r_outputs, length, valid_length)

if merge_outputs is None:
merge_outputs = isinstance(l_outputs, tensor_types)
Expand Down

0 comments on commit c6fab04

Please sign in to comment.