-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMP OP&Test] Norm bf16 #51083
[AMP OP&Test] Norm bf16 #51083
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -62,7 +62,7 @@ __global__ void Normalize(const T* x, | |||
MT reduce_result = BlockReduce(temp_storage).Sum(sum); | |||
|
|||
if (threadIdx.x == 0) { | |||
norm = square_root(reduce_result + static_cast<MT>(eps)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这句不用改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -86,7 +86,7 @@ void NormKernel(const Context& ctx, | |||
|
|||
auto xdim = in_x->dims(); | |||
if (axis < 0) axis = xdim.size() + axis; | |||
T eps = static_cast<T>(epsilon); | |||
float eps = epsilon; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个变量删了就行
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -157,6 +159,40 @@ def init_test_case(self): | |||
self.epsilon = 1e-8 | |||
|
|||
|
|||
@unittest.skipIf( | |||
not core.is_compiled_with_cuda(), | |||
"core is not compiled with CUDA and not support the bfloat16", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"and not support the bfloat16" 这句可以去掉
self.check_output_with_place(core.CUDAPlace(0)) | ||
|
||
def test_check_grad(self): | ||
pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个地方为什么直接pass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
95074f7
@@ -119,11 +121,11 @@ def init_dtype(self): | |||
self.dtype = "float16" | |||
|
|||
def test_check_output(self): | |||
self.check_output_with_place(fluid.core.CUDAPlace(0), atol=5e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无需设置默认值
|
||
def test_check_grad(self): | ||
self.check_grad_with_place( | ||
fluid.core.CUDAPlace(0), ['X'], 'Out', max_relative_error=0.05 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
尝试使用默认值
self.python_out_sig = ["out"] | ||
|
||
def test_check_output(self): | ||
self.check_output_with_place(core.CUDAPlace(0), atol=1e-2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认值无需设置
core.CUDAPlace(0), | ||
['X'], | ||
'Out', | ||
user_defined_grads=self.gradient, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.gradient 这个参数并没有计算啊,尝试删除
'Out', | ||
user_defined_grads=self.gradient, | ||
check_eager=True, | ||
max_relative_error=1e-2, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
默认值不用设置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe
add norm op test and support bf16 norm