Skip to content

Commit

Permalink
Add Sort API for Kernel Primitive API (#39734)
Browse files Browse the repository at this point in the history
* Add Sort API for Kernel Primitive API

* update & -> ptr
  • Loading branch information
AnnaTrainingG authored Feb 22, 2022
1 parent de760d2 commit f4e7488
Showing 1 changed file with 123 additions and 0 deletions.
123 changes: 123 additions & 0 deletions paddle/phi/kernels/primitive/compute_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,40 @@ __device__ __forceinline__ T BlockYReduce(T val, ReduceOp reducer) {
return shared_memory[threadIdx.x];
}

// Swap data
template <typename T>
__device__ __forceinline__ void Swap(T* first_value, T* second_value) {
T t_value;
t_value = (*first_value);
(*first_value) = (*second_value);
(*second_value) = t_value;
}

// swap with monotonic_type
template <typename T>
__device__ __forceinline__ void Comparator(T* first_value,
T* second_value,
int monotonic_type) {
if (((*first_value) > (*second_value)) == monotonic_type) {
Swap<T>(first_value, second_value);
}
}

template <typename T, typename IndexType>
__device__ __forceinline__ void ComparatorWithIndex(T* first_value,

T* second_value,
IndexType* first_index,
IndexType* second_index,
int monotonic_type) {
if ((*first_value > (*second_value)) == monotonic_type) {
// swap value
Swap<T>(first_value, second_value);
// swap index
Swap<IndexType>(first_index, second_index);
}
}

} // namespace details

/**
Expand Down Expand Up @@ -481,5 +515,94 @@ __device__ __forceinline__ void Cumsum(OutT* out,
static_cast<OutT>(temp[tidx + shared_size + (tidx + shared_size) / 32]);
}

#define SHARED_SIZE_LIMIT \
1024 // each thread load 2 data from global memory so SHARED_SIZE_LIMIT must
// larger than blockDim.x * 2
// if monotonic_type = 1 then increase
// if gridDim.x > 1 please set monotonic_type = blockIdx.x & 1; blockIdx.x % 2
// == 1 the increase
template <typename T>
__device__ __forceinline__ void Sort(T* dst,
const T* src_data,
int num,
int monotonic_type) {
// todo: set num = Pow2(num)
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than
// blockDim * 2
// Copy value and index from src and src_index
value[threadIdx.x] = src_data[0];
value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1];
// make bitonicSort
for (int size = 2; size < num; size <<= 1) {
int bitonic_type = (threadIdx.x & (size / 2)) != 0;
for (int stride = size / 2; stride > 0; stride >>= 1) {
__syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
details::Comparator<T>(&value[pos], &value[pos + stride], bitonic_type);
}
}
// last sort
for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) {
__syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
// last sort when monotonic_type = 1 then increase
details::Comparator<T>(&value[pos], &value[pos + stride], monotonic_type);
}
__syncthreads();
dst[0] = value[threadIdx.x];
dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
}

template <typename T, typename IndexType>
__device__ __forceinline__ void Sort(T* dst,
IndexType* dst_index,
const T* src_data,
IndexType* src_index,
int num,
int monotonic_type) {
// todo: set num = Pow2(num)
// shareMem for value and index num must smaller than SHARED_SIZE_LIMIT / 2
__shared__ T value[SHARED_SIZE_LIMIT]; // shareMem's size must larger than
// blockDim * 2
__shared__ IndexType index[SHARED_SIZE_LIMIT];
// Copy value and index from src and src_index
value[threadIdx.x] = src_data[0];
value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_data[1];
// index
index[threadIdx.x] = src_index[0];
index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)] = src_index[1];
// make bitonicSort
for (int size = 2; size < num; size <<= 1) {
int bitonic_type = (threadIdx.x & (size / 2)) != 0;
for (int stride = size / 2; stride > 0; stride >>= 1) {
__syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
details::ComparatorWithIndex<T, IndexType>(&value[pos],
&value[pos + stride],
&index[pos],
&index[pos + stride],
bitonic_type);
}
}

for (int stride = SHARED_SIZE_LIMIT / 2; stride > 0; stride >>= 1) {
__syncthreads();
int pos = 2 * threadIdx.x - (threadIdx.x & (stride - 1));
// last sort when monotonic_type = 1 then increase
details::ComparatorWithIndex<T, IndexType>(&value[pos],
&value[pos + stride],
&index[pos],
&index[pos + stride],
monotonic_type);
}

__syncthreads();
dst[0] = value[threadIdx.x];
dst[1] = value[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
dst_index[0] = index[threadIdx.x];
dst_index[1] = index[threadIdx.x + (SHARED_SIZE_LIMIT / 2)];
}

} // namespace kps
} // namespace phi

0 comments on commit f4e7488

Please sign in to comment.