Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
sxjscience committed May 20, 2019
1 parent 1caefab commit 6918e7d
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/operator/nn/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ void LayerNormGPUContig(const LayerNormParam param,
cudaStream_t stream = Stream<gpu>::GetStream(ctx.get_stream<gpu>());
const dim3 dimBlock(32, nthread_y);
MXNET_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
typedef std::conditional<safe_acc, AccType, DType>::type AType;
typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
int nshared = nthread_y > 1 ? nthread_y * 32 * sizeof(AType)
+ (nthread_y / 2) * 32 * sizeof(int) : 0;
CheckLaunchParam(dimGrid, dimBlock);
Expand Down Expand Up @@ -637,7 +637,7 @@ void LayerNormGradGPUContig(const LayerNormParam param,
&gb_block_dim, &gb_grid_dim, &npart);
if (gamma_grad_req != kNullOp || beta_grad_req != kNullOp) {
MXNET_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
typedef std::conditional<safe_acc, AccType, DType>::type AType;
typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
Tensor<gpu, 1, AType> workspace =
ctx.requested[0].get_space_typed<gpu, 1, AType>(Shape1(2 * npart * nchannel), s);
AType* part_gamma_grad_ptr = workspace.dptr_;
Expand Down Expand Up @@ -696,7 +696,7 @@ void LayerNormGradGPUContig(const LayerNormParam param,
const int LOAD_UNROLL = 4;
if (data_grad_req != kNullOp) {
MXNET_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
typedef std::conditional<safe_acc, AccType, DType>::type AType;
typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
int nshared = data_block_dim.y > 1 ? data_block_dim.y * data_block_dim.x * sizeof(AType) : 0;
CheckLaunchParam(data_grid_dim, data_block_dim);
if (data_grad_req == kAddTo) {
Expand Down

0 comments on commit 6918e7d

Please sign in to comment.