Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Variable sequence length not handled correctly for BiDirectional layers #19323

Closed
zjost opened this issue Oct 9, 2020 · 7 comments
Closed

Variable sequence length not handled correctly for BiDirectional layers #19323

zjost opened this issue Oct 9, 2020 · 7 comments

Comments

@zjost
Copy link
Contributor

zjost commented Oct 9, 2020

Description

There are a couple of different issues related to the use of use_sequence_length in the _RNNLayer.

  1. It doesn't seem to be usable by GRU and RNN, but only LSTM. This is more thoroughly documented here
  2. For LSTM, it only seems to work properly when using GPU. When using CPU, padding additional elements impacts the output despite passing in the same sequence length. This means the output of a particular sequence would change depending on the maximum sequence length of the batch.

To Reproduce

#ctx = [mx.cpu()]
ctx = [mx.gpu(0)]
class TestModel(gluon.nn.HybridBlock):
    def __init__(self, bidirectional=True):
        super(TestModel, self).__init__(prefix="TestModel_")
        with self.name_scope():
            self.rnn = gluon.rnn.LSTM(hidden_size=1, bidirectional=bidirectional, use_sequence_length=True)
    
    def hybrid_forward(self, F, x, x_len):
        x = x.expand_dims(2) # add a feature dimension
        x = x.transpose((1, 0, 2)) # to make in (max_sequence_length, batch_size, other_feature_dims)
        out = self.rnn(x, sequence_length=x_len)
        out = F.SequenceLast(out, sequence_length=x_len, use_sequence_length=True)
        return out

net = TestModel(bidirectional=True)
net.initialize(mx.init.Xavier(), ctx=ctx, force_reinit=True)

pad_val = -1
example_codes = [[1,2], [1,pad_val]]
example_len = [2,1]
x_input = mx.nd.array(example_codes).as_in_context(ctx[0])
x_len_input = mx.nd.array(example_len).as_in_context(ctx[0])
mx.random.seed(0)

# Original
out1 = net(x_input, x_len_input)

# Extra padding on first token
x_input2 = mx.nd.array([k+[pad_val] for k in example_codes]).as_in_context(ctx[0])
out2 = net(x_input2, x_len_input)

# Note:  out1 != out2 when ctx = CPU for the backward cell

Steps to reproduce

Run the above code with CPU and GPU context and observe the output of the second column (i.e. from the backward LSTM cell).

What have you tried to solve it?

I've found that if you use x = F.SequenceMask(x, sequence_length=x_len, use_sequence_length=True) before passing to the RNN, the outputs match. This might suggest that the CPU implementation reverses the entire padded sequence for the backward LSTM cell, rather than just reversing the first x_len elements.

Note: I suspect #14208 is relevant given that the intended behavior works only for GPU/LSTM

Environment

