Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Duplicate Print and remove member prints #364

Merged
merged 3 commits into from
Jan 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 0 additions & 41 deletions include/matx/core/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1746,48 +1746,7 @@ class tensor_t : public detail::tensor_impl_t<T,RANK,Desc> {
}
}

/**
* @brief Print a tensor's values to stdout
*
* This form of `Print()` takes integral values for each index, and prints that as many values
* in each dimension as the arguments specify. For example:
*
* `a.Print(2, 3, 2);`
*
* Will print 2 values of the first, 3 values of the second, and 2 values of the third dimension
* of a 3D tensor. The number of parameters must match the rank of the tensor. A special value of
* 0 can be used if the entire tensor should be printed:
*
* `a.Print(0, 0, 0);` // Prints the whole tensor
*
* For more fine-grained printing, see the over `Print()` overloads.
*
* @tparam Args Integral argument types
* @param dims Number of values to print for each dimension
*/
template <typename... Args,
std::enable_if_t<((std::is_integral_v<Args>)&&...) &&
(RANK == 0 || sizeof...(Args) > 0),
bool> = true>
void Print(Args... dims) const {
matx::Print(*this, dims...);
}

/**
* @brief Print a tensor's all values to stdout
*
* This form of `Print()` is an alias of `Print(0)`, `Print(0, 0)`,
* `Print(0, 0, 0)` and `Print(0, 0, 0, 0)` for 1D, 2D, 3D and 4D tensor
* respectively. It passes the proper number of zeros to `Print(...)`
* automatically according to the rank of this tensor. The user only have to
* invoke `.Print()` to print the whole tensor, instead of passing zeros
* manually.
*/
template <typename... Args,
std::enable_if_t<(RANK > 0 && sizeof...(Args) == 0), bool> = true>
void Print(Args... dims) const {
matx::Print(*this, dims...);
}

/**
* @brief Print a tensor's values to stdout using start/end parameters
Expand Down
137 changes: 54 additions & 83 deletions include/matx/core/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -627,7 +627,58 @@ static constexpr bool PRINT_ON_DEVICE = false; ///< Print() uses printf on
/**
* @brief Print a tensor's values to stdout
*
* This form of `Print()` takes integral values for each index, and prints that as many values
* This is a wrapper utility function to print the type, size and stride of tensor,
* see PrintData for details of internal tensor printing options
*
* @tparam Args Integral argument types
* @param op input Operator
* @param dims Number of values to print for each dimension
*/
template <typename Op, typename... Args,
std::enable_if_t<((std::is_integral_v<Args>)&&...) &&
(Op::Rank() == 0 || sizeof...(Args) > 0),
bool> = true>
void Print(const Op &op, Args... dims)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

// print tensor size info first
std::string type = (is_tensor_view_v<Op>) ? "Tensor" : "Operator";

printf("%s{%s} Rank: %d, Sizes:[", type.c_str(), detail::GetTensorType<typename Op::scalar_type>().c_str(), op.Rank());

for (index_t dimIdx = 0; dimIdx < (op.Rank() ); dimIdx++ )
{
printf("%lld", op.Size(static_cast<int>(dimIdx)) );
if( dimIdx < (op.Rank() - 1) )
printf(", ");
}

if constexpr (is_tensor_view_v<Op>)
{
printf("], Strides:[");
if constexpr (Op::Rank() > 0)
{
for (index_t dimIdx = 0; dimIdx < (op.Rank() ); dimIdx++ )
{
printf("%lld", op.Stride(static_cast<int>(dimIdx)) );
if( dimIdx < (op.Rank() - 1) )
{
printf(",");
}
}
}
}

printf("]\n");
PrintData(op, dims...);

}

