Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Nov 16, 2021
1 parent f90bef4 commit 681dd40
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,11 @@ __device__ __forceinline__ void ReadDataBc(
* the lowest dimension.
*/
template <typename Tx, typename Ty, int NX, int NY, int BlockSize, int Rank,
typename IndexCal, typename Transform, bool IsBoundary = false>
typename IndexCal, typename Functor, bool IsBoundary = false>
__device__ __forceinline__ void ReadDataReduce(
Ty* dst, const Tx* __restrict__ src, int block_offset,
const IndexCal& index_cal, int size_nx, int size_ny, int stride_nx,
int stride_ny, Transform transform, bool reduce_last_dim) {
int stride_ny, Functor func, bool reduce_last_dim) {
int thread_offset = 0;
int left_idx = 0;
if (reduce_last_dim) {
Expand All @@ -385,7 +385,7 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
dst[ny] = static_cast<Ty>(transform(src[index_src]));
dst[ny] = static_cast<Ty>(func(src[index_src]));
thread_offset += stride_ny;
}
} else {
Expand All @@ -400,7 +400,7 @@ __device__ __forceinline__ void ReadDataReduce(
}
}
uint32_t index_src = index_cal(thread_offset + block_offset);
dst[nx + ny * NX] = static_cast<Ty>(transform(src[index_src]));
dst[nx + ny * NX] = static_cast<Ty>(func(src[index_src]));
thread_offset += stride_ny;
}
}
Expand Down

1 comment on commit 681dd40

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.