diff --git a/examples/fft_conv.cu b/examples/fft_conv.cu index 7f871119..9833ea83 100644 --- a/examples/fft_conv.cu +++ b/examples/fft_conv.cu @@ -73,7 +73,6 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) { MATX_ENTER_HANDLER(); using complex = cuda::std::complex; - cudaExecutor exec{}; index_t signal_size = 1ULL << 16; index_t filter_size = 16; @@ -87,6 +86,7 @@ int main([[maybe_unused]] int argc, [[maybe_unused]] char **argv) cudaEvent_t start, stop; cudaEventCreate(&start); cudaEventCreate(&stop); + cudaExecutor exec{stream}; // Create time domain buffers auto sig_time = make_tensor({batches, signal_size}); diff --git a/include/matx/core/cache.h b/include/matx/core/cache.h index 795d441b..741f87b6 100644 --- a/include/matx/core/cache.h +++ b/include/matx/core/cache.h @@ -36,6 +36,7 @@ #include #include #include +#include #include #include @@ -50,6 +51,7 @@ using CacheId = uint64_t; __attribute__ ((visibility ("default"))) #endif inline cuda::std::atomic CacheIdCounter{0}; +inline std::recursive_mutex cache_mtx; ///< Mutex protecting updates from map template __attribute__ ((visibility ("default"))) @@ -83,6 +85,8 @@ class matxCache_t { */ template void Clear(const CacheId &id) { + [[maybe_unused]] std::lock_guard lock(cache_mtx); + auto el = cache.find(id); MATX_ASSERT_STR(el != cache.end(), matxInvalidType, "Cache type not found"); @@ -91,6 +95,9 @@ class matxCache_t { template void LookupAndExec(const CacheId &id, const InParams ¶ms, const MakeFun &mfun, const ExecFun &efun) { + // This mutex should eventually be finer-grained so each transform doesn't get blocked by others + [[maybe_unused]] std::lock_guard lock(cache_mtx); + // Create named cache if it doesn't exist auto el = cache.find(id); if (el == cache.end()) { diff --git a/include/matx/transforms/chol/chol_cuda.h b/include/matx/transforms/chol/chol_cuda.h index 01e92c42..e2478acd 100644 --- a/include/matx/transforms/chol/chol_cuda.h +++ b/include/matx/transforms/chol/chol_cuda.h @@ -58,6 +58,7 @@ struct DnCholCUDAParams_t { size_t batch_size; cublasFillMode_t uplo; MatXDataType_t dtype; + cudaExecutor exec; }; template @@ -89,8 +90,9 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t { * Use upper or lower triangle for computation * */ - matxDnCholCUDAPlan_t(const ATensor &a, - cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) + matxDnCholCUDAPlan_t( const ATensor &a, + const cudaExecutor &exec, + cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) @@ -101,9 +103,10 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t { MATX_STATIC_ASSERT_STR(!is_half_v, matxInvalidType, "Cholesky solver does not support half precision"); MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and Output types must match"); - params = GetCholParams(a, uplo); + params = GetCholParams(a, uplo, exec); + this->GetWorkspaceSize(); - this->AllocateWorkspace(params.batch_size, false); + this->AllocateWorkspace(params.batch_size, false, exec); } void GetWorkspaceSize() override @@ -117,13 +120,15 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t { } static DnCholCUDAParams_t GetCholParams(const ATensor &a, - cublasFillMode_t uplo) + cublasFillMode_t uplo, + const cudaExecutor &exec) { DnCholCUDAParams_t params; params.batch_size = GetNumBatches(a); params.n = a.Size(RANK - 1); params.A = a.Data(); params.uplo = uplo; + params.exec = exec; params.dtype = TypeToInt(); return params; @@ -201,7 +206,9 @@ class matxDnCholCUDAPlan_t : matxDnCUDASolver_t { struct DnCholCUDAParamsKeyHash { std::size_t operator()(const DnCholCUDAParams_t &k) const noexcept { - return (std::hash()(k.n)) + (std::hash()(k.batch_size)); + return (std::hash()(k.n)) + + (std::hash()(k.batch_size)) + + (std::hash()((uint64_t)(k.exec.getStream()))); } }; @@ -213,7 +220,10 @@ struct DnCholCUDAParamsKeyEq { bool operator()(const DnCholCUDAParams_t &l, const DnCholCUDAParams_t &t) const noexcept { - return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype; + return l.n == t.n && + l.batch_size == t.batch_size && + l.dtype == t.dtype && + l.exec.getStream() == t.exec.getStream(); } }; @@ -290,14 +300,14 @@ void chol_impl(OutputTensor &&out, const ATensor &a, cublasFillMode_t uplo_cusolver = (uplo == SolverFillMode::UPPER)? CUBLAS_FILL_MODE_UPPER : CUBLAS_FILL_MODE_LOWER; // Get parameters required by these tensors - auto params = detail::matxDnCholCUDAPlan_t::GetCholParams(tmp_out, uplo_cusolver); + auto params = detail::matxDnCholCUDAPlan_t::GetCholParams(tmp_out, uplo_cusolver, exec); using cache_val_type = detail::matxDnCholCUDAPlan_t; detail::GetCache().LookupAndExec( detail::GetCacheIdFromType(), params, [&]() { - return std::make_shared(tmp_out, uplo_cusolver); + return std::make_shared(tmp_out, exec, uplo_cusolver); }, [&](std::shared_ptr ctype) { ctype->Exec(tmp_out, tmp_out, exec, uplo_cusolver); diff --git a/include/matx/transforms/eig/eig_cuda.h b/include/matx/transforms/eig/eig_cuda.h index 8e43a753..c2ae3079 100644 --- a/include/matx/transforms/eig/eig_cuda.h +++ b/include/matx/transforms/eig/eig_cuda.h @@ -61,6 +61,7 @@ struct DnEigCUDAParams_t { void *W; size_t batch_size; MatXDataType_t dtype; + cudaExecutor exec; }; template @@ -98,6 +99,7 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t { */ matxDnEigCUDAPlan_t(WTensor &w, const ATensor &a, + const cudaExecutor &exec, cusolverEigMode_t jobz = CUSOLVER_EIG_MODE_VECTOR, cublasFillMode_t uplo = CUBLAS_FILL_MODE_UPPER) { @@ -113,12 +115,12 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t { MATX_STATIC_ASSERT_STR(!is_complex_v, matxInvalidType, "W type must be real"); MATX_STATIC_ASSERT_STR((std::is_same_v::type, T2>), matxInvalidType, "Out and W inner types must match"); - params = GetEigParams(w, a, jobz, uplo); + params = GetEigParams(w, a, jobz, uplo, exec); this->GetWorkspaceSize(); -#if CUSOLVER_VERSION > 11701 || (CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >=2) - this->AllocateWorkspace(params.batch_size, true); +#if CUSOLVER_VERSION > 11701 || (CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >= 2) + this->AllocateWorkspace(params.batch_size, true, exec); #else - this->AllocateWorkspace(params.batch_size, false); + this->AllocateWorkspace(params.batch_size, false, exec); #endif } @@ -147,7 +149,8 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t { static DnEigCUDAParams_t GetEigParams(WTensor &w, const ATensor &a, cusolverEigMode_t jobz, - cublasFillMode_t uplo) + cublasFillMode_t uplo, + const cudaExecutor &exec) { DnEigCUDAParams_t params; params.batch_size = GetNumBatches(a); @@ -156,6 +159,8 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t { params.W = w.Data(); params.jobz = jobz; params.uplo = uplo; + params.exec = exec; + params.dtype = TypeToInt(); return params; @@ -258,7 +263,7 @@ class matxDnEigCUDAPlan_t : matxDnCUDASolver_t { struct DnEigCUDAParamsKeyHash { std::size_t operator()(const DnEigCUDAParams_t &k) const noexcept { - return (std::hash()(k.n)) + (std::hash()(k.batch_size)); + return (std::hash()(k.n)) + (std::hash()(k.batch_size)) + (std::hash()((uint64_t)(k.exec.getStream()))); } }; @@ -269,7 +274,7 @@ struct DnEigCUDAParamsKeyHash { struct DnEigCUDAParamsKeyEq { bool operator()(const DnEigCUDAParams_t &l, const DnEigCUDAParams_t &t) const noexcept { - return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype; + return l.n == t.n && l.batch_size == t.batch_size && l.dtype == t.dtype && l.exec.getStream() == t.exec.getStream(); } }; @@ -339,7 +344,7 @@ void eig_impl(OutputTensor &&out, WTensor &&w, // Get parameters required by these tensors auto params = detail::matxDnEigCUDAPlan_t:: - GetEigParams(w_new, tv, jobz_cusolver, uplo_cusolver); + GetEigParams(w_new, tv, jobz_cusolver, uplo_cusolver, exec); // Get cache or new eigen plan if it doesn't exist using cache_val_type = detail::matxDnEigCUDAPlan_t; @@ -347,7 +352,7 @@ void eig_impl(OutputTensor &&out, WTensor &&w, detail::GetCacheIdFromType(), params, [&]() { - return std::make_shared(w_new, tv, jobz_cusolver, uplo_cusolver); + return std::make_shared(w_new, tv, exec, jobz_cusolver, uplo_cusolver); }, [&](std::shared_ptr ctype) { ctype->Exec(tv, w_new, tv, exec, jobz_cusolver, uplo_cusolver); diff --git a/include/matx/transforms/eig/eig_lapack.h b/include/matx/transforms/eig/eig_lapack.h index 534ca55a..0b56948a 100644 --- a/include/matx/transforms/eig/eig_lapack.h +++ b/include/matx/transforms/eig/eig_lapack.h @@ -115,7 +115,7 @@ class matxDnEigHostPlan_t : matxDnHostSolver_t { params = GetEigParams(w, a, jobz, uplo); this->GetWorkspaceSize(); - this->AllocateWorkspace(params.batch_size, false); + this->AllocateWorkspace(params.batch_size); } void GetWorkspaceSize() override diff --git a/include/matx/transforms/lu/lu_cuda.h b/include/matx/transforms/lu/lu_cuda.h index cc65998e..861f8aaf 100644 --- a/include/matx/transforms/lu/lu_cuda.h +++ b/include/matx/transforms/lu/lu_cuda.h @@ -59,6 +59,7 @@ struct DnLUCUDAParams_t { void *piv; size_t batch_size; MatXDataType_t dtype; + cudaExecutor exec; }; template @@ -91,7 +92,8 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t { * */ matxDnLUCUDAPlan_t(PivotTensor &piv, - const ATensor &a) + const ATensor &a, + const cudaExecutor &exec) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) @@ -104,9 +106,9 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t { MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and Output types must match"); MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Pivot tensor type must be int64_t"); - params = GetLUParams(piv, a); + params = GetLUParams(piv, a, exec); this->GetWorkspaceSize(); - this->AllocateWorkspace(params.batch_size, false); + this->AllocateWorkspace(params.batch_size, false, exec); } void GetWorkspaceSize() override @@ -120,7 +122,8 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t { } static DnLUCUDAParams_t GetLUParams(PivotTensor &piv, - const ATensor &a) noexcept + const ATensor &a, + const cudaExecutor &exec) noexcept { DnLUCUDAParams_t params; params.batch_size = GetNumBatches(a); @@ -129,7 +132,7 @@ class matxDnLUCUDAPlan_t : matxDnCUDASolver_t { params.A = a.Data(); params.piv = piv.Data(); params.dtype = TypeToInt(); - + params.exec = exec; return params; } @@ -212,7 +215,7 @@ struct DnLUCUDAParamsKeyHash { std::size_t operator()(const DnLUCUDAParams_t &k) const noexcept { return (std::hash()(k.m)) + (std::hash()(k.n)) + - (std::hash()(k.batch_size)); + (std::hash()(k.batch_size)) + (std::hash()((uint64_t)(k.exec.getStream()))); } }; @@ -223,7 +226,7 @@ struct DnLUCUDAParamsKeyEq { bool operator()(const DnLUCUDAParams_t &l, const DnLUCUDAParams_t &t) const noexcept { return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && - l.dtype == t.dtype; + l.dtype == t.dtype && l.exec.getStream() == t.exec.getStream(); } }; @@ -284,7 +287,7 @@ void lu_impl(OutputTensor &&out, PivotTensor &&piv, auto tvt = tv.PermuteMatrix(); // Get parameters required by these tensors - auto params = detail::matxDnLUCUDAPlan_t::GetLUParams(piv_new, tvt); + auto params = detail::matxDnLUCUDAPlan_t::GetLUParams(piv_new, tvt, exec); // Get cache or new LU plan if it doesn't exist using cache_val_type = detail::matxDnLUCUDAPlan_t; @@ -292,7 +295,7 @@ void lu_impl(OutputTensor &&out, PivotTensor &&piv, detail::GetCacheIdFromType(), params, [&]() { - return std::make_shared(piv_new, tvt); + return std::make_shared(piv_new, tvt, exec); }, [&](std::shared_ptr ctype) { ctype->Exec(tvt, piv_new, tvt, exec); diff --git a/include/matx/transforms/qr/qr_cuda.h b/include/matx/transforms/qr/qr_cuda.h index 9377e5bf..076a539f 100644 --- a/include/matx/transforms/qr/qr_cuda.h +++ b/include/matx/transforms/qr/qr_cuda.h @@ -245,6 +245,7 @@ struct DnQRCUDAParams_t { void *tau; size_t batch_size; MatXDataType_t dtype; + cudaExecutor exec; }; template @@ -280,7 +281,8 @@ class matxDnQRCUDAPlan_t : matxDnCUDASolver_t { * */ matxDnQRCUDAPlan_t(TauTensor &tau, - const ATensor &a) + const ATensor &a, + const cudaExecutor &exec) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) @@ -293,9 +295,9 @@ class matxDnQRCUDAPlan_t : matxDnCUDASolver_t { MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "Input and Output types must match"); MATX_STATIC_ASSERT_STR((std::is_same_v), matxInavlidType, "A and Tau types must match"); - params = GetQRParams(tau, a); + params = GetQRParams(tau, a, exec); this->GetWorkspaceSize(); - this->AllocateWorkspace(params.batch_size, false); + this->AllocateWorkspace(params.batch_size, false, exec); } void GetWorkspaceSize() override @@ -308,7 +310,8 @@ class matxDnQRCUDAPlan_t : matxDnCUDASolver_t { } static DnQRCUDAParams_t GetQRParams(TauTensor &tau, - const ATensor &a) + const ATensor &a, + const cudaExecutor &exec) { DnQRCUDAParams_t params; @@ -318,7 +321,7 @@ class matxDnQRCUDAPlan_t : matxDnCUDASolver_t { params.A = a.Data(); params.tau = tau.Data(); params.dtype = TypeToInt(); - + params.exec = exec; return params; } @@ -396,7 +399,7 @@ struct DnQRCUDAParamsKeyHash { std::size_t operator()(const DnQRCUDAParams_t &k) const noexcept { return (std::hash()(k.m)) + (std::hash()(k.n)) + - (std::hash()(k.batch_size)); + (std::hash()(k.batch_size)) + (std::hash()((uint64_t)(k.exec.getStream()))); } }; @@ -407,7 +410,7 @@ struct DnQRCUDAParamsKeyEq { bool operator()(const DnQRCUDAParams_t &l, const DnQRCUDAParams_t &t) const noexcept { return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && - l.dtype == t.dtype; + l.dtype == t.dtype && l.exec.getStream() == t.exec.getStream(); } }; @@ -465,7 +468,7 @@ void qr_solver_impl(OutTensor &&out, TauTensor &&tau, auto tvt = tv.PermuteMatrix(); // Get parameters required by these tensors - auto params = detail::matxDnQRCUDAPlan_t::GetQRParams(tau_new, tvt); + auto params = detail::matxDnQRCUDAPlan_t::GetQRParams(tau_new, tvt, exec); // Get cache or new QR plan if it doesn't exist using cache_val_type = detail::matxDnQRCUDAPlan_t; @@ -473,7 +476,7 @@ void qr_solver_impl(OutTensor &&out, TauTensor &&tau, detail::GetCacheIdFromType(), params, [&]() { - return std::make_shared(tau_new, tvt); + return std::make_shared(tau_new, tvt, exec); }, [&](std::shared_ptr ctype) { ctype->Exec(tvt, tau_new, tvt, exec); diff --git a/include/matx/transforms/qr/qr_lapack.h b/include/matx/transforms/qr/qr_lapack.h index caf8e531..f1a337c7 100644 --- a/include/matx/transforms/qr/qr_lapack.h +++ b/include/matx/transforms/qr/qr_lapack.h @@ -110,7 +110,7 @@ class matxDnQRHostPlan_t : matxDnHostSolver_t { params = GetQRParams(tau, a); this->GetWorkspaceSize(); - this->AllocateWorkspace(params.batch_size, false); + this->AllocateWorkspace(params.batch_size); } void GetWorkspaceSize() override diff --git a/include/matx/transforms/solver_common.h b/include/matx/transforms/solver_common.h index 67d8f46a..3a78225a 100644 --- a/include/matx/transforms/solver_common.h +++ b/include/matx/transforms/solver_common.h @@ -255,38 +255,35 @@ class matxDnCUDASolver_t { cusolverDnDestroy(handle); } - void AllocateWorkspace([[maybe_unused]] size_t batches, [[maybe_unused]] bool batched_api) + void AllocateWorkspace([[maybe_unused]] size_t batches, [[maybe_unused]] bool batched_api, const cudaExecutor &exec) { -#if CUSOLVER_VERSION > 11701 || ( CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >=2) + const auto stream = exec.getStream(); if (batched_api) { // Newer cuSolver if (dspace > 0) { - matxAlloc(&d_workspace, dspace, MATX_DEVICE_MEMORY); + matxAlloc(&d_workspace, dspace, MATX_ASYNC_DEVICE_MEMORY, stream); } // cuSolver has a bug where the workspace needs to be zeroed before using it when the type is complex. // Zero it out for all types for now. - cudaMemset(d_workspace, 0, dspace); - matxAlloc((void **)&d_info, sizeof(*d_info) * batches, MATX_DEVICE_MEMORY); + cudaMemsetAsync(d_workspace, 0, dspace, stream); + matxAlloc((void **)&d_info, sizeof(*d_info) * batches, MATX_ASYNC_DEVICE_MEMORY, stream); if (hspace > 0) { matxAlloc(&h_workspace, hspace, MATX_HOST_MEMORY); } } else { -#endif if (dspace > 0) { - matxAlloc(&d_workspace, batches * dspace, MATX_DEVICE_MEMORY); + matxAlloc(&d_workspace, batches * dspace, MATX_ASYNC_DEVICE_MEMORY, stream); } - matxAlloc((void **)&d_info, batches * sizeof(*d_info), MATX_DEVICE_MEMORY); + matxAlloc((void **)&d_info, batches * sizeof(*d_info), MATX_ASYNC_DEVICE_MEMORY, stream); if (hspace > 0) { matxAlloc(&h_workspace, batches * hspace, MATX_HOST_MEMORY); } -#if CUSOLVER_VERSION > 11701 || ( CUSOLVER_VERSION == 11701 && CUSOLVER_VER_BUILD >=2) } -#endif } virtual void GetWorkspaceSize() = 0; @@ -328,7 +325,7 @@ class matxDnHostSolver_t { matxFree(iwork); } - void AllocateWorkspace([[maybe_unused]] size_t batches, [[maybe_unused]] bool batched_api) + void AllocateWorkspace([[maybe_unused]] size_t batches) { if (lwork > 0) { matxAlloc(&work, lwork * sizeof(ValueType), MATX_HOST_MALLOC_MEMORY); diff --git a/include/matx/transforms/svd/svd_cuda.h b/include/matx/transforms/svd/svd_cuda.h index 0ef4f8b5..9148b336 100644 --- a/include/matx/transforms/svd/svd_cuda.h +++ b/include/matx/transforms/svd/svd_cuda.h @@ -543,6 +543,7 @@ struct DnSVDCUDAParams_t { size_t batch_size; MatXDataType_t dtype; SVDMethod method; + cudaExecutor exec; }; template @@ -626,6 +627,7 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t { VtTensor &vt, const ATensor &a, SVDMethod method, + const cudaExecutor &exec, const char jobz = 'A') { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) @@ -642,7 +644,7 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t { MATX_STATIC_ASSERT_STR(!is_complex_v, matxInvalidType, "S type must be real"); MATX_STATIC_ASSERT_STR((std::is_same_v::type, T3>), matxInvalidType, "A and S inner types must match"); - params = GetSVDParams(u, s, vt, a, jobz); + params = GetSVDParams(u, s, vt, a, jobz, exec); params.method = method; if (params.method == SVDMethod::GESVDJ_BATCHED) { @@ -658,7 +660,7 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t { } this->GetWorkspaceSize(); - this->AllocateWorkspace(params.batch_size, params.method == SVDMethod::GESVDJ_BATCHED); + this->AllocateWorkspace(params.batch_size, params.method == SVDMethod::GESVDJ_BATCHED, exec); } void GetWorkspaceSize() override @@ -722,8 +724,8 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t { static DnSVDCUDAParams_t GetSVDParams(UTensor &u, STensor &s, - VtTensor &vt, const ATensor &a, - const char jobz = 'A') + VtTensor &vt, const ATensor &a, + const char jobz, const cudaExecutor &exec) { DnSVDCUDAParams_t params; params.batch_size = GetNumBatches(a); @@ -735,6 +737,7 @@ class matxDnSVDCUDAPlan_t : matxDnCUDASolver_t { params.S = s.Data(); params.jobz = jobz; params.dtype = TypeToInt(); + params.exec = exec; return params; } @@ -882,7 +885,7 @@ struct DnSVDCUDAParamsKeyHash { std::size_t operator()(const DnSVDCUDAParams_t &k) const noexcept { return (std::hash()(k.m)) + (std::hash()(k.n)) + - (std::hash()(k.batch_size)); + (std::hash()(k.batch_size)) + (std::hash()((uint64_t)(k.exec.getStream()))); } }; @@ -892,7 +895,11 @@ struct DnSVDCUDAParamsKeyHash { struct DnSVDCUDAParamsKeyEq { bool operator()(const DnSVDCUDAParams_t &l, const DnSVDCUDAParams_t &t) const noexcept { - return l.n == t.n && l.m == t.m && l.batch_size == t.batch_size && l.dtype == t.dtype; + return l.n == t.n && + l.m == t.m && + l.batch_size == t.batch_size && + l.dtype == t.dtype && + l.exec.getStream() == t.exec.getStream(); } }; @@ -990,7 +997,7 @@ void svd_impl(UTensor &&u, STensor &&s, // Get parameters required by these tensors auto params = detail::matxDnSVDCUDAPlan_t:: - GetSVDParams(u_in, s_new, vt_in, at_col_maj, job_cusolver); + GetSVDParams(u_in, s_new, vt_in, at_col_maj, job_cusolver, exec); // Get cache or new SVD plan if it doesn't exist using cache_val_type = detail::matxDnSVDCUDAPlan_t; @@ -998,7 +1005,7 @@ void svd_impl(UTensor &&u, STensor &&s, detail::GetCacheIdFromType(), params, [&]() { - return std::make_shared(u_in, s_new, vt_in, at_col_maj, method, job_cusolver); + return std::make_shared(u_in, s_new, vt_in, at_col_maj, method, exec, job_cusolver); }, [&](std::shared_ptr ctype) { ctype->Exec(u_in, s_new, vt_in, at_col_maj, exec, job_cusolver); @@ -1027,7 +1034,7 @@ void svd_impl(UTensor &&u, STensor &&s, // Get parameters required by these tensors auto params = detail::matxDnSVDCUDAPlan_t:: - GetSVDParams(u_col_maj, s_new, vt_col_maj, tvt, job_cusolver); + GetSVDParams(u_col_maj, s_new, vt_col_maj, tvt, job_cusolver, exec); // Get cache or new SVD plan if it doesn't exist using cache_val_type = detail::matxDnSVDCUDAPlan_t; @@ -1035,7 +1042,7 @@ void svd_impl(UTensor &&u, STensor &&s, detail::GetCacheIdFromType(), params, [&]() { - return std::make_shared(u_col_maj, s_new, vt_col_maj, tvt, method, job_cusolver); + return std::make_shared(u_col_maj, s_new, vt_col_maj, tvt, method, exec, job_cusolver); }, [&](std::shared_ptr ctype) { ctype->Exec(u_col_maj, s_new, vt_col_maj, tvt, exec, job_cusolver); diff --git a/include/matx/transforms/svd/svd_lapack.h b/include/matx/transforms/svd/svd_lapack.h index 25d5334a..0621bd1d 100644 --- a/include/matx/transforms/svd/svd_lapack.h +++ b/include/matx/transforms/svd/svd_lapack.h @@ -132,7 +132,7 @@ class matxDnSVDHostPlan_t : matxDnHostSolver_t { params = GetSVDParams(u, s, vt, a, jobz, algo); this->GetWorkspaceSize(); - this->AllocateWorkspace(params.batch_size, false); + this->AllocateWorkspace(params.batch_size); } void GetWorkspaceSize() override