From 725cfe6aff672366a45bd94d1abe2dd9dbae55af Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Mon, 6 May 2019 14:26:45 +0800 Subject: [PATCH] [Bugfix] Fix layer norm for large input shape (#14870) * fix layer norm for large input shape * try to fix * use a larger eps * try to fix test * try to fix --- src/operator/nn/layer_norm-inl.h | 8 +-- tests/python/unittest/test_operator.py | 91 +++++++++++++++++++++++--- 2 files changed, 86 insertions(+), 13 deletions(-) diff --git a/src/operator/nn/layer_norm-inl.h b/src/operator/nn/layer_norm-inl.h index c7de7d734521..3fa2e91681fe 100644 --- a/src/operator/nn/layer_norm-inl.h +++ b/src/operator/nn/layer_norm-inl.h @@ -203,14 +203,14 @@ void LayerNormGradCompute(const nnvm::NodeAttrs& attrs, BROADCAST_NDIM_SWITCH(red_dst_shape.ndim(), NDim, { reduce_workspace_size = std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_src_shape, - kAddTo, red_dst_shape)); + broadcast::ReduceWorkspaceSize(s, red_dst_shape, + kAddTo, red_src_shape)); }); BROADCAST_NDIM_SWITCH(red_exclude_dst_shape.ndim(), NDim, { reduce_workspace_size = std::max(reduce_workspace_size, - broadcast::ReduceWorkspaceSize(s, red_exclude_src_shape, kAddTo, - red_exclude_dst_shape)); + broadcast::ReduceWorkspaceSize(s, red_exclude_dst_shape, kAddTo, + red_exclude_src_shape)); }); }); workspace = ctx.requested[0].get_space_typed( diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 2406a1c2f761..f2bd7d780ee4 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -3364,7 +3364,9 @@ def test_l2_normalization(): check_l2_normalization((nbatch, nchannel, height, width), mode, dtype) -def check_layer_normalization(in_shape, axis, eps, dtype=np.float32, forward_check_eps=1E-3): +def check_layer_normalization(in_shape, axis, eps, dtype=np.float32, + forward_check_eps=1E-3, backward_check_eps=1E-3, + npy_grad_check=True, finite_grad_check=True): def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): if axis < 0: axis += data.ndim @@ -3377,6 +3379,24 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): np.reshape(beta, broadcast_shape) return out + def npy_layer_norm_grad(data, gamma, out_grad, axis, eps): + if axis < 0: + axis += data.ndim + exclude_axis = tuple([ele for ele in range(data.ndim) if ele != axis]) + data_mean = data.mean(axis=axis, keepdims=True) + data_var = data.var(axis=axis, keepdims=True) + data_std = np.sqrt(data_var + eps) + centered_data = (data - data_mean) / data_std + gamma_grad = (centered_data * out_grad).sum(axis=exclude_axis, keepdims=True) + beta_grad = out_grad.sum(axis=exclude_axis, keepdims=True) + w = out_grad * gamma.reshape([1 if i != axis else data.shape[axis] for i in range(data.ndim)])\ + / data_std + data_grad = w - w.mean(axis=axis, keepdims=True)\ + - centered_data * (w * centered_data).mean(axis=axis, keepdims=True) + gamma_grad = gamma_grad.reshape((-1,)) + beta_grad = beta_grad.reshape((-1,)) + return data_grad, gamma_grad, beta_grad + ctx = default_context() data = np.random.normal(0, 1, in_shape).astype(dtype) gamma = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype) @@ -3392,10 +3412,51 @@ def npy_layer_norm(data, gamma, beta, axis=1, eps=1E-5): out_nd = exe.forward()[0] out = npy_layer_norm(data, gamma, beta, axis, eps) assert_almost_equal(out, out_nd.asnumpy(), forward_check_eps, forward_check_eps) - for req in ['write', 'add']: - check_numeric_gradient(out_s, {'data': data, 'gamma': gamma, 'beta': beta}, - grad_nodes={'data': req, 'gamma': req, 'beta': req}, - numeric_eps=1e-2, rtol=1e-2, atol=1e-2) + + if finite_grad_check: + for req in ['write', 'add']: + check_numeric_gradient(out_s, {'data': data, 'gamma': gamma, 'beta': beta}, + grad_nodes={'data': req, 'gamma': req, 'beta': req}, + numeric_eps=1e-2, rtol=1e-2, atol=1e-2) + + if npy_grad_check: + # Test for grad_req = write + out_grad = np.random.normal(0, 1, in_shape).astype(dtype) + exe = out_s.simple_bind(ctx, data=in_shape, grad_req='write') + exe.arg_dict['data'][:] = data + exe.arg_dict['gamma'][:] = gamma + exe.arg_dict['beta'][:] = beta + exe.forward() + exe.backward([mx.nd.array(out_grad, ctx=ctx)]) + gt_data_grad, gt_gamma_grad, gt_beta_grad =\ + npy_layer_norm_grad(data, gamma, out_grad, axis, eps) + assert_almost_equal(exe.grad_dict['data'].asnumpy(), gt_data_grad, backward_check_eps, backward_check_eps) + assert_almost_equal(exe.grad_dict['gamma'].asnumpy(), gt_gamma_grad, backward_check_eps, backward_check_eps) + assert_almost_equal(exe.grad_dict['beta'].asnumpy(), gt_beta_grad, backward_check_eps, backward_check_eps) + + # Test for grad_req = add + out_grad = np.random.normal(0, 1, in_shape).astype(dtype) + init_data_grad = np.random.normal(0, 1, in_shape).astype(dtype) + init_gamma_grad = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype) + init_beta_grad = np.random.normal(0, 1, (in_shape[axis],)).astype(dtype) + exe = out_s.simple_bind(ctx, data=in_shape, grad_req='add') + exe.arg_dict['data'][:] = data + exe.arg_dict['gamma'][:] = gamma + exe.arg_dict['beta'][:] = beta + exe.grad_dict['data'][:] = init_data_grad + exe.grad_dict['gamma'][:] = init_gamma_grad + exe.grad_dict['beta'][:] = init_beta_grad + exe.forward() + exe.backward([mx.nd.array(out_grad, ctx=ctx)]) + gt_data_grad, gt_gamma_grad, gt_beta_grad = \ + npy_layer_norm_grad(data, gamma, out_grad, axis, eps) + assert_almost_equal(exe.grad_dict['data'].asnumpy(), + gt_data_grad + init_data_grad, backward_check_eps, backward_check_eps) + assert_almost_equal(exe.grad_dict['gamma'].asnumpy(), + gt_gamma_grad + init_gamma_grad, backward_check_eps, backward_check_eps) + assert_almost_equal(exe.grad_dict['beta'].asnumpy(), + gt_beta_grad + init_beta_grad, backward_check_eps, backward_check_eps) + @with_seed() def test_norm(): @@ -3469,13 +3530,25 @@ def l2norm(input_data, axis=0, keepdims=True): def test_layer_norm(): - for dtype, forward_check_eps in zip([np.float16, np.float32, np.float64], - [1E-2, 1E-3, 1E-4]): - for in_shape in [(10, 6, 5), (10, 10)]: + for dtype, forward_check_eps, backward_check_eps in zip([np.float16, np.float32, np.float64], + [1E-2, 1E-3, 1E-4], + [1E-2, 1E-3, 1E-4]): + if dtype != np.float16: + in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10), (128 * 32, 512)], [True, True, False] + else: + in_shape_l, finite_grad_check_l = [(10, 6, 5), (10, 10)], [True, True] # large input + fp16 does not pass the forward check + for in_shape, finite_grad_check in zip(in_shape_l, finite_grad_check_l): for axis in range(-len(in_shape), len(in_shape)): for eps in [1E-2, 1E-3]: + if dtype == np.float16: + npy_grad_check = False + else: + npy_grad_check = True check_layer_normalization(in_shape, axis, eps, dtype=dtype, - forward_check_eps=forward_check_eps) + forward_check_eps=forward_check_eps, + backward_check_eps=backward_check_eps, + npy_grad_check=npy_grad_check, + finite_grad_check=finite_grad_check) # Numpy Implementation of Sequence Ops