diff --git a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py index 368b98d662b1..dd8a2b4b5116 100644 --- a/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py +++ b/python/mxnet/contrib/onnx/onnx2mx/_op_translations.py @@ -398,8 +398,7 @@ def linalg_gemm(attrs, inputs, proto_obj): alpha = attrs['alpha'] if 'beta' in attrs: beta = attrs['beta'] - flatten_a = symbol.flatten(inputs[0]) - matmul_op = symbol.linalg_gemm2(A=flatten_a, B=inputs[1], + matmul_op = symbol.linalg_gemm2(A=inputs[0], B=inputs[1], transpose_a=trans_a, transpose_b=trans_b, alpha=alpha) gemm_op = symbol.broadcast_add(matmul_op, beta*inputs[2])