-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch norm latest #5103
Batch norm latest #5103
Conversation
x_val, scale_val, bias_val, epsilon, "NHWC") | ||
|
||
# | ||
mean_out = saved_mean * (1. - momentum) + momentum * mean |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these three line can be moved to _reference_training
@@ -200,12 +307,14 @@ def test_with_place(place): | |||
self.__assert_close(x_grad_tensor, x_grad_ref, "x_grad") | |||
self.__assert_close(scale_grad_tensor, scale_grad_ref, "scale_grad") | |||
self.__assert_close(bias_grad_tensor, bias_grad_ref, "bias_grad") | |||
print "op test backward passed: ", tensor_format | |||
|
|||
places = [core.CPUPlace()] | |||
if core.is_compile_gpu() and core.op_support_gpu("batch_norm"): | |||
places.append(core.GPUPlace(0)) | |||
for place in places: | |||
test_with_place(place) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可以加一层循环,把TensorFormat作为参数传给 test_with_place_and_dataformat
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
把四种case都测试到
CPU NCHW
CPU NWHC
GPU NCHW
GPU NWHC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great Job!
Test with both NCHW and NHWC passed.