Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix gluon rnn cell single step unroll
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed May 30, 2019
1 parent 9c5b88f commit 928a7b0
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 27 deletions.
5 changes: 4 additions & 1 deletion python/mxnet/gluon/rnn/rnn_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,10 @@ def _reverse_sequences(sequences, unroll_step, valid_length=None):
reversed_sequences = F.SequenceReverse(F.stack(*sequences, axis=0),
sequence_length=valid_length,
use_sequence_length=True)
reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)
if unroll_step > 1 or F is symbol:
reversed_sequences = F.split(reversed_sequences, axis=0, num_outputs=unroll_step, squeeze_axis=True)
else:
reversed_sequences = [reversed_sequences[0]]

return reversed_sequences

Expand Down
53 changes: 27 additions & 26 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,32 +634,33 @@ def test_layer_fill_shape():


def test_bidirectional_unroll_valid_length():
# Test BidirectionalCell.
# In 1.3.1 version, after hybridize( ), BidirectionalCell would failed when pass valid_length to unroll( ).

class BiLSTM(gluon.nn.HybridBlock):
def __init__(self, rnn_size, time_step, **kwargs):
super(BiLSTM, self).__init__(**kwargs)
self.time_step = time_step
with self.name_scope():
self.bi_lstm = gluon.rnn.BidirectionalCell(
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'),
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'),
output_prefix='lstm_bi_')

def hybrid_forward(self, F, inputs, valid_len):
outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
layout='NTC', merge_outputs=True)
return outputs, states

rnn_size, time_step = 100, 3
net = BiLSTM(rnn_size, time_step)
net.initialize()
net.hybridize()
inputs_data = mx.nd.random.uniform(shape=(10, 3, 50))
valid_len = mx.nd.array([1]*10)
outputs, _ = net(inputs_data, valid_len)
assert outputs.shape == (10, 3, 200)
def _check_bidirectional_unroll_valid_length(length):
class BiLSTM(gluon.nn.HybridBlock):
def __init__(self, rnn_size, time_step, **kwargs):
super(BiLSTM, self).__init__(**kwargs)
self.time_step = time_step
with self.name_scope():
self.bi_lstm = gluon.rnn.BidirectionalCell(
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_l0_'),
gluon.rnn.LSTMCell(rnn_size, prefix='rnn_r0_'),
output_prefix='lstm_bi_')

def hybrid_forward(self, F, inputs, valid_len):
outputs, states = self.bi_lstm.unroll(self.time_step, inputs, valid_length=valid_len,
layout='NTC', merge_outputs=True)
return outputs, states

rnn_size = 100
net = BiLSTM(rnn_size, length)
net.initialize()
net.hybridize()
inputs_data = mx.nd.random.uniform(shape=(10, length, 50))
valid_len = mx.nd.array([length]*10)
outputs, _ = net(inputs_data, valid_len)
assert outputs.shape == (10, length, 200)

_check_bidirectional_unroll_valid_length(1)
_check_bidirectional_unroll_valid_length(3)


if __name__ == '__main__':
Expand Down

0 comments on commit 928a7b0

Please sign in to comment.