Skip to content

Commit

Permalink
GPU implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Critsium-xy committed Jan 13, 2025
1 parent 911a80d commit 7646974
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 9 deletions.
4 changes: 2 additions & 2 deletions source/module_base/blas_connector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ void vector_mul_vector(const int& dim, T* result, const T* vector1, const T* vec
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
vector_mul_vector_complex_wrapper(d, dim, result, vector1, vector2);
vector_mul_vector_gpu(dim, result, vector1, vector2);
#endif
}
}
Expand All @@ -691,7 +691,7 @@ void vector_div_vector(const int& dim, T* result, const T* vector1, const T* vec
}
else if (device_type == base_device::AbacusDevice_t::GpuDevice){
#ifdef __CUDA
vector_div_vector_complex_wrapper(d, dim, result, vector1, vector2);
vector_mul_vector_gpu(dim, result, vector1, vector2);
#endif
}
}
79 changes: 72 additions & 7 deletions source/module_base/kernels/cuda/math_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -192,25 +192,22 @@ __global__ void vector_div_vector_kernel(
}

template <typename FPTYPE>
inline void vector_div_constant_complex_wrapper(const base_device::DEVICE_GPU* d,
const int dim,
inline void vector_div_vector_complex_wrapper(const int dim,
std::complex<FPTYPE>* result,
const std::complex<FPTYPE>* vector,
const FPTYPE constant)
{
thrust::complex<FPTYPE>* result_tmp = reinterpret_cast<thrust::complex<FPTYPE>*>(result);
const thrust::complex<FPTYPE>* vector_tmp = reinterpret_cast<const thrust::complex<FPTYPE>*>(vector);

const thrust::complex<FPTYPE>* vector1_tmp = reinterpret_cast<const thrust::complex<FPTYPE>*>(vector1);
int thread = THREADS_PER_BLOCK;
int block = (dim + thread - 1) / thread;
vector_div_constant_kernel<thrust::complex<FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector_tmp, constant);
vector_div_vector_kernel<thrust::complex<FPTYPE>> <<<block, thread >>> (dim, result_tmp, vector1_tmp, vector2);

cudaCheckOnDebug();
}

template <typename FPTYPE>
inline void vector_mul_vector_complex_wrapper(const base_device::DEVICE_GPU* d,
const int& dim,
inline void vector_mul_vector_complex_wrapper(const int& dim,
std::complex<FPTYPE>* result,
const std::complex<FPTYPE>* vector1,
const FPTYPE* vector2)
Expand All @@ -224,4 +221,72 @@ inline void vector_mul_vector_complex_wrapper(const base_device::DEVICE_GPU* d,
cudaCheckOnDebug();
}

void vector_div_vector_gpu(const int& dim,
double* result,
const double* vector1,
const double* vector2)
{
int thread = THREADS_PER_BLOCK;
int block = (dim + thread - 1) / thread;
vector_div_vector_kernel<double> <<<block, thread >>> (dim, result, vector1, vector2);

cudaCheckOnDebug();
}

void vector_div_vector_gpu(const int& dim,
float* result,
const float* vector1,
const float* vector2)
{
int thread = THREADS_PER_BLOCK;
int block = (dim + thread - 1) / thread;
vector_div_vector_kernel<float> <<<block, thread >>> (dim, result, vector1, vector2);

cudaCheckOnDebug();
}

void vector_div_vector_gpu(const int& dim, std::complex<float>* result, const std::complex<float>* vector1, const float* vector2)
{
vector_div_vector_complex_wrapper(dim, result, vector1, vector2);
}

void vector_div_vector_gpu(const int& dim, std::complex<double>* result, const std::complex<double>* vector1, const double* vector2)
{
vector_div_vector_complex_wrapper(dim, result, vector1, vector2);
}

void vector_mul_vector_gpu(const int& dim,
double* result,
const double* vector1,
const double* vector2)
{
int thread = THREADS_PER_BLOCK;
int block = (dim + thread - 1) / thread;
vector_mul_vector_kernel<double> <<<block, thread >>> (dim, result, vector1, vector2);

cudaCheckOnDebug();
}

void vector_mul_vector_gpu(const int& dim,
float* result,
const float* vector1,
const float* vector2)
{
int thread = THREADS_PER_BLOCK;
int block = (dim + thread - 1) / thread;
vector_mul_vector_kernel<float> <<<block, thread >>> (dim, result, vector1, vector2);

cudaCheckOnDebug();
}

void vector_mul_vector_gpu(const int& dim, std::complex<float>* result, const std::complex<float>* vector1, const float* vector2)
{
vector_mul_vector_complex_wrapper(dim, result, vector1, vector2);
}

void vector_mul_vector_gpu(const int& dim, std::complex<double>* result, const std::complex<double>* vector1, const double* vector2)
{
vector_mul_vector_complex_wrapper(dim, result, vector1, vector2);
}

} // namespace ModuleBase

0 comments on commit 7646974

Please sign in to comment.