Skip to content

Commit

Permalink
update & -> ptr
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Feb 21, 2022
1 parent 20a3ab1 commit 4107246
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions paddle/phi/kernels/primitive/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,31 +134,31 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {

// Swap data
template <typename T>
__device__ void Swap(T* first_value, T* second_value) {
__device__ __forceinline__ void Swap(T* first_value, T* second_value) {
T t_value;
t_value = first_value[0];
first_value[0] = second_value[0];
second_value[0] = t_value;
t_value = (*first_value);
(*first_value) = (*second_value);
(*second_value) = t_value;
}

// swap with monotonic_type
template <typename T>
__device__ inline void Comparator(T* first_value,
T* second_value,
int monotonic_type) {
if ((first_value > second_value) == monotonic_type) {
__device__ __forceinline__ void Comparator(T* first_value,
T* second_value,
int monotonic_type) {
if (((*first_value) > (*second_value)) == monotonic_type) {
Swap<T>(first_value, second_value);
}
}

template <typename T, typename IndexType>
__device__ inline void ComparatorWithIndex(T* first_value,
__device__ __forceinline__ void ComparatorWithIndex(T* first_value,

T* second_value,
IndexType* first_index,
IndexType* second_index,
int monotonic_type) {
if ((first_value > second_value) == monotonic_type) {
T* second_value,
IndexType* first_index,
IndexType* second_index,
int monotonic_type) {
if ((*first_value > (*second_value)) == monotonic_type) {
// swap value
Swap<T>(first_value, second_value);
// swap index
Expand Down Expand Up @@ -522,7 +522,10 @@ __device__ __forceinline__ void Cumsum(OutT* out,
// if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2
// == 1 the increase
template <typename T>
__device__ void Sort(T* dst, const T* src_data, int num, int monotonic_type) {
__device__ __forceinline__ void Sort(T* dst,
const T* src_data,
int num,
int monotonic_type) {
// todo: set num = Pow2(num)
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than
Expand Down Expand Up @@ -552,12 +555,12 @@ __device__ void Sort(T* dst, const T* src_data, int num, int monotonic_type) {
}

template <typename T, typename IndexType>
__device__ void Sort(T* dst,
IndexType* dst_index,
const T* src_data,
IndexType* src_index,
int num,
int monotonic_type) {
__device__ __forceinline__ void Sort(T* dst,
IndexType* dst_index,
const T* src_data,
IndexType* src_index,
int num,
int monotonic_type) {
// todo: set num = Pow2(num)
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than
Expand Down

1 comment on commit 4107246

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.