Environment Information
----------Python Info----------
Version      : 3.6.5
Compiler     : GCC 7.2.0
Build        : ('default', 'Apr 29 2018 16:14:56')
Arch         : ('64bit', '')
------------Pip Info-----------
Version      : 10.0.1
Directory    : /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/pip
----------MXNet Info-----------
Version      : 1.6.0
Directory    : /home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet
Commit Hash   : 6eec9da55c5096079355d1f1a5fa58dcf35d6752
Library      : ['/home/ec2-user/anaconda3/envs/mxnet_p36/lib/python3.6/site-packages/mxnet/libmxnet.so']
Build features:
✔ CUDA
✔ CUDNN
✔ NCCL
✔ CUDA_RTC
✖ TENSORRT
✔ CPU_SSE
✔ CPU_SSE2
✔ CPU_SSE3
✔ CPU_SSE4_1
✔ CPU_SSE4_2
✖ CPU_SSE4A
✔ CPU_AVX
✖ CPU_AVX2
✔ OPENMP
✖ SSE
✔ F16C
✖ JEMALLOC
✔ BLAS_OPEN
✖ BLAS_ATLAS
✖ BLAS_MKL
✖ BLAS_APPLE
✔ LAPACK
✔ MKLDNN
✔ OPENCV
✖ CAFFE
✖ PROFILER
✔ DIST_KVSTORE
✖ CXX14
✖ INT64_TENSOR_SIZE
✔ SIGNAL_HANDLER
✖ DEBUG
✖ TVM_OP
----------System Info----------
Platform     : Linux-4.14.198-152.320.amzn2.x86_64-x86_64-with-glibc2.9
system       : Linux
node         : ip-172-31-46-69.us-west-2.compute.internal
release      : 4.14.198-152.320.amzn2.x86_64
version      : #1 SMP Wed Sep 23 23:57:28 UTC 2020
----------Hardware Info----------
machine      : x86_64
processor    : x86_64
Architecture:        x86_64
CPU op-mode(s):      32-bit, 64-bit
Byte Order:          Little Endian
CPU(s):              32
On-line CPU(s) list: 0-31
Thread(s) per core:  2
Core(s) per socket:  16
Socket(s):           1
NUMA node(s):        1
Vendor ID:           GenuineIntel
CPU family:          6
Model:               79
Model name:          Intel(R) Xeon(R) CPU E5-2686 v4 @ 2.30GHz
Stepping:            1
CPU MHz:             2700.082
CPU max MHz:         3000.0000
CPU min MHz:         1200.0000
BogoMIPS:            4600.04
Hypervisor vendor:   Xen
Virtualization type: full
L1d cache:           32K
L1i cache:           32K
L2 cache:            256K
L3 cache:            46080K
NUMA node0 CPU(s):   0-31
Flags:               fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single pti fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm rdseed adx xsaveopt
----------Network Test----------
Setting timeout: 10
Timing for MXNet: /~https://github.com/apache/incubator-mxnet, DNS: 0.0028 sec, LOAD: 0.6319 sec.
Timing for Gluon Tutorial(en): http://gluon.mxnet.io, DNS: 0.1442 sec, LOAD: 0.0671 sec.
Error open Gluon Tutorial(cn): https://zh.gluon.ai, <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:833)>, DNS finished in 0.07044124603271484 sec.
Timing for FashionMNIST: https://apache-mxnet.s3-accelerate.dualstack.amazonaws.com/gluon/dataset/fashion-mnist/train-labels-idx1-ubyte.gz, DNS: 0.0106 sec, LOAD: 0.1394 sec.
Timing for PYPI: https://pypi.python.org/pypi/pip, DNS: 0.0031 sec, LOAD: 0.3234 sec.
Error open Conda: https://repo.continuum.io/pkgs/free/, HTTP Error 403: Forbidden, DNS finished in 0.001977682113647461 sec.
----------Environment----------
KMP_DUPLICATE_LIB_OK="True"
KMP_INIT_AT_FORK="FALSE"
KMP_AFFINITY="granularity=fine,compact,1,0"
OMP_NUM_THREADS="16"
@szha
Copy link
Member

szha commented Oct 9, 2020

cc @anko-intel @TaoLv

@grygielski
Copy link
Contributor

Hi @zjost currently CPU does not support use_sequence_length=True in RNN layers. It seems that this information is missing in MKLDNN execution path but when you run this code with export MXNET_USE_MKLDNN_RNN=0 environment variable you will get the following error: MXNetError: RNN use_sequence_length option is only available for cuDNN version >= 7.2.

Your solution to use F.SequenceMask(x, sequence_length=x_len, use_sequence_length=True) is equivalent to setting pad_val = 0 instead of -1. However, it's not proper solution and it happened to work by accident. Padding with 0s yields correct result for bidirectional RNN layers only if all biases are equal to 0 which is the case here (default initializer for bias is zero). You can check it by changing LSTM layer initialization in your model to:

self.rnn = gluon.rnn.LSTM(hidden_size=1, bidirectional=bidirectional, input_size=1, use_sequence_length=True,
                          h2h_bias_initializer='one', i2h_bias_initializer='one')

For now, my suggestion would be to either use batch_size=1 or group sentences into batches of equal length.

@zjost
Copy link
Contributor Author

zjost commented Oct 14, 2020