/**
* @brief Print a tensor's values to stdout
*
* This is the interal `Print()` takes integral values for each index, and prints that as many values
* in each dimension as the arguments specify. For example:
*
* `a.Print(2, 3, 2);`
Expand All @@ -648,33 +699,12 @@ template <typename Op, typename... Args,
std::enable_if_t<((std::is_integral_v<Args>)&&...) &&
(Op::Rank() == 0 || sizeof...(Args) > 0),
bool> = true>
void Print(const Op &op, Args... dims) {
void PrintData(const Op &op, Args... dims) {
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

// print tensor size info first
std::string type = (is_tensor_view_v<Op>) ? "Tensor" : "Operator";

printf("%s{%s} Rank: %d, Sizes:[", type.c_str(), detail::GetTensorType<typename Op::scalar_type>().c_str(), op.Rank());

for (index_t dimIdx = 0; dimIdx < (op.Rank() ); dimIdx++ ){
printf("%lld", op.Size(static_cast<int>(dimIdx)) );
if( dimIdx < (op.Rank() - 1) )
printf(", ");
}

#ifdef __CUDACC__
if constexpr (is_tensor_view_v<Op>) {

printf("], Strides:[");
if constexpr (Op::Rank() > 0) {
for (index_t dimIdx = 0; dimIdx < (op.Rank() ); dimIdx++ ) {
printf("%lld", op.Stride(static_cast<int>(dimIdx)) );
if( dimIdx < (op.Rank() - 1) )
printf(",");
}
}
printf("]\n");

auto kind = GetPointerKind(op.Data());
cudaDeviceSynchronize();
if (HostPrintable(kind)) {
Expand All @@ -687,16 +717,14 @@ void Print(const Op &op, Args... dims) {
else {
auto tmpv = make_tensor<typename Op::scalar_type>(op.Shape());
(tmpv = op).run();
tmpv.Print(dims...);
PrintData(tmpv, dims...);
}
}
}
else {
printf("]\n");
InternalPrint(op, dims...);
}
#else
printf("]\n");
InternalPrint(op, dims...);
#endif
}
Expand All @@ -719,61 +747,4 @@ void Print(const Op &op, Args... dims) {
std::apply([&](auto &&...args) { Print(op, args...); }, tp);
}

/**
* @brief Print a tensor's values to stdout using start/end parameters
*
* This form of `Print()` takes two array-like lists for the start and end indices, respectively. For
* example:
*
* `a.Print({2, 3}, {matxEnd, 5});`
*
* Will print the 2D tensor `a` with the first dimension starting at index 2 and going to the end, and
* the second index starting at 3 and ending at 5 (exlusive). The format is identical to calling
* `Slice()` to get a sliced view, followed by `Print()` with the indices.
*
* @tparam NRANK Automatically-deduced rank of tensor
* @param start Start indices to print from
* @param end End indices to stop
*/
// template <typename Op, int NRANK>
// void Print(const Op &op, const index_t (&start)[NRANK], const index_t (&end)[NRANK]) const
// {
// auto s = this->Slice(start, end);
// std::array<index_t, NRANK> arr = {0};
// auto tup = std::tuple_cat(arr);
// std::apply(
// [&](auto&&... args) {
// s.InternalPrint(args...);
// }, tup);
// }

/**
* @brief Print a tensor's values to stdout using start/end/stride
*
* This form of `Print()` takes three array-like lists for the start, end, and stride indices, respectively. For
* example:
*
* `a.Print({2, 3}, {matxEnd, 5}, {1, 2});`
*
* Will print the 2D tensor `a` with the first dimension starting at index 2 and going to the end with a
* stride of 1, and the second index starting at 3 and ending at 5 (exlusive) with a stride of 2. The format is
* identical to calling `Slice()` to get a sliced view, followed by `Print()` with the indices.
*
* @tparam NRANK Automatically-deduced rank of tensor
* @param start Start indices to print from
* @param end End indices to stop
* @param strides Strides of each dimension
*/
// template <int NRANK>
// void Print(const index_t (&start)[NRANK], const index_t (&end)[NRANK], const index_t (&strides)[NRANK]) const
// {
// auto s = this->Slice(start, end, strides);
// std::array<index_t, NRANK> arr = {0};
// auto tup = std::tuple_cat(arr);
// std::apply(
// [&](auto&&... args) {
// s.InternalPrint(args...);
// }, tup);
// }

}