Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP] add fp16&bf16 support for flatten op #52035

Merged
merged 5 commits into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle.fluid import core


class TestFlattenOp(OpTest):
Expand All @@ -31,7 +32,8 @@ def setUp(self):
self.stop_axis = -1
self.skip_cinn()
self.init_test_case()
self.inputs = {"X": np.random.random(self.in_shape).astype("float64")}
self.init_test_dtype()
self.init_input_data()
self.init_attrs()
self.outputs = {
"Out": self.inputs["X"].reshape(self.new_shape),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

output也需要对uint16做特殊处理,convert_float_to_uint16

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.inputs["X"]已经在生成的时候转为了uint16,此处无需再转换

Expand All @@ -42,10 +44,20 @@ def skip_cinn(self):
self.enable_cinn = True

def test_check_output(self):
self.check_output(no_check_set=["XShape"], check_prim=True)
if str(self.dtype) in {"float16", "uint16"}:
self.check_output_with_place(
core.CUDAPlace(0), no_check_set=["XShape"], check_prim=True
)
else:
self.check_output(no_check_set=["XShape"], check_prim=True)

def test_check_grad(self):
self.check_grad(["X"], "Out", check_prim=True)
if str(self.dtype) in {"float16", "uint16"}:
self.check_grad_with_place(
core.CUDAPlace(0), ["X"], "Out", check_prim=True
)
else:
self.check_grad(["X"], "Out", check_prim=True)

def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
Expand All @@ -59,6 +71,39 @@ def init_attrs(self):
"stop_axis": self.stop_axis,
}

def init_test_dtype(self):
self.dtype = "float64"

def init_input_data(self):
x = np.random.random(self.in_shape).astype("float32")
Copy link
Contributor

@ZzSean ZzSean Mar 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里所有数据类型都会被先初始化为float32的,应该改成self.dtype,对uint16单独处理即可

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

if str(self.dtype) == "uint16":
x = convert_float_to_uint16(x)
self.inputs = {"X": x}


class TestFlattenFP32Op(TestFlattenOp):
def init_test_dtype(self):
self.dtype = "float32"


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op(TestFlattenOp):
def init_test_dtype(self):
self.dtype = "float16"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op(TestFlattenOp):
def init_test_dtype(self):
self.dtype = "uint16"


class TestFlattenOp_1(TestFlattenOp):
def init_test_case(self):
Expand All @@ -74,6 +119,30 @@ def init_attrs(self):
}


class TestFlattenFP32Op_1(TestFlattenOp_1):
def init_test_dtype(self):
self.dtype = "float32"


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_1(TestFlattenOp_1):
def init_test_dtype(self):
self.dtype = "float16"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_1(TestFlattenOp_1):
def init_test_dtype(self):
self.dtype = "uint16"


class TestFlattenOp_2(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
Expand All @@ -88,6 +157,30 @@ def init_attrs(self):
}


class TestFlattenFP32Op_2(TestFlattenOp_2):
def init_test_dtype(self):
self.dtype = "float32"


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_2(TestFlattenOp_2):
def init_test_dtype(self):
self.dtype = "float16"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_2(TestFlattenOp_2):
def init_test_dtype(self):
self.dtype = "uint16"


class TestFlattenOp_3(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
Expand All @@ -102,6 +195,30 @@ def init_attrs(self):
}


class TestFlattenFP32Op_3(TestFlattenOp_3):
def init_test_dtype(self):
self.dtype = "float32"


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_3(TestFlattenOp_3):
def init_test_dtype(self):
self.dtype = "float16"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_3(TestFlattenOp_3):
def init_test_dtype(self):
self.dtype = "uint16"


class TestFlattenOp_4(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
Expand All @@ -116,6 +233,30 @@ def init_attrs(self):
}


class TestFlattenFP32Op_4(TestFlattenOp_4):
def init_test_dtype(self):
self.dtype = "float32"


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_4(TestFlattenOp_4):
def init_test_dtype(self):
self.dtype = "float16"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_4(TestFlattenOp_4):
def init_test_dtype(self):
self.dtype = "uint16"


class TestFlattenOp_5(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
Expand All @@ -130,6 +271,30 @@ def init_attrs(self):
}


class TestFlattenFP32Op_5(TestFlattenOp_5):
def init_test_dtype(self):
self.dtype = "float32"


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_5(TestFlattenOp_5):
def init_test_dtype(self):
self.dtype = "float16"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_5(TestFlattenOp_5):
def init_test_dtype(self):
self.dtype = "uint16"


class TestFlattenOp_6(TestFlattenOp):
def init_test_case(self):
self.in_shape = tuple()
Expand All @@ -147,6 +312,30 @@ def init_attrs(self):
}


class TestFlattenFP32Op_6(TestFlattenOp_6):
def init_test_dtype(self):
self.dtype = "float32"


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16Op_6(TestFlattenOp_6):
def init_test_dtype(self):
self.dtype = "float16"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16Op_6(TestFlattenOp_6):
def init_test_dtype(self):
self.dtype = "uint16"


class TestFlattenOpSixDims(TestFlattenOp):
def init_test_case(self):
self.in_shape = (3, 2, 3, 2, 4, 4)
Expand All @@ -161,6 +350,30 @@ def init_attrs(self):
}


class TestFlattenFP32OpSixDims(TestFlattenOpSixDims):
def init_test_dtype(self):
self.dtype = "float32"


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"core is not complied with CUDA",
)
class TestFlattenFP16OpSixDims(TestFlattenOpSixDims):
def init_test_dtype(self):
self.dtype = "float16"


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestFlattenBF16OpSixDims(TestFlattenOpSixDims):
def init_test_dtype(self):
self.dtype = "uint16"


class TestFlatten2OpError(unittest.TestCase):
def test_errors(self):
image_shape = (2, 3, 4, 4)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
'depthwise_conv2d',
'depthwise_conv2d_transpose',
'dropout',
'flatten_contiguous_range',
'fused_elemwise_activation',
'hinge_loss',
'huber_loss',
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,7 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
'int32',
'int64',
'uint8',
'uint16',
],
'flatten',
)
Expand Down