To be clear, this is a problem for us because we use SageMaker and it makes it such that the same record gets different scores when it's invoked via the endpoint as a single record vs in a Batch Transform job. Running Batch Transform 1 record at a time takes way too long and we can't control how SageMaker splits the batches.

Also, any comment on point 1, about how this only seems to work with LSTM, not RNN/GRU, even when using correct cuDNN?

@grygielski
Copy link
Contributor

grygielski commented Oct 15, 2020

Regarding point 1. it seems that for now only LSTM supports use_sequence_length parameter. It was introduced by this PR: #14208 and it causes some argument order problem when using other RNN types. However, I'm not sure if this is bug or just lack of implementation on GPU side because I'm only familiar with CPU code and as I said before, use_sequence_length is not supported here.

I understand your concerns about performance of using batches of 1. If you want to get correct results of bidirectional RNN layers while running bigger batches you can create BidirectionalRNN layer by yourself with two RNN layers and concat after. Example of such layer:

ctx = [mx.cpu()]

class CustomBidirectionalRNNLayer(gluon.nn.HybridBlock):
    def __init__(self, hidden_size):
        super(CustomBidirectionalRNNLayer, self).__init__(prefix="bidir_rnn_")
        with self.name_scope():
            self.rnn_l2r = gluon.rnn.LSTM(hidden_size=hidden_size, bidirectional=False, prefix='l2r',
                                         h2h_bias_initializer='one', i2h_bias_initializer='one')
            self.rnn_r2l = gluon.rnn.LSTM(hidden_size=hidden_size, bidirectional=False, prefix='r2l',
                                         h2h_bias_initializer='one', i2h_bias_initializer='one')
    
    def hybrid_forward(self, F, x, x_len):
        l2r_out = self.rnn_l2r(x)
        r2l_out = self.rnn_r2l(F.SequenceReverse(x, sequence_length=x_len, use_sequence_length=True))
        out = F.concat(l2r_out, r2l_out, dim=2)
        return out

    
class TestModel(gluon.nn.HybridBlock):
    def __init__(self):
        super(TestModel, self).__init__(prefix="TestModel_")
        with self.name_scope():
            self.bidir_rnn = CustomBidirectionalRNNLayer(hidden_size=1)
            
            
    def hybrid_forward(self, F, x, x_len):
        x = x.expand_dims(2) # add a feature dimension
        x = x.transpose((1, 0, 2)) # to make in (max_sequence_length, batch_size, other_feature_dims)
        out = self.bidir_rnn(x, x_len)
        out = F.SequenceLast(out, sequence_length=x_len, use_sequence_length=True)
        return out
    
net = TestModel()
net.initialize(mx.init.Xavier(), ctx=ctx, force_reinit=True)

pad_val = 0
example_codes = [[1,2], [1,pad_val]]
example_len = [2,1]
x_input = mx.nd.array(example_codes).as_in_context(ctx[0])
x_len_input = mx.nd.array(example_len).as_in_context(ctx[0])
mx.random.seed(0)

# Original
out1 = net(x_input, x_len_input)

# Extra padding on first token
x_input2 = mx.nd.array([k+[pad_val] for k in example_codes]).as_in_context(ctx[0])
out2 = net(x_input2, x_len_input)

This solution also solves point 1. because use_sequence_length is used only in F.SequenceReverse/F.SequenceLast and not in RNN operator so it doesn't give any error. Let me know if you are fine with such workaround.

@zjost
Copy link
Contributor Author

zjost commented Oct 15, 2020

Thanks for taking the time to show an implementation of this.

Do you think a warning should be added to the documentation regarding the use of bidirectional? I don't think it's clear that this will not have intended behavior unless sequences are of equal length in a batch. Particularly so because if you start tracing down the base classes, it seems as though e.g. GRU will pass along the use_sequence_length kwarg.

Regarding #14208, I'm not sure why this fails for other RNN types since the code changes appear to be primarily to the _RNN base class rather than the LSTMCell.

@grygielski
Copy link
Contributor

Absolutely there should be some note in the documentation about use_sequence_length working only on GPU. For bidirectional=False it does not harm users but using it with bidirectional will lead to incorrect results. Also there should be a mention that it works only for LSTM right now.

