Skip to content

Commit

Permalink
fix compile
Browse files Browse the repository at this point in the history
  • Loading branch information
Qianruipku committed Jan 19, 2025
1 parent e58bfe0 commit 0fde07e
Show file tree
Hide file tree
Showing 6 changed files with 463 additions and 467 deletions.
14 changes: 6 additions & 8 deletions source/module_base/para_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@ template <typename T, typename Device>
void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T beta, T* C_global)
{
const Device* ctx = {};
char transC = 'C';
char transN = 'N';
#ifdef __MPI
if (col_nproc > 1)
{
Expand Down Expand Up @@ -122,8 +120,8 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
if (col_rank == ip)
{
ModuleBase::gemm_op<T, Device>()(ctx,
transC,
transN,
'C',
'N',
ncolA,
ncolB,
nrow,
Expand All @@ -145,8 +143,8 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
Parallel_Common::recv_dev<T, Device>(Atmp_device, size, ip, 0, col_world, &status, A_tmp.data());
MPI_Wait(&requests[ip], &status);
ModuleBase::gemm_op<T, Device>()(ctx,
transC,
transN,
'C',
'N',
m,
ncolB,
nrow,
Expand Down Expand Up @@ -195,8 +193,8 @@ void PGemmCN<T, Device>::multiply(const T alpha, const T* A, const T* B, const T
{
T real_beta = row_rank == 0 ? beta : 0;
ModuleBase::gemm_op<T, Device>()(ctx,
transC,
transN,
'C',
'N',
ncolA,
ncolB,
nrow,
Expand Down
2 changes: 1 addition & 1 deletion source/module_hamilt_pw/hamilt_pwdft/fs_nonlocal_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ void FS_Nonlocal_tools<FPTYPE, Device>::reduce_pool_becp(const int& npm)
#ifdef __MPI
if (GlobalV::NPROC_IN_POOL > 1)
{
Parallel_Common::reduce_dev<FPTYPE,Device>(this->becp, size_becp_act, POOL_WORLD);
Parallel_Common::reduce_dev<std::complex<FPTYPE>, Device>(this->becp, size_becp_act, POOL_WORLD);
}
#endif
}
Expand Down
202 changes: 101 additions & 101 deletions source/module_hsolver/diago_cg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ void DiagoCG<T, Device>::calc_grad(const ct::Tensor& prec,
// }
// denghui replace this at 20221106
// TODO: use GPU precondition to initialize CG class
vector_div_vector_op<T, Device>()(ctx_, this->n_basis_, grad.data<T>(), hphi.data<T>(), prec.data<Real>());
vector_div_vector_op<T, Device>()(ctx_, this->n_basis_, pphi.data<T>(), sphi.data<T>(), prec.data<Real>());
ModuleBase::vector_div_vector_op<T, Device>()(ctx_, this->n_basis_, grad.data<T>(), hphi.data<T>(), prec.data<Real>());
ModuleBase::vector_div_vector_op<T, Device>()(ctx_, this->n_basis_, pphi.data<T>(), sphi.data<T>(), prec.data<Real>());

// Update lambda !
// (4) <psi|SPH|psi >
Expand All @@ -247,13 +247,13 @@ void DiagoCG<T, Device>::calc_grad(const ct::Tensor& prec,
// grad.data<T>()[i] -= lambda * this->pphi[i];
// }
// haozhihan replace this 2022-10-6
constantvector_addORsub_constantVector_op<T, Device>()(ctx_,
this->n_basis_,
grad.data<T>(),
grad.data<T>(),
1.0,
pphi.data<T>(),
(-lambda));
ModuleBase::constantvector_addORsub_constantVector_op<T, Device>()(ctx_,
this->n_basis_,
grad.data<T>(),
grad.data<T>(),
1.0,
pphi.data<T>(),
(-lambda));
}

template <typename T, typename Device>
Expand All @@ -264,49 +264,49 @@ void DiagoCG<T, Device>::orth_grad(const ct::Tensor& psi,
ct::Tensor& lagrange)
{
this->spsi_func_(grad, scg); // scg = S|grad>
gemv_op<T, Device>()(ctx_,
'C',
this->n_basis_,
m,
this->one_,
psi.data<T>(),
this->n_basis_,
scg.data<T>(),
1,
this->zero_,
lagrange.data<T>(),
1);
ModuleBase::gemv_op<T, Device>()(ctx_,
'C',
this->n_basis_,
m,
this->one_,
psi.data<T>(),
this->n_basis_,
scg.data<T>(),
1,
this->zero_,
lagrange.data<T>(),
1);

Parallel_Reduce::reduce_pool(lagrange.data<T>(), m);

// (3) orthogonal |g> and |scg> to all states (0~m-1)
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// haozhihan replace 2022-10-07
gemv_op<T, Device>()(ctx_,
'N',
this->n_basis_,
m,
this->neg_one_,
psi.data<T>(),
this->n_basis_,
lagrange.data<T>(),
1,
this->one_,
grad.data<T>(),
1);

gemv_op<T, Device>()(ctx_,
'N',
this->n_basis_,
m,
this->neg_one_,
psi.data<T>(),
this->n_basis_,
lagrange.data<T>(),
1,
this->one_,
scg.data<T>(),
1);
ModuleBase::gemv_op<T, Device>()(ctx_,
'N',
this->n_basis_,
m,
this->neg_one_,
psi.data<T>(),
this->n_basis_,
lagrange.data<T>(),
1,
this->one_,
grad.data<T>(),
1);

ModuleBase::gemv_op<T, Device>()(ctx_,
'N',
this->n_basis_,
m,
this->neg_one_,
psi.data<T>(),
this->n_basis_,
lagrange.data<T>(),
1,
this->one_,
scg.data<T>(),
1);
}

template <typename T, typename Device>
Expand Down Expand Up @@ -342,7 +342,7 @@ void DiagoCG<T, Device>::calc_gamma_cg(const int& iter,
// }
// denghui replace this 20221106
// TODO: use GPU precondition instead
vector_mul_vector_op<T, Device>()(ctx_, this->n_basis_, g0.data<T>(), scg.data<T>(), prec.data<Real>());
ModuleBase::vector_mul_vector_op<T, Device>()(ctx_, this->n_basis_, g0.data<T>(), scg.data<T>(), prec.data<Real>());

// (3) Update gg_now!
// gg_now = < g|P|scg > = < g|g0 >
Expand Down Expand Up @@ -370,13 +370,13 @@ void DiagoCG<T, Device>::calc_gamma_cg(const int& iter,
// pcg[i] = gamma * pcg[i] + grad.data<T>()[i];
// }
// haozhihan replace this 2022-10-6
constantvector_addORsub_constantVector_op<T, Device>()(ctx_,
this->n_basis_,
cg.data<T>(),
cg.data<T>(),
gamma,
grad.data<T>(),
1.0);
ModuleBase::constantvector_addORsub_constantVector_op<T, Device>()(ctx_,
this->n_basis_,
cg.data<T>(),
cg.data<T>(),
gamma,
grad.data<T>(),
1.0);

const Real norma = gamma * cg_norm * sin(theta);
T znorma = static_cast<T>(norma * -1);
Expand All @@ -388,7 +388,7 @@ void DiagoCG<T, Device>::calc_gamma_cg(const int& iter,
{
pcg[i] -= norma * pphi_m[i];
}*/
axpy_op<T, Device>()(ctx_, this->n_basis_, &znorma, phi_m.data<T>(), 1, cg.data<T>(), 1);
ModuleBase::axpy_op<T, Device>()(ctx_, this->n_basis_, &znorma, phi_m.data<T>(), 1, cg.data<T>(), 1);
}
}

Expand Down Expand Up @@ -438,13 +438,13 @@ bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
// }

// haozhihan replace this 2022-10-6
constantvector_addORsub_constantVector_op<T, Device>()(ctx_,
this->n_basis_,
phi_m.data<T>(),
phi_m.data<T>(),
cost,
cg.data<T>(),
sint_norm);
ModuleBase::constantvector_addORsub_constantVector_op<T, Device>()(ctx_,
this->n_basis_,
phi_m.data<T>(),
phi_m.data<T>(),
cost,
cg.data<T>(),
sint_norm);

if (std::abs(eigen - e0) < ethreshold)
{
Expand All @@ -460,20 +460,20 @@ bool DiagoCG<T, Device>::update_psi(const ct::Tensor& pphi,
// }

// haozhihan replace this 2022-10-6
constantvector_addORsub_constantVector_op<T, Device>()(ctx_,
this->n_basis_,
sphi.data<T>(),
sphi.data<T>(),
cost,
scg.data<T>(),
sint_norm);
constantvector_addORsub_constantVector_op<T, Device>()(ctx_,
this->n_basis_,
hphi.data<T>(),
hphi.data<T>(),
cost,
pphi.data<T>(),
sint_norm);
ModuleBase::constantvector_addORsub_constantVector_op<T, Device>()(ctx_,
this->n_basis_,
sphi.data<T>(),
sphi.data<T>(),
cost,
scg.data<T>(),
sint_norm);
ModuleBase::constantvector_addORsub_constantVector_op<T, Device>()(ctx_,
this->n_basis_,
hphi.data<T>(),
hphi.data<T>(),
cost,
pphi.data<T>(),
sint_norm);
return false;
}
}
Expand All @@ -496,36 +496,36 @@ void DiagoCG<T, Device>::schmit_orth(const int& m, const ct::Tensor& psi, const
//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// haozhihan replace 2022-10-6
int inc = 1;
gemv_op<T, Device>()(ctx_,
'C',
this->n_basis_,
m + 1,
this->one_,
psi.data<T>(),
this->n_basis_,
sphi.data<T>(),
inc,
this->zero_,
lagrange_so.data<T>(),
inc);
ModuleBase::gemv_op<T, Device>()(ctx_,
'C',
this->n_basis_,
m + 1,
this->one_,
psi.data<T>(),
this->n_basis_,
sphi.data<T>(),
inc,
this->zero_,
lagrange_so.data<T>(),
inc);

// be careful , here reduce m+1
Parallel_Reduce::reduce_pool(lagrange_so.data<T>(), m + 1);

//<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
// haozhihan replace 2022-10-6
gemv_op<T, Device>()(ctx_,
'N',
this->n_basis_,
m,
this->neg_one_,
psi.data<T>(),
this->n_basis_,
lagrange_so.data<T>(),
inc,
this->one_,
phi_m.data<T>(),
inc);
ModuleBase::gemv_op<T, Device>()(ctx_,
'N',
this->n_basis_,
m,
this->neg_one_,
psi.data<T>(),
this->n_basis_,
lagrange_so.data<T>(),
inc,
this->one_,
phi_m.data<T>(),
inc);

//======================================================================
/*for (int j = 0; j < m; j++)
Expand Down Expand Up @@ -563,7 +563,7 @@ void DiagoCG<T, Device>::schmit_orth(const int& m, const ct::Tensor& psi, const
// {
// pphi_m[ig] /= psi_norm;
// }
vector_div_constant_op<T, Device>()(ctx_, this->n_basis_, phi_m.data<T>(), phi_m.data<T>(), psi_norm);
ModuleBase::vector_div_constant_op<T, Device>()(ctx_, this->n_basis_, phi_m.data<T>(), phi_m.data<T>(), psi_norm);

// ModuleBase::timer::tick("DiagoCG","schmit_orth");
}
Expand Down
Loading

0 comments on commit 0fde07e

Please sign in to comment.