Skip to content

Commit

Permalink
Reenable CUB tests (#188)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Jun 2, 2022
1 parent 4eec811 commit d194e7e
Showing 1 changed file with 14 additions and 16 deletions.
30 changes: 14 additions & 16 deletions include/matx_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -1064,15 +1064,14 @@ template <typename TensorType, typename InType, typename ReduceOp>
void inline reduce(TensorType &dest, const InType &in, ReduceOp op,
cudaStream_t stream = 0, [[maybe_unused]] bool init = true)
{
// Disable CUB until bug using 1D outputs is resolved
// constexpr bool use_cub = TensorType::Rank() == 0 || (TensorType::Rank() == 1 && InType::Rank() == 2);
// // Use CUB implementation if we have a tensor on the RHS and it's not blocked from using CUB
// if constexpr (!is_matx_no_cub_reduction_v<ReduceOp> && use_cub) {
// cub_reduce<TensorType, InType, ReduceOp>(dest, in, op.Init(), stream);
// }
// else { // Fall back to the slow path of custom implementation
constexpr bool use_cub = TensorType::Rank() == 0 || (TensorType::Rank() == 1 && InType::Rank() == 2);
// Use CUB implementation if we have a tensor on the RHS and it's not blocked from using CUB
if constexpr (!is_matx_no_cub_reduction_v<ReduceOp> && use_cub) {
cub_reduce<TensorType, InType, ReduceOp>(dest, in, op.Init(), stream);
}
else { // Fall back to the slow path of custom implementation
reduce(dest, std::nullopt, in, op, stream, init);
//}
}
}

/**
Expand Down Expand Up @@ -1214,15 +1213,14 @@ template <typename TensorType, typename InType>
void inline sum(TensorType &dest, const InType &in, cudaStream_t stream = 0)
{
#ifdef __CUDACC__
// Disable CUB until bug using 1D outputs is resolved
// constexpr bool use_cub = TensorType::Rank() == 0 || (TensorType::Rank() == 1 && InType::Rank() == 2);
// // Use CUB implementation if we have a tensor on the RHS
// if constexpr (use_cub) {
// cub_sum<TensorType, InType>(dest, in, stream);
// }
// else { // Fall back to the slow path of custom implementation
constexpr bool use_cub = TensorType::Rank() == 0 || (TensorType::Rank() == 1 && InType::Rank() == 2);
// Use CUB implementation if we have a tensor on the RHS
if constexpr (use_cub) {
cub_sum<TensorType, InType>(dest, in, stream);
}
else { // Fall back to the slow path of custom implementation
reduce(dest, in, detail::reduceOpSum<typename TensorType::scalar_type>(), stream, true);
//}
}
#endif
}

Expand Down

0 comments on commit d194e7e

Please sign in to comment.