Skip to content
This repository has been archived by the owner on Jan 15, 2024. It is now read-only.

Commit

Permalink
Extend unit tests to include one-step decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
leezu committed Oct 28, 2019
1 parent 94e7d73 commit 46f5b64
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 2 deletions.
2 changes: 1 addition & 1 deletion scripts/machine_translation/gnmt.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
"""Encoder and decoder usded in sequence-to-sequence learning."""
__all__ = ['GNMTEncoder', 'GNMTDecoder', 'get_gnmt_encoder_decoder']
__all__ = ['GNMTEncoder', 'GNMTDecoder', 'GNMTOneStepDecoder', 'get_gnmt_encoder_decoder']

import mxnet as mx
from mxnet.base import _as_list
Expand Down
50 changes: 49 additions & 1 deletion scripts/tests/test_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from mxnet.test_utils import assert_almost_equal
from ..machine_translation.gnmt import *
from gluonnlp.model.transformer import *
from gluonnlp.model.transformer import TransformerDecoder
from gluonnlp.model.transformer import TransformerDecoder, TransformerOneStepDecoder


def test_gnmt_encoder():
Expand Down Expand Up @@ -65,6 +65,11 @@ def test_gnmt_encoder_decoder():
output_attention=output_attention, use_residual=use_residual, prefix='gnmt_decoder_')
decoder.initialize(ctx=ctx)
decoder.hybridize()
one_step_decoder = GNMTOneStepDecoder(cell_type="lstm", num_layers=3, hidden_size=num_hidden,
dropout=0.0, output_attention=output_attention,
use_residual=use_residual, prefix='gnmt_decoder_',
params=decoder.collect_params())
one_step_decoder.hybridize()
for batch_size in [4]:
for src_seq_length, tgt_seq_length in [(5, 10), (10, 5)]:
src_seq_nd = mx.nd.random.normal(0, 1, shape=(batch_size, src_seq_length, 4), ctx=ctx)
Expand Down Expand Up @@ -98,6 +103,25 @@ def test_gnmt_encoder_decoder():
else:
assert(len(additional_outputs) == 0)

# Test one-step forwarding
output, new_states, additional_outputs = one_step_decoder(
tgt_seq_nd[:, 0, :], decoder_states)
assert(output.shape == (batch_size, num_hidden))
if output_attention:
assert(len(additional_outputs) == 1)
attention_out = additional_outputs[0].asnumpy()
assert(attention_out.shape == (batch_size, 1, src_seq_length))
for i in range(batch_size):
mem_v_len = int(src_valid_length_npy[i])
if mem_v_len < src_seq_length - 1:
assert((attention_out[i, :, mem_v_len:] == 0).all())
if mem_v_len > 0:
assert_almost_equal(attention_out[i, :, :].sum(axis=-1),
np.ones(attention_out.shape[1]))
else:
assert(len(additional_outputs) == 0)


def test_transformer_encoder():
ctx = mx.current_context()
for num_layers in range(1, 3):
Expand Down Expand Up @@ -186,3 +210,27 @@ def test_transformer_encoder_decoder(output_attention, use_residual, batch_size,
np.ones(attention_out.shape[1:3]))
else:
assert(len(additional_outputs) == 0)

# Test one step forwarding
decoder = TransformerOneStepDecoder(num_layers=3, units=units, hidden_size=32,
num_heads=8, max_length=10, dropout=0.0,
output_attention=output_attention,
use_residual=use_residual,
prefix='transformer_decoder_',
params=decoder.collect_params())
decoder.hybridize()
output, new_states, additional_outputs = decoder(tgt_seq_nd[:, 0, :], decoder_states)
assert(output.shape == (batch_size, units))
if output_attention:
assert(len(additional_outputs) == 3)
attention_out = additional_outputs[0][1].asnumpy()
assert(attention_out.shape == (batch_size, 8, 1, src_seq_length))
for i in range(batch_size):
mem_v_len = int(src_valid_length_npy[i])
if mem_v_len < src_seq_length - 1:
assert((attention_out[i, :, :, mem_v_len:] == 0).all())
if mem_v_len > 0:
assert_almost_equal(attention_out[i, :, :, :].sum(axis=-1),
np.ones(attention_out.shape[1:3]))
else:
assert(len(additional_outputs) == 0)

0 comments on commit 46f5b64

Please sign in to comment.