Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sxjscience committed May 20, 2019
1 parent 6918e7d commit dbcce28
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/operator/nn/layer_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ void LayerNormGPUContig(const LayerNormParam param,
}
cudaStream_t stream = Stream<gpu>::GetStream(ctx.get_stream<gpu>());
const dim3 dimBlock(32, nthread_y);
MXNET_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
int nshared = nthread_y > 1 ? nthread_y * 32 * sizeof(AType)
+ (nthread_y / 2) * 32 * sizeof(int) : 0;
Expand Down Expand Up @@ -636,7 +636,7 @@ void LayerNormGradGPUContig(const LayerNormParam param,
GetGammaBetaGradKernelParams(nbatch, nchannel, &part_grad_block_dim, &part_grad_grid_dim,
&gb_block_dim, &gb_grid_dim, &npart);
if (gamma_grad_req != kNullOp || beta_grad_req != kNullOp) {
MXNET_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
Tensor<gpu, 1, AType> workspace =
ctx.requested[0].get_space_typed<gpu, 1, AType>(Shape1(2 * npart * nchannel), s);
Expand Down Expand Up @@ -695,7 +695,7 @@ void LayerNormGradGPUContig(const LayerNormParam param,
const dim3 data_block_dim(32, nthread_y);
const int LOAD_UNROLL = 4;
if (data_grad_req != kNullOp) {
MXNET_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
MXNET_REAL_ACC_TYPE_SWITCH(in_data.type_flag_, DType, AccType, {
typedef typename std::conditional<safe_acc, AccType, DType>::type AType;
int nshared = data_block_dim.y > 1 ? data_block_dim.y * data_block_dim.x * sizeof(AType) : 0;
CheckLaunchParam(data_grid_dim, data_block_dim);
Expand Down

0 comments on commit dbcce28

Please sign in to comment.