Skip to content

Commit

Permalink
Added Print for operators (#211)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Jul 12, 2022
1 parent cb5f69c commit dd0c145
Show file tree
Hide file tree
Showing 4 changed files with 283 additions and 26 deletions.
31 changes: 5 additions & 26 deletions include/matx_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
using stride_container = typename Desc::stride_container;
using desc_type = Desc; ///< Descriptor type trait
using self_type = tensor_t<T, RANK, Storage, Desc>;
static constexpr bool PRINT_ON_DEVICE = false; ///< Print() uses printf on device

/**
* @brief Construct a new 0-D tensor t object
Expand Down Expand Up @@ -1654,26 +1653,8 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
std::enable_if_t<((std::is_integral_v<Args>)&&...) &&
(RANK == 0 || sizeof...(Args) > 0),
bool> = true>
void Print(Args... dims) const {
#ifdef __CUDACC__
auto kind = GetPointerKind(this->ldata_);
cudaDeviceSynchronize();
if (HostPrintable(kind)) {
InternalPrint(dims...);
}
else if (DevicePrintable(kind) || kind == MATX_INVALID_MEMORY) {
if constexpr (PRINT_ON_DEVICE) {
PrintKernel<<<1, 1>>>(*this, dims...);
}
else {
auto tmpv = make_tensor<T>(this->Shape());
(tmpv = *this).run();
tmpv.Print(dims...);
}
}
#else
InternalPrint(dims...);
#endif
[[deprecated("Use non-member function Print() instead")]] void Print(Args... dims) const {
matx::Print(*this, dims...);
}

/**
Expand All @@ -1689,9 +1670,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
template <typename... Args,
std::enable_if_t<(RANK > 0 && sizeof...(Args) == 0), bool> = true>
void Print(Args... dims) const {
std::array<int, RANK> arr = {0};
auto tp = std::tuple_cat(arr);
std::apply([&](auto &&...args) { this->Print(args...); }, tp);
matx::Print(*this, dims...);
}

/**
Expand All @@ -1718,7 +1697,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
auto tup = std::tuple_cat(arr);
std::apply(
[&](auto&&... args) {
s.InternalPrint(args...);
detail::InternalPrint(s, args...);
}, tup);
}

Expand Down Expand Up @@ -1747,7 +1726,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
auto tup = std::tuple_cat(arr);
std::apply(
[&](auto&&... args) {
s.InternalPrint(args...);
detail::InternalPrint(s, args...);
}, tup);
}

Expand Down
258 changes: 258 additions & 0 deletions include/matx_tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@

#include <cuda/std/tuple>
#include <functional>
#include "matx_make.h"

namespace matx
{
Expand Down Expand Up @@ -359,5 +360,262 @@ namespace detail {
}
}


/**
* Print a value
*
* Type-agnostic function to print a value to stdout
*
* @param val
*/
template <typename T>
__MATX_INLINE__ __MATX_HOST__ void PrintVal(const T &val)
{
if constexpr (is_complex_v<T>) {
printf("%.4e%+.4ej ", static_cast<float>(val.real()),
static_cast<float>(val.imag()));
}
else if constexpr (is_matx_half_v<T> || is_half_v<T>) {
printf("%.4e ", static_cast<float>(val));
}
else if constexpr (std::is_floating_point_v<T>) {
printf("%.4e ", val);
}
else if constexpr (std::is_same_v<T, long long int>) {
printf("%lld ", val);
}
else if constexpr (std::is_same_v<T, int64_t>) {
printf("%" PRId64 " ", val);
}
else if constexpr (std::is_same_v<T, int32_t>) {
printf("%" PRId32 " ", val);
}
else if constexpr (std::is_same_v<T, int16_t>) {
printf("%" PRId16 " ", val);
}
else if constexpr (std::is_same_v<T, int8_t>) {
printf("%" PRId8 " ", val);
}
else if constexpr (std::is_same_v<T, uint64_t>) {
printf("%" PRIu64 " ", val);
}
else if constexpr (std::is_same_v<T, uint32_t>) {
printf("%" PRIu32 " ", val);
}
else if constexpr (std::is_same_v<T, uint16_t>) {
printf("%" PRIu16 " ", val);
}
else if constexpr (std::is_same_v<T, uint8_t>) {
printf("%" PRIu8 " ", val);
}
else if constexpr (std::is_same_v<T, bool>) {
printf("%d ", val);
}
}

/**
* Print a tensor
*
* Type-agnostic function to print a tensor to stdout
*
*/
template <typename Op, typename ... Args>
__MATX_HOST__ void InternalPrint(const Op &op, Args ...dims)
{
MATX_STATIC_ASSERT(op.Rank() == sizeof...(Args), "Number of dimensions to print must match tensor rank");
MATX_STATIC_ASSERT(op.Rank() <= 4, "Printing is only supported on tensors of rank 4 or lower currently");
if constexpr (sizeof...(Args) == 0) {
PrintVal(op.operator()());
printf("\n");
}
else if constexpr (sizeof...(Args) == 1) {
auto& k =detail:: pp_get<0>(dims...);
for (index_t _k = 0; _k < ((k == 0) ? op.Size(0) : k); _k++) {
printf("%06lld: ", _k);
PrintVal(op.operator()(_k));
printf("\n");
}
}
else if constexpr (sizeof...(Args) == 2) {
auto& k = detail::pp_get<0>(dims...);
auto& l = detail::pp_get<1>(dims...);
for (index_t _k = 0; _k < ((k == 0) ? op.Size(0) : k); _k++) {
for (index_t _l = 0; _l < ((l == 0) ? op.Size(1) : l); _l++) {
if (_l == 0)
printf("%06lld: ", _k);

PrintVal(op.operator()(_k, _l));
}
printf("\n");
}
}
else if constexpr (sizeof...(Args) == 3) {
auto& j = detail::pp_get<0>(dims...);
auto& k = detail::pp_get<1>(dims...);
auto& l = detail::pp_get<2>(dims...);
for (index_t _j = 0; _j < ((j == 0) ? op.Size(0) : j); _j++) {
printf("[%06lld,:,:]\n", _j);
for (index_t _k = 0; _k < ((k == 0) ? op.Size(1) : k); _k++) {
for (index_t _l = 0; _l < ((l == 0) ? op.Size(2) : l); _l++) {
if (_l == 0)
printf("%06lld: ", _k);

PrintVal(op.operator()(_j, _k, _l));
}
printf("\n");
}
printf("\n");
}
}
else if constexpr (sizeof...(Args) == 4) {
auto& i = detail::pp_get<0>(dims...);
auto& j = detail::pp_get<1>(dims...);
auto& k = detail::pp_get<2>(dims...);
auto& l = detail::pp_get<3>(dims...);
for (index_t _i = 0; _i < ((i == 0) ? op.Size(0) : i); _i++) {
for (index_t _j = 0; _j < ((j == 0) ? op.Size(1) : j); _j++) {
printf("[%06lld,%06lld,:,:]\n", _i, _j);
for (index_t _k = 0; _k < ((k == 0) ? op.Size(2) : k); _k++) {
for (index_t _l = 0; _l < ((l == 0) ? op.Size(3) : l); _l++) {
if (_l == 0)
printf("%06lld: ", _k);

PrintVal(op.operator()(_i, _j, _k, _l));
}
printf("\n");
}
printf("\n");
}
}
}
}
}

