Skip to content

Commit

Permalink
Merge branch 'doc0716_3' into doc0716_2
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Jul 16, 2023
2 parents 08a2a7c + 1f33c27 commit d603767
Show file tree
Hide file tree
Showing 5 changed files with 268 additions and 0 deletions.
35 changes: 35 additions & 0 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -3511,6 +3511,19 @@
"out"
]
},
"torch.cumulative_trapezoid": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.cumulative_trapezoid",
"args_list": [
"y",
"x",
"dx",
"dim"
],
"kwargs_change": {
"dim": "axis"
}
},
"torch.deg2rad": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.deg2rad",
Expand Down Expand Up @@ -9011,6 +9024,28 @@
"input": "x"
}
},
"torch.special.i1": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.i1",
"args_list": [
"input",
"out"
],
"kwargs_change": {
"input": "x"
}
},
"torch.special.i1e": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.i1e",
"args_list": [
"input",
"out"
],
"kwargs_change": {
"input": "x"
}
},
"torch.special.log1p": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.log1p",
Expand Down
74 changes: 74 additions & 0 deletions tests/test_cumulative_trapezoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.cumulative_trapezoid")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.cumulative_trapezoid(torch.tensor([1.0, 1, 1, 0, 1]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32)
x = torch.tensor([1, 2, 3, 0, 1]).type(torch.float32)
result = torch.cumulative_trapezoid(y, x)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.tensor([[0.6, 0.0, 0.0, 0.0]])
result = torch.cumulative_trapezoid(y, dx=2)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.arange(9).reshape(3, 3).type(torch.float32)
result = torch.cumulative_trapezoid(y, dim=0)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.arange(9).reshape(3, 3).type(torch.float32)
result = torch.cumulative_trapezoid(y, dim=1)
"""
)
obj.run(pytorch_code, ["result"])
53 changes: 53 additions & 0 deletions tests/test_special_i0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.special.i0")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.special.i0(torch.tensor([1.0, 2.0, 3.0]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.0, 2.0, 3.0])
result = torch.special.i0(x)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = [1.0, 2.0, 3.0]
out = torch.tensor([])
result = torch.special.i0(torch.tensor(x), out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])
53 changes: 53 additions & 0 deletions tests/test_special_i1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.special.i1")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.special.i1(torch.tensor([1.0, 2.0, 3.0]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.0, 2.0, 3.0])
result = torch.special.i1(x)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = [1.0, 2.0, 3.0]
out = torch.tensor([])
result = torch.special.i1(torch.tensor(x), out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])
53 changes: 53 additions & 0 deletions tests/test_special_i1e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.special.i1e")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.special.i1e(torch.tensor([1.0, 2.0, 3.0]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.0, 2.0, 3.0])
result = torch.special.i1e(x)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = [1.0, 2.0, 3.0]
out = torch.tensor([])
result = torch.special.i1e(torch.tensor(x), out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])

0 comments on commit d603767

Please sign in to comment.