From 89602b21037b0ad6fa5afd10b6ce66fa86a48a5e Mon Sep 17 00:00:00 2001 From: Cliff Burdick Date: Thu, 8 Jun 2023 13:56:33 -0700 Subject: [PATCH] Fixed print function to work on device in certain cases --- include/matx/core/tensor_utils.h | 40 +++++++++++++++++++++----------- 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/include/matx/core/tensor_utils.h b/include/matx/core/tensor_utils.h index 9ef42fb0..784b511b 100644 --- a/include/matx/core/tensor_utils.h +++ b/include/matx/core/tensor_utils.h @@ -40,6 +40,8 @@ #include "matx/core/make_tensor.h" #include "matx/kernels/utility.cuh" +static constexpr bool PRINT_ON_DEVICE = false; ///< print() uses printf on device + namespace matx { /** @@ -667,10 +669,23 @@ namespace detail { } } } -} // end namespace detail - -static constexpr bool PRINT_ON_DEVICE = false; ///< print() uses printf on device + template )&&...) && + (Op::Rank() == 0 || sizeof...(Args) > 0), + bool> = true> + void DevicePrint(const Op &op, Args... dims) { + if constexpr (PRINT_ON_DEVICE) { + PrintKernel<<<1, 1>>>(op, dims...); + } + else { + auto tmpv = make_tensor(op.Shape()); + (tmpv = op).run(); + PrintData(tmpv, dims...); + } + } +} // end namespace detail /** * @brief Print a tensor's values to stdout @@ -714,22 +729,21 @@ void PrintData(const Op &op, Args... dims) { data, reinterpret_cast(op.Data())); MATX_ASSERT_STR_EXP(ret, CUDA_SUCCESS, matxCudaError, "Failed to get memory type"); - MATX_ASSERT_STR(mtype == CU_MEMORYTYPE_HOST || mtype == 0, matxNotSupported, "Invalid memory type for printing"); + MATX_ASSERT_STR(mtype == CU_MEMORYTYPE_HOST || mtype == 0 || mtype == CU_MEMORYTYPE_DEVICE, + matxNotSupported, "Invalid memory type for printing"); - detail::InternalPrint(op, dims...); + if (mtype == CU_MEMORYTYPE_DEVICE) { + detail::DevicePrint(op, dims...); + } + else { + detail::InternalPrint(op, dims...); + } } else if (kind == MATX_INVALID_MEMORY || HostPrintable(kind)) { detail::InternalPrint(op, dims...); } else if (DevicePrintable(kind)) { - if constexpr (PRINT_ON_DEVICE) { - PrintKernel<<<1, 1>>>(op, dims...); - } - else { - auto tmpv = make_tensor(op.Shape()); - (tmpv = op).run(); - PrintData(tmpv, dims...); - } + detail::DevicePrint(op, dims...); } } else {