static constexpr bool PRINT_ON_DEVICE = false; ///< Print() uses printf on device

/**
* @brief Print a tensor's values to stdout
*
* This form of `Print()` takes integral values for each index, and prints that as many values
* in each dimension as the arguments specify. For example:
*
* `a.Print(2, 3, 2);`
*
* Will print 2 values of the first, 3 values of the second, and 2 values of the third dimension
* of a 3D tensor. The number of parameters must match the rank of the tensor. A special value of
* 0 can be used if the entire tensor should be printed:
*
* `a.Print(0, 0, 0);` // Prints the whole tensor
*
* For more fine-grained printing, see the over `Print()` overloads.
*
* @tparam Args Integral argument types
* @param dims Number of values to print for each dimension
*/
template <typename Op, typename... Args,
std::enable_if_t<((std::is_integral_v<Args>)&&...) &&
(Op::Rank() == 0 || sizeof...(Args) > 0),
bool> = true>
void Print(const Op &op, Args... dims) {
#ifdef __CUDACC__
if constexpr (is_tensor_view_v<Op>) {
auto kind = GetPointerKind(op.Data());
cudaDeviceSynchronize();
if (HostPrintable(kind)) {
detail::InternalPrint(op, dims...);
}
else if (DevicePrintable(kind) || kind == MATX_INVALID_MEMORY) {
if constexpr (PRINT_ON_DEVICE) {
PrintKernel<<<1, 1>>>(op, dims...);
}
else {
auto tmpv = make_tensor<typename Op::scalar_type>(op.Shape());
(tmpv = op).run();
tmpv.Print(dims...);
}
}
}
else {
InternalPrint(op, dims...);
}
#else
InternalPrint(op, dims...);
#endif
}

