Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix the bug of BidirectionalCell #13575

Merged
merged 8 commits into from
Dec 13, 2018
Merged
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ List of Contributors
* [Rahul Padmanabhan](/~https://github.com/rahul3)
* [Yuxi Hu](/~https://github.com/yuxihu)
* [Harsh Patel](/~https://github.com/harshp8l)
* [Xiao Wang](/~https://github.com/BeyonderXX)

Label Bot
---------
Expand Down
37 changes: 20 additions & 17 deletions python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,23 @@ def _mask_sequence_variable_length(F, data, length, valid_length, time_axis, mer
squeeze_axis=True))
return outputs

def _reverse_sequences(sequences, unroll_step, valid_length=None):
if isinstance(sequences[0], symbol.Symbol):
F = symbol
else:
F = ndarray

if valid_length is None:
reversed_sequences = list(reversed(sequences))
else:
reversed_sequences = F.SequenceReverse(F.stack(*sequences, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)

return reversed_sequences


class RecurrentCell(Block):
"""Abstract base class for RNN cells

Expand Down Expand Up @@ -1035,14 +1052,7 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
self.reset()

inputs, axis, F, batch_size = _format_sequence(length, inputs, layout, False)
if valid_length is None:
reversed_inputs = list(reversed(inputs))
else:
reversed_inputs = F.SequenceReverse(F.stack(*inputs, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
reversed_inputs = _as_list(F.split(reversed_inputs, axis=0, num_outputs=length,
squeeze_axis=True))
reversed_inputs = list(_reverse_sequences(inputs, length, valid_length))
begin_state = _get_begin_state(self, F, begin_state, inputs, batch_size)

states = begin_state
Expand All @@ -1056,15 +1066,8 @@ def unroll(self, length, inputs, begin_state=None, layout='NTC', merge_outputs=N
begin_state=states[len(l_cell.state_info(batch_size)):],
layout=layout, merge_outputs=False,
valid_length=valid_length)
if valid_length is None:
reversed_r_outputs = list(reversed(r_outputs))
else:
reversed_r_outputs = F.SequenceReverse(F.stack(*r_outputs, axis=0),
sequence_length=valid_length,
use_sequence_length=True,
axis=0)
reversed_r_outputs = _as_list(F.split(reversed_r_outputs, axis=0, num_outputs=length,
squeeze_axis=True))
reversed_r_outputs = _reverse_sequences(r_outputs, length, valid_length)

if merge_outputs is None:
merge_outputs = isinstance(l_outputs, tensor_types)
l_outputs, _, _, _ = _format_sequence(None, l_outputs, layout, merge_outputs)
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,34 @@ def test_layer_fill_shape():
assert layer.l0_i2h_weight.shape[1] == 7, layer.l0_i2h_weight.shape[1]


def test_bidirectional_unroll_valid_length():
# Test BidirectionalCell.
# In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ).
class BiLSTM(gluon.nn.HybridBlock):
def __init__(self, rnn_size, time_step, **kwargs):
super(BiLSTM, self).__init__(**kwargs)
self.time_step = time_step
with self.name_scope():
self.bi_lstm = gluon.rnn.BidirectionalCell(
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'),
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'),
output_prefix='lstm_bi_')

def hybrid_forward(self, F, inputs, valid_len):
outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
layout='NTC', merge_outputs=True)
return outputs, states

rnn_size, time_step = 100, 3
net = BiLSTM(rnn_size, time_step)
net.initialize()
net.hybridize()
inputs_data = mx.nd.random.uniform(shape=(10, 3, 50))
valid_len = mx.nd.array([1]*10)
outputs, _ = net(inputs_data, valid_len)
assert outputs.shape == (10, 3, 200)


if __name__ == '__main__':
import nose
nose.runmodule()