Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

转换规则 No.299/305/316/317/322 #174

Merged
merged 2 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,15 @@
},
"torch.Tensor.copysign": {},
"torch.Tensor.copysign_": {},
"torch.Tensor.corrcoef": {},
"torch.Tensor.corrcoef": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.corrcoef",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle.linalg.corrcoef 和 paddle.Tensor.corrcoef哪个是正确的,和文档写的不一样。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已有 tests/test_corrcoef.py torch.corrcoef 转为 paddle.linalg.corrcoef

torch.Tensor.corrcoef 转为 paddle.Tensor.corrcoef

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

"args_list": [],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    "args_list": [],
    "kwargs_change": {},
    "paddle_default_kwargs": {
      "rowvar": true
    }

这个应该这三个部分可以去掉, 简化代码是不是好一点

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

"kwargs_change": {},
"paddle_default_kwargs": {
"rowvar": true
}
},
"torch.Tensor.cos": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.cos"
Expand Down Expand Up @@ -3503,6 +3511,19 @@
"out"
]
},
"torch.cumulative_trapezoid": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.cumulative_trapezoid",
"args_list": [
"y",
"x",
"dx",
"dim"
],
"kwargs_change": {
"dim": "axis"
}
},
"torch.deg2rad": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.deg2rad",
Expand Down Expand Up @@ -9003,6 +9024,28 @@
"input": "x"
}
},
"torch.special.i1": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.i1",
"args_list": [
"input",
"out"
],
"kwargs_change": {
"input": "x"
}
},
"torch.special.i1e": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.i1e",
"args_list": [
"input",
"out"
],
"kwargs_change": {
"input": "x"
}
},
"torch.special.log1p": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.log1p",
Expand Down
51 changes: 51 additions & 0 deletions tests/test_Tensor_corrcoef.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.Tensor.corrcoef")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[ 0.7308, 1.0060, 0.5270, 1.4516],
[-0.1383, 1.5706, 0.4724, 0.4141],
[ 0.1193, 0.2829, 0.9037, 0.3957],
[-0.8202, -0.6474, -0.1631, -0.6543]])
result = x.corrcoef()
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([[-0.1533, 2.3020, -0.1771, 0.5928],
[ 0.4338, -0.6537, 0.2296, 0.5946],
[-0.4932, 1.8386, -0.1039, 1.0440],
[ 0.1735, -0.8303, -0.3821, -0.4384],
[-0.1533, 2.3020, -0.1771, 0.5928],
[ 0.4338, -0.6537, 0.2296, 0.5946],
[-0.4932, 1.8386, -0.1039, 1.0440],
[ 0.1735, -0.8303, -0.3821, -0.4384]])
result = x.corrcoef()
"""
)
obj.run(pytorch_code, ["result"])
74 changes: 74 additions & 0 deletions tests/test_cumulative_trapezoid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import textwrap

from apibase import APIBase

obj = APIBase("torch.cumulative_trapezoid")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.cumulative_trapezoid(torch.tensor([1.0, 1, 1, 0, 1]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.tensor([1, 1, 1, 0, 1]).type(torch.float32)
x = torch.tensor([1, 2, 3, 0, 1]).type(torch.float32)
result = torch.cumulative_trapezoid(y, x)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.tensor([[0.6, 0.0, 0.0, 0.0]])
result = torch.cumulative_trapezoid(y, dx=2)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_4():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.arange(9).reshape(3, 3).type(torch.float32)
result = torch.cumulative_trapezoid(y, dim=0)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_5():
pytorch_code = textwrap.dedent(
"""
import torch
y = torch.arange(9).reshape(3, 3).type(torch.float32)
result = torch.cumulative_trapezoid(y, dim=1)
"""
)
obj.run(pytorch_code, ["result"])
53 changes: 53 additions & 0 deletions tests/test_special_i0.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.special.i0")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.special.i0(torch.tensor([1.0, 2.0, 3.0]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.0, 2.0, 3.0])
result = torch.special.i0(x)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = [1.0, 2.0, 3.0]
out = torch.tensor([])
result = torch.special.i0(torch.tensor(x), out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])
53 changes: 53 additions & 0 deletions tests/test_special_i1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.special.i1")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.special.i1(torch.tensor([1.0, 2.0, 3.0]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.0, 2.0, 3.0])
result = torch.special.i1(x)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = [1.0, 2.0, 3.0]
out = torch.tensor([])
result = torch.special.i1(torch.tensor(x), out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])
53 changes: 53 additions & 0 deletions tests/test_special_i1e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import textwrap

from apibase import APIBase

obj = APIBase("torch.special.i1e")


def test_case_1():
pytorch_code = textwrap.dedent(
"""
import torch
result = torch.special.i1e(torch.tensor([1.0, 2.0, 3.0]))
"""
)
obj.run(pytorch_code, ["result"])


def test_case_2():
pytorch_code = textwrap.dedent(
"""
import torch
x = torch.tensor([1.0, 2.0, 3.0])
result = torch.special.i1e(x)
"""
)
obj.run(pytorch_code, ["result"])


def test_case_3():
pytorch_code = textwrap.dedent(
"""
import torch
x = [1.0, 2.0, 3.0]
out = torch.tensor([])
result = torch.special.i1e(torch.tensor(x), out=out)
"""
)
obj.run(pytorch_code, ["result", "out"])