I will try to explain you the problem why changing _RNN base class works for LSTM and not for GRU and RNN. The problem lies on the graph framework level (NNVM). Here the arguments of RNN are registered:
/~https://github.com/apache/incubator-mxnet/blob/d0ceecbb3e4f2154a7783cba8f6e152b8c9003b1/src/operator/rnn.cc#L414-L418
As you can see these 2 arguments are optional where cell_state exists only in LSTM and sequence_length is present where user set use_sentence_length to True. Listing of input arguments is defined by:
/~https://github.com/apache/incubator-mxnet/blob/d0ceecbb3e4f2154a7783cba8f6e152b8c9003b1/src/operator/rnn.cc#L37-L53
as well as in the enum:
/~https://github.com/apache/incubator-mxnet/blob/d0ceecbb3e4f2154a7783cba8f6e152b8c9003b1/src/operator/rnn-inl.h#L55-L56
Problem here is that conditional argument cannot exist without the previous ones being available. So this way if we use sequence_length it expects the previous one too (which is cell_state that exists only in LSTM). We can change ordering of these 2 arguments in the code. This way RNN/GRU won't crash while used with use_sequence_length but LSTM will crash when not used with use_sequence_length because it needs cell_state that is after. I'm not familiar with any workaround for that but I don't think we should focus on that now since GRU/RNN doesn't have any kernel with use_sequence_length anyway.

BTW sorry, I've made a little mistake with my CustomBidirectionalLayer in my previous post. To make it work correctly you have to reverse back r2l output before concat:

class CustomBidirectionalRNNLayer(gluon.nn.HybridBlock):
    def __init__(self, hidden_size):
        super(CustomBidirectionalRNNLayer, self).__init__(prefix="bidir_rnn_")
        with self.name_scope():
            self.rnn_l2r = gluon.rnn.LSTM(hidden_size=hidden_size, bidirectional=False, prefix='l2r',
                                         h2h_bias_initializer='one', i2h_bias_initializer='one')
            self.rnn_r2l = gluon.rnn.LSTM(hidden_size=hidden_size, bidirectional=False, prefix='r2l',
                                         h2h_bias_initializer='one', i2h_bias_initializer='one')
    
    def hybrid_forward(self, F, x, x_len):
        l2r_out = self.rnn_l2r(x)
        r2l_out = self.rnn_r2l(F.SequenceReverse(x, sequence_length=x_len, use_sequence_length=True))
        out = F.concat(l2r_out, F.SequenceReverse(r2l_out, sequence_length=x_len, use_sequence_length=True), dim=2)
        return out

For now I'll prepare a PR so the code fails when run with use_sequence_length on CPU so users will be notified.

leezu pushed a commit that referenced this issue Nov 4, 2020
…ength=True (#19466)

This PR is addressing #19323. I've added additional check for use_sequence_length parameter when choosing kernel to run. oneDNN does not support variable sequence length so the code right now raises an error.
@grygielski
Copy link
Contributor

@leezu I think we can close it since #19466 has been merged

@szha szha closed this as completed Nov 6, 2020
vidyaravipati pushed a commit to vidyaravipati/incubator-mxnet that referenced this issue Nov 11, 2020
…ength=True (apache#19466)

This PR is addressing apache#19323. I've added additional check for use_sequence_length parameter when choosing kernel to run. oneDNN does not support variable sequence length so the code right now raises an error.
chinakook pushed a commit to chinakook/mxnet that referenced this issue Nov 17, 2020
…ength=True (apache#19466)

This PR is addressing apache#19323. I've added additional check for use_sequence_length parameter when choosing kernel to run. oneDNN does not support variable sequence length so the code right now raises an error.
chinakook pushed a commit to chinakook/mxnet that referenced this issue Nov 19, 2020
…ength=True (apache#19466)

This PR is addressing apache#19323. I've added additional check for use_sequence_length parameter when choosing kernel to run. oneDNN does not support variable sequence length so the code right now raises an error.
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

No branches or pull requests

3 participants