/**
* @brief Print a tensor's all values to stdout
*
* This form of `Print()` is an alias of `Print(0)`, `Print(0, 0)`,
* `Print(0, 0, 0)` and `Print(0, 0, 0, 0)` for 1D, 2D, 3D and 4D tensor
* respectively. It passes the proper number of zeros to `Print(...)`
* automatically according to the rank of this tensor. The user only have to
* invoke `.Print()` to print the whole tensor, instead of passing zeros
* manually.
*/
template <typename Op, typename... Args,
std::enable_if_t<(Op::Rank() > 0 && sizeof...(Args) == 0), bool> = true>
void Print(const Op &op, Args... dims) {
std::array<int, Op::Rank()> arr = {0};
auto tp = std::tuple_cat(arr);
std::apply([&](auto &&...args) { Print(op, args...); }, tp);
}

/**
* @brief Print a tensor's values to stdout using start/end parameters
*
* This form of `Print()` takes two array-like lists for the start and end indices, respectively. For
* example:
*
* `a.Print({2, 3}, {matxEnd, 5});`
*
* Will print the 2D tensor `a` with the first dimension starting at index 2 and going to the end, and
* the second index starting at 3 and ending at 5 (exlusive). The format is identical to calling
* `Slice()` to get a sliced view, followed by `Print()` with the indices.
*
* @tparam NRANK Automatically-deduced rank of tensor
* @param start Start indices to print from
* @param end End indices to stop
*/
// template <typename Op, int NRANK>
// void Print(const Op &op, const index_t (&start)[NRANK], const index_t (&end)[NRANK]) const
// {
// auto s = this->Slice(start, end);
// std::array<index_t, NRANK> arr = {0};
// auto tup = std::tuple_cat(arr);
// std::apply(
// [&](auto&&... args) {
// s.InternalPrint(args...);
// }, tup);
// }

/**
* @brief Print a tensor's values to stdout using start/end/stride
*
* This form of `Print()` takes three array-like lists for the start, end, and stride indices, respectively. For
* example:
*
* `a.Print({2, 3}, {matxEnd, 5}, {1, 2});`
*
* Will print the 2D tensor `a` with the first dimension starting at index 2 and going to the end with a
* stride of 1, and the second index starting at 3 and ending at 5 (exlusive) with a stride of 2. The format is
* identical to calling `Slice()` to get a sliced view, followed by `Print()` with the indices.
*
* @tparam NRANK Automatically-deduced rank of tensor
* @param start Start indices to print from
* @param end End indices to stop
* @param strides Strides of each dimension
*/
// template <int NRANK>
// void Print(const index_t (&start)[NRANK], const index_t (&end)[NRANK], const index_t (&strides)[NRANK]) const
// {
// auto s = this->Slice(start, end, strides);
// std::array<index_t, NRANK> arr = {0};
// auto tup = std::tuple_cat(arr);
// std::apply(
// [&](auto&&... args) {
// s.InternalPrint(args...);
// }, tup);
// }

}
9 changes: 9 additions & 0 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2062,3 +2062,12 @@ TEST(OperatorTests, Cast)
MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsFloat, Print)
{
MATX_ENTER_HANDLER();
auto t = make_tensor<TypeParam>({3});
auto r = ones(t.Shape());

Print(r);
MATX_EXIT_HANDLER();
}
11 changes: 11 additions & 0 deletions test/00_tensor/BasicTensorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -406,3 +406,14 @@ TYPED_TEST(BasicTensorTestsIntegral, InitAssign)
MATX_EXIT_HANDLER();
}


TYPED_TEST(BasicTensorTestsAll, Print)
{
MATX_ENTER_HANDLER();

auto t = make_tensor<TypeParam>({3});
(t = ones(t.Shape())).run();
t.Print();

MATX_EXIT_HANDLER();
}

0 comments on commit dd0c145

Please sign in to comment.