Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Aug 23, 2023
1 parent af8658f commit b8a6863
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 1 deletion.
28 changes: 28 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -2983,6 +2983,21 @@
"input": "x"
}
},
"torch.autocast": {
"Matcher": "AutocastMatcher",
"paddle_api": "paddle.amp.auto_cast",
"args_list": [
"device_type",
"enabled",
"dtype",
"cache_enabled"
],
"kwargs_change": {
"device_type": "",
"enabled": "enable",
"cache_enabled": ""
}
},
"torch.autograd.Function": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.autograd.PyLayer"
Expand Down Expand Up @@ -3579,6 +3594,19 @@
"correction": "ddof"
}
},
"torch.cpu.amp.autocast": {
"Matcher": "AutocastMatcher",
"paddle_api": "paddle.amp.auto_cast",
"args_list": [
"enabled",
"dtype",
"cache_enabled"
],
"kwargs_change": {
"enabled": "enable",
"cache_enabled": ""
}
},
"torch.cross": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.cross",
Expand Down
29 changes: 28 additions & 1 deletion paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def generate_code(self, kwargs):
res = "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(new_kwargs))

if dtype_v:
res += ".astype({})".format(dtype_v)
res += ".cast({})".format(dtype_v)

if pin_memory_v:
res += ".pin_memory()"
Expand Down Expand Up @@ -169,6 +169,33 @@ def generate_code(self, kwargs):
return code


class AutocastMatcher(BaseMatcher):
def generate_code(self, kwargs):
kwargs_change = {}
if "kwargs_change" in self.api_mapping:
kwargs_change = self.api_mapping["kwargs_change"]
new_kwargs = {}
for k in list(kwargs.keys()):
if k in kwargs_change:
if kwargs_change[k]:
# rename/copy in new_kwargs
if isinstance(kwargs_change[k], list):
for v in kwargs_change[k]:
new_kwargs[v] = kwargs[k]
else:
new_kwargs[kwargs_change[k]] = kwargs[k]
else:
# remove in new_kwargs
kwargs.pop(k)
else:
# copy to new_kwargs
new_kwargs[k] = kwargs.pop(k)

new_kwargs = self.set_paddle_default_kwargs(new_kwargs)

return "{}({})".format(self.get_paddle_api(), self.kwargs_to_str(new_kwargs))


class TorchAddMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "alpha" in kwargs:
Expand Down
95 changes: 95 additions & 0 deletions tests/test_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.autocast")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
with torch.autocast(device_type='cpu', enabled=False):
result = x*x
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
with torch.autocast(device_type='cpu'):
result = x*x
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
with torch.autocast(device_type='cpu', enabled=True, dtype=torch.bfloat16):
result = x*x
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
# with torch.autocast('cpu', dtype=torch.bfloat16, enabled=True, cache_enabled=None):
with torch.autocast('cpu', dtype=torch.bfloat16, enabled=True, cache_enabled=None):
result = x*x
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
with torch.autocast(device_type='cpu', dtype=torch.bfloat16, enabled=True, cache_enabled=None):
result = x*x
"""
)
obj.run(pytorch_code, ["result"])
94 changes: 94 additions & 0 deletions tests/test_cpu_amp_autocast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.cpu.amp.autocast")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
with torch.cpu.amp.autocast(enabled=False):
result = x*x
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
with torch.cpu.amp.autocast():
result = x*x
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
with torch.cpu.amp.autocast(dtype=torch.bfloat16, enabled=True):
result = x*x
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
with torch.cpu.amp.autocast(True, torch.bfloat16, True):
result = x*x
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
import torch.nn as nn
x = torch.tensor([[[-1.3020, -0.1005, 0.5766, 0.6351, -0.8893, 0.0253, -0.1756, 1.2913],
[-0.8833, -0.1369, -0.0168, -0.5409, -0.1511, -0.1240, -1.1870, -1.8816]]])
with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16, cache_enabled=True):
result = x*x
"""
)
obj.run(pytorch_code, ["result"])

0 comments on commit b8a6863

Please sign in to comment.