Skip to content

Commit

Permalink
Pytorch Conv Transpose Padding Fix (#7958)
Browse files Browse the repository at this point in the history
* fix conv transpose import from TF

* fix String::fromwe() to String::from()

* * fixing pytorch converter to take into account the output_padding parameter for conv transpose operations
* updating pytorch converter to correctly convert conv1d to conv1d in tvm inestead of a flattened conv2d unless under circumstances of grouped convolution
* updating pytorch converter to correctly convert conv1d transpose to conv1d transpose in tvm instead of a flattened conv2d transpose
* added tests to cover these latest additions

* * removing print statements used for debugging

* * fixing typos and formatting

* * fixing formatting

* * fixing grammar

* * formatting fixes

* * updated formatting after running pylint and python_format checks

Co-authored-by: Mikael Sevenier <mikael.sevenier@sima.ai>
  • Loading branch information
Jeffrey-Sima and Mikael Sevenier authored May 14, 2021
1 parent 3bf65b7 commit aa7bfe7
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 20 deletions.
45 changes: 34 additions & 11 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,32 +977,52 @@ def convolution(self, inputs, input_types):
kernel_size = weight_shape[2:]
use_bias = isinstance(bias, _expr.Expr)

if len(kernel_size) == 1:
strides = (1,) + strides
padding = (0,) + padding
dilation = (1,) + dilation
# We are trying to invoke various relay operations through a single conv_op variable.
# However the function signatures for some operations have additional attributes so we
# pass these in along with the standard ones.
additional_arguments = dict()

if use_transpose:
if len(kernel_size) == 3:
conv_op = _op.nn.conv3d_transpose
else:
elif len(kernel_size) == 2:
conv_op = _op.nn.conv2d_transpose
else:
conv_op = _op.nn.conv1d_transpose
output_padding = tuple(inputs[7])
additional_arguments["output_padding"] = output_padding

else:
if len(kernel_size) == 3:
conv_op = _op.nn.conv3d
else:
elif len(kernel_size) == 2:
conv_op = _op.nn.conv2d
else:
conv_op = _op.nn.conv1d

if len(kernel_size) == 3:
data_layout = "NCDHW"
kernel_layout = "OIDHW"
else:
elif len(kernel_size) == 2:
data_layout = "NCHW"
kernel_layout = "OIHW"

if len(kernel_size) == 1:
else:
data_layout = "NCW"
kernel_layout = "OIW"

# Conv1d does not currently support grouped convolution so we convert it to conv2d
is_grouped_conv1d = False
if groups > 1 and len(kernel_size) == 1 and not use_transpose:
is_grouped_conv1d = True
conv_op = _op.nn.conv2d
kernel_size = [1] + kernel_size
strides = (1,) + strides
padding = (0,) + padding
dilation = (1,) + dilation
data = _op.expand_dims(data, axis=2)
weight = _op.expand_dims(weight, axis=2)
data_layout = "NCHW"
kernel_layout = "OIHW"

conv_out = conv_op(
data,
Expand All @@ -1012,17 +1032,20 @@ def convolution(self, inputs, input_types):
dilation=dilation,
groups=groups,
channels=channels,
kernel_size=[1] + kernel_size if len(kernel_size) == 1 else kernel_size,
kernel_size=kernel_size,
data_layout=data_layout,
kernel_layout=kernel_layout,
out_layout="",
out_dtype="",
**additional_arguments,
)
if use_bias:
res = _op.nn.bias_add(conv_out, bias)
else:
res = conv_out
if len(kernel_size) == 1:
if is_grouped_conv1d:
# Because we conducted grouped conv1d convolution through conv2d we must
# squeeze the output to get the correct result.
res = _op.squeeze(res, axis=[2])
return res

Expand Down
65 changes: 56 additions & 9 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from tvm import relay
from tvm.contrib import graph_executor
from tvm.contrib.nvcc import have_fp16
import pytest

sys.setrecursionlimit(10000)

Expand Down Expand Up @@ -965,17 +966,63 @@ def forward(self, *args):


@tvm.testing.uses_gpu
def test_forward_conv_transpose():
torch.set_grad_enabled(False)
conv2d_input_shape = [1, 3, 10, 10]
@pytest.mark.parametrize("in_channels", [3], ids=lambda x: "in_channels=" + str(x))
@pytest.mark.parametrize("out_channels", [5], ids=lambda x: "out_channels=" + str(x))
@pytest.mark.parametrize("kernel_size", [3], ids=lambda x: "kernel_size=" + str(x))
@pytest.mark.parametrize("output_padding", [0, 1, 2], ids=lambda x: "output_padding=" + str(x))
@pytest.mark.parametrize("groups", [1], ids=lambda x: "groups=" + str(x))
@pytest.mark.parametrize("bias", [True, False], ids=lambda x: "bias=" + str(x))
def test_forward_conv_transpose(
in_channels, out_channels, kernel_size, output_padding, bias, groups
):
# Note we do not test with groups > 1 because that is not supported
# in tvm for conv transpose operations

# Output padding must be smaller than either stride or dilation so we
# opt to make the stride 1 + output padding
stride = output_padding + 1

# Conv 3D Transpose Tests
conv3d_input_shape = [1, in_channels, 16, 16, 16]
conv3d_input_data = torch.rand(conv3d_input_shape).float()
conv3d_transpose = torch.nn.ConvTranspose3d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
output_padding=output_padding,
groups=groups,
bias=bias,
).eval()
verify_model(conv3d_transpose, conv3d_input_data)

# Conv 2D Transpose Tests
conv2d_input_shape = [1, in_channels, 128, 256]
conv2d_input_data = torch.rand(conv2d_input_shape).float()
verify_model(torch.nn.ConvTranspose2d(3, 6, 7, bias=True), input_data=conv2d_input_data)
verify_model(torch.nn.ConvTranspose2d(3, 12, 3, bias=False), input_data=conv2d_input_data)

conv1d_input_shape = [1, 3, 10]
conv2d_transpose = torch.nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
output_padding=output_padding,
groups=groups,
bias=bias,
).eval()
verify_model(conv2d_transpose, conv2d_input_data)

# # Conv 1D Transpose Tests
conv1d_input_shape = [1, in_channels, 10]
conv1d_input_data = torch.rand(conv1d_input_shape).float()
verify_model(torch.nn.ConvTranspose1d(3, 6, 7, bias=True), input_data=conv1d_input_data)
verify_model(torch.nn.ConvTranspose1d(3, 12, 3, bias=False), input_data=conv1d_input_data)
conv1d_transpose = torch.nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
output_padding=output_padding,
groups=groups,
bias=bias,
).eval()
verify_model(conv1d_transpose, conv1d_input_data)


def test_forward_deform_conv():
Expand Down

0 comments on commit aa7bfe7

Please sign in to comment.