-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Conversation
I did hybridize( ) and pass "valid_length" to the unroll( ) function of BidirectionalCell, then returned AssertionError in line 79. Because symbol.split( ) return a symbol but not a symbol list. Result in the length of inputs dont equal parameter "length" when call unroll( ) to compute r_outputs and r_states.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Im confused, doesn't _as_list return a list?
def _as_list(obj):
"""A utility function that converts the argument to a list if it is not already.
Parameters
----------
obj : object
Returns
-------
If `obj` is a list or tuple, return it. Otherwise, return `[obj]` as a
single-element list.
"""
if isinstance(obj, (list, tuple)):
return obj
else:
return [obj]
I don't understand the fix. Can you paste the error?
@larroy _as_list would return a wrong list as it would wrap the single grouped symbol inside, instead of |
@BeyonderXX thanks for the fix. Could you add a test with the proper assertions? |
I‘m glad to! I will add a new test case as soon as possible. |
I did hybridize( ) and pass "valid_length" to the unroll( ) function of BidirectionalCell, then returned AssertionError in line 79. Because symbol.split( ) return a symbol but not a symbol list. Result in the length of inputs dont equal parameter "length" when call unroll( ) to compute r_outputs and r_states.
Fix the error of parameter.
I did hybridize( ) and pass "valid_length" to the unroll( ) function of BidirectionalCell, then returned AssertionError in line 79. Because symbol.split( ) return a symbol but not a symbol list. Result in the length of inputs dont equal parameter "length" when call unroll( ) to compute r_outputs and r_states.
@szha I reproduced the issue in windows, when pass valid_length to unroll( ) of BidirectionalCell, it would raise null pointer error after multiple calls. |
@szha I figured out the issue, because I passed a wrong parameter to unroll( ) which led to out of bounds. I have fixed the problem and rewrited my code in BidirectionalCell. So sorry to disturb you abruptly. Thanks! |
@BeyonderXX thanks for the fix! |
* upstream/master: (38 commits) Feature/mkldnn static (apache#13628) Fix the bug of BidirectionalCell (apache#13575) Set install path for libmxnet.so dynamic lib on Mac OS (apache#13629) add batch norm test (apache#13625) Scripts for building dependency libraries of MXNet (apache#13282) fix quantize pass error when the quantization supported Op are excluded in the model (apache#13596) Optimize C++ API (apache#13496) Fix warning in waitall doc (apache#13618) [MXNET-1225] Always use config.mk in make install instructions (apache#13364) [MXNET-1224]: improve scala maven jni build and packing. (apache#13493) [MXNET-1155] Add scala packageTest utility (apache#13046) fix the Float not showing correctly problem (apache#13617) apache#13385 [Clojure] - Turn examples into integration tests (apache#13554) Add Intel MKL blas to Jenkins (apache#13607) Revert "[MXNET-1198] MXNet Java API (apache#13162)" Reducing the length of setup tutorial (apache#13306) [MXNET-1182] Predictor example (apache#13237) [MXNET-1187] Added Java SSD Inference Tutorial for website (apache#13201) add defaults and clean up the tests (apache#13295) [MXNET-1181] Added command line alternative to IntelliJ in install instructions (apache#13267) ...
Description
I did hybridize( ) and pass "valid_length" to the unroll( ) function of BidirectionalCell, then returned AssertionError in line 79. Because symbol.split( ) return a symbol but not a symbol list. Result in the length of inputs is not equal to parameter "length" when call unroll( ) to compute r_outputs and r_states. Symbol support for converting to list through list( ) API, so that use list( ) function insted of _as_list( ).
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
python/mxnet/gluon/rnn/rnn_cell.py
tests/python/unittest/test_gluon_rnn.py
valid_length to it.