Skip to content

Commit

Permalink
16-bit float reductions + updated softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick committed Apr 1, 2023
1 parent 13ee40f commit 3809f72
Show file tree
Hide file tree
Showing 10 changed files with 428 additions and 148 deletions.
17 changes: 7 additions & 10 deletions bench/00_operators/reduction.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,17 @@ template <typename ValueType>
void softmax(nvbench::state &state, nvbench::type_list<ValueType>)
{
// Get current parameters:
auto t4 = make_tensor<ValueType>({1,10845,8,16});
auto t4out = make_tensor<ValueType>({1,10845,8,16});
t4.PrefetchDevice(0);
t4out.PrefetchDevice(0);


auto t2 = make_tensor<ValueType>({86760, 16});
auto t2out = make_tensor<ValueType>({86760, 16});
t2.PrefetchDevice(0);
t2out.PrefetchDevice(0);

softmax(t2out, t2, {1});
softmax(t4out, t4, {3});

state.exec(
[&t2, &t2out](nvbench::launch &launch) {
matx::softmax(t2out, t2, (cudaStream_t)launch.get_stream());
[&t4, &t4out](nvbench::launch &launch) {
matx::softmax(t4out, t4, (cudaStream_t)launch.get_stream());
});

}
NVBENCH_BENCH_TYPES(softmax, NVBENCH_TYPE_AXES(softmax_types));

Expand Down
16 changes: 8 additions & 8 deletions include/matx/core/iterator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ namespace matx {
template <typename OperatorType>
struct RandomOperatorIterator {
using self_type = RandomOperatorIterator<OperatorType>;
using value_type = typename OperatorType::scalar_type;
using scalar_type = typename OperatorType::scalar_type;
using value_type = detail::convert_matx_type_t<typename OperatorType::scalar_type>;
using scalar_type = value_type;
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
// index_t>;
using stride_type = index_t;
Expand All @@ -67,12 +67,12 @@ struct RandomOperatorIterator {
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ value_type operator*() const
{
if constexpr (OperatorType::Rank() == 0) {
return t_.operator()();
return static_cast<value_type>(t_.operator()());
}
else {
auto arrs = detail::GetIdxFromAbs(t_, offset_);
return detail::mapply([&](auto &&...args) {
return t_.operator()(args...);
return static_cast<value_type>(t_.operator()(args...));
}, arrs);
}
}
Expand Down Expand Up @@ -145,8 +145,8 @@ struct RandomOperatorIterator {
template <typename OperatorType>
struct RandomOperatorOutputIterator {
using self_type = RandomOperatorOutputIterator<OperatorType>;
using value_type = typename OperatorType::scalar_type;
using scalar_type = typename OperatorType::scalar_type;
using value_type = detail::convert_matx_type_t<typename OperatorType::scalar_type>;
using scalar_type = value_type;
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
// index_t>;
using stride_type = index_t;
Expand All @@ -161,13 +161,13 @@ struct RandomOperatorOutputIterator {
[[nodiscard]] __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ reference operator*()
{
if constexpr (OperatorType::Rank() == 0) {
return t_.operator()();
return (reference)(t_.operator()());
}
else {
auto arrs = detail::GetIdxFromAbs(t_, offset_);

return std::apply([&](auto &&...args) -> reference {
return t_.operator()(args...);
return (reference)(t_.operator()(args...));
}, arrs);
}
}
Expand Down
28 changes: 14 additions & 14 deletions include/matx/core/tensor_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,14 +161,14 @@ namespace detail {
template <typename Func, typename Tuple, size_t... S>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ decltype(auto) apply_impl(Func &&f, Tuple&& tuple, std::index_sequence<S...>) {

if constexpr (is_std_tuple<std::remove_reference_t<Tuple>>::value || is_std_array<std::remove_reference_t<Tuple>>::value) {
if constexpr (is_std_tuple<remove_cvref_t<Tuple>>::value || is_std_array<remove_cvref_t<Tuple>>::value) {
return cuda::std::invoke(std::forward<Func>(f), std::get<S>(std::forward<Tuple>(tuple))...);
}
else {
return cuda::std::invoke(std::forward<Func>(f), cuda::std::get<S>(std::forward<Tuple>(tuple))...);
}

if constexpr (!(is_std_tuple<std::remove_reference_t<Tuple>>::value || is_std_array<std::remove_reference_t<Tuple>>::value)) {
if constexpr (!(is_std_tuple<remove_cvref_t<Tuple>>::value || is_std_array<remove_cvref_t<Tuple>>::value)) {
return cuda::std::invoke(std::forward<Func>(f), cuda::std::get<S>(std::forward<Tuple>(tuple))...);
}
else {
Expand All @@ -179,52 +179,52 @@ namespace detail {
template <class Func, class Tuple>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ constexpr decltype(auto) mapply(Func&& f, Tuple&& t)
{
if constexpr (is_std_tuple<std::remove_reference_t<Tuple>>::value || is_std_array<std::remove_reference_t<Tuple>>::value) {
if constexpr (is_std_tuple<remove_cvref_t<Tuple>>::value || is_std_array<remove_cvref_t<Tuple>>::value) {
return apply_impl(
std::forward<Func>(f), std::forward<Tuple>(t),
std::make_index_sequence<std::tuple_size_v<std::remove_reference_t<Tuple>>>{});
std::make_index_sequence<std::tuple_size_v<remove_cvref_t<Tuple>>>{});
}
else {
return apply_impl(
std::forward<Func>(f), std::forward<Tuple>(t),
std::make_index_sequence<cuda::std::tuple_size_v<std::remove_reference_t<Tuple>>>{});
std::make_index_sequence<cuda::std::tuple_size_v<remove_cvref_t<Tuple>>>{});
}

if constexpr (!(is_std_tuple<std::remove_reference_t<Tuple>>::value || is_std_array<std::remove_reference_t<Tuple>>::value)) {
if constexpr (!(is_std_tuple<remove_cvref_t<Tuple>>::value || is_std_array<remove_cvref_t<Tuple>>::value)) {
return apply_impl(
std::forward<Func>(f), std::forward<Tuple>(t),
std::make_index_sequence<cuda::std::tuple_size_v<std::remove_reference_t<Tuple>>>{});
std::make_index_sequence<cuda::std::tuple_size_v<remove_cvref_t<Tuple>>>{});
}
else {
return apply_impl(
std::forward<Func>(f), std::forward<Tuple>(t),
std::make_index_sequence<std::tuple_size_v<std::remove_reference_t<Tuple>>>{});
std::make_index_sequence<std::tuple_size_v<remove_cvref_t<Tuple>>>{});
}
}

template <class Func, class Tuple>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ constexpr decltype(auto) mapply_reverse(Func&& f, Tuple&& t)
{
if constexpr (is_std_tuple<std::remove_reference_t<Tuple>>::value || is_std_array<std::remove_reference_t<Tuple>>::value) {
if constexpr (is_std_tuple<remove_cvref_t<Tuple>>::value || is_std_array<remove_cvref_t<Tuple>>::value) {
return apply_impl(
std::forward<Func>(f), std::forward<Tuple>(t),
make_index_sequence_rev<std::tuple_size_v<std::remove_reference_t<Tuple>>>{});
make_index_sequence_rev<std::tuple_size_v<remove_cvref_t<Tuple>>>{});
}
else {
return apply_impl(
std::forward<Func>(f), std::forward<Tuple>(t),
make_index_sequence_rev<cuda::std::tuple_size_v<std::remove_reference_t<Tuple>>>{});
make_index_sequence_rev<cuda::std::tuple_size_v<remove_cvref_t<Tuple>>>{});
}

if constexpr (!(is_std_tuple<std::remove_reference_t<Tuple>>::value || is_std_array<std::remove_reference_t<Tuple>>::value)) {
if constexpr (!(is_std_tuple<remove_cvref_t<Tuple>>::value || is_std_array<remove_cvref_t<Tuple>>::value)) {
return apply_impl(
std::forward<Func>(f), std::forward<Tuple>(t),
make_index_sequence_rev<cuda::std::tuple_size_v<std::remove_reference_t<Tuple>>>{});
make_index_sequence_rev<cuda::std::tuple_size_v<remove_cvref_t<Tuple>>>{});
}
else {
return apply_impl(
std::forward<Func>(f), std::forward<Tuple>(t),
make_index_sequence_rev<std::tuple_size_v<std::remove_reference_t<Tuple>>>{});
make_index_sequence_rev<std::tuple_size_v<remove_cvref_t<Tuple>>>{});
}
}

Expand Down
28 changes: 26 additions & 2 deletions include/matx/core/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ struct remove_cvref {
using type = std::remove_cv_t<std::remove_reference_t<T>>; ///< Type after removal
};

template <typename T>
using remove_cvref_t = typename remove_cvref<T>::type;

template <typename T, int RANK, typename Desc> class tensor_impl_t;
template <typename T, int RANK, typename Storage, typename Desc> class tensor_t;

Expand Down Expand Up @@ -492,6 +495,24 @@ using promote_half_t = typename std::conditional_t<is_half_v<T>, float, T>;


namespace detail {

template <typename T>
struct convert_matx_type {
using type = T;
};

template <>
struct convert_matx_type<matxFp16> {
using type = __half;
};

template <>
struct convert_matx_type<matxBf16> {
using type = __nv_bfloat16;
};

template <typename T>
using convert_matx_type_t = typename convert_matx_type<T>::type;

template <class T, std::size_t N, std::size_t... I>
constexpr std::array<std::remove_cv_t<T>, N>
Expand Down Expand Up @@ -554,8 +575,11 @@ template <typename T> using value_promote_t = promote_half_t<value_type_t<T>>;

template <typename> struct is_std_tuple: std::false_type {};
template <typename ...T> struct is_std_tuple<std::tuple<T...>>: std::true_type {};
template <typename T> struct is_std_array : std::false_type {};
template <typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type {};

template<typename T> struct is_std_array : std::false_type {};
template<typename T, size_t N> struct is_std_array<std::array<T, N>> : std::true_type {};
template <typename T> inline constexpr bool is_std_array_v = detail::is_std_array<remove_cvref_t<T>>::value;



// Get the n-th element from a parameter pack
Expand Down
45 changes: 27 additions & 18 deletions include/matx/operators/binary_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ namespace matx
// dummy type to signal this is a matxop
using matxop = bool;
using scalar_type = typename Op::scalar_type;
using self_type = matxBinaryOp<I1, I2, Op>;

__MATX_INLINE__ const std::string str() const {
return op_.str(get_type_str(in1_), get_type_str(in2_));
Expand All @@ -116,26 +117,34 @@ namespace matx
}
}

template <typename... Is>
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto operator()(Is... indices) const
{
// Rank 0
auto i1 = get_value(in1_, indices...);
auto i2 = get_value(in2_, indices...);
return op_(i1, i2);
}
template <typename... Is, std::enable_if_t<std::conjunction_v<std::is_integral<Is>...>, bool> = true>
__MATX_DEVICE__ __MATX_HOST__ __MATX_INLINE__ auto operator()(Is... indices) const
{
auto i1 = get_value(in1_, indices...);
auto i2 = get_value(in2_, indices...);
return op_(i1, i2);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return detail::matx_max(detail::get_rank<I1>(), detail::get_rank<I2>());
}
template <typename ArrayType, std::enable_if_t<is_std_array_v<ArrayType>, bool> = true>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ const auto operator()(const ArrayType &idx) const noexcept
{
return mapply([&](auto &&...args) {
return this->operator()(args...);
}, idx);
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto Size(int dim) const noexcept
{
index_t size1 = detail::get_expanded_size<Rank()>(in1_, dim);
index_t size2 = detail::get_expanded_size<Rank()>(in2_, dim);
return detail::matx_max(size1,size2);
}

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return detail::matx_max(detail::get_rank<I1>(), detail::get_rank<I2>());
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto Size(int dim) const noexcept
{
index_t size1 = detail::get_expanded_size<Rank()>(in1_, dim);
index_t size2 = detail::get_expanded_size<Rank()>(in2_, dim);
return detail::matx_max(size1,size2);
}
};
}

Expand Down
2 changes: 1 addition & 1 deletion include/matx/operators/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ namespace matx
template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
std::array<index_t, int(sizeof...(Is))> inds{indices...};
std::array<index_t, sizeof...(Is)> inds{indices...};
return inds[dim_];
}

Expand Down
10 changes: 9 additions & 1 deletion include/matx/operators/unary_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ namespace matx
// dummy type to signal this is a matxop
using matxop = bool;
using scalar_type = typename Op::scalar_type;
using self_type = matxUnaryOp<I1, Op>;

__MATX_INLINE__ const std::string str() const {
return op_.str() + "(" + get_type_str(in1_) + ")";
Expand All @@ -78,7 +79,14 @@ namespace matx
}
}

template <typename... Is>
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ const auto operator()(const std::array<index_t, detail::get_rank<I1>()> &idx) const noexcept
{
return mapply([&](auto &&...args) {
return this->operator()(args...);
}, idx);
}

template <typename... Is, std::enable_if_t<std::conjunction_v<std::is_integral<Is>...>, bool> = true>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
auto i1 = get_value(in1_, indices...);
Expand Down
17 changes: 12 additions & 5 deletions include/matx/transforms/cub.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,14 +113,15 @@ struct EmptyParams_t {};
template <typename OperatorType>
struct BeginOffset {
using self_type = BeginOffset<OperatorType>;
using value_type = typename OperatorType::scalar_type;
using value_type = detail::convert_matx_type_t<typename OperatorType::scalar_type>;
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
// index_t>;
using stride_type = index_t;
using pointer = value_type*;
using reference = value_type;
using iterator_category = std::random_access_iterator_tag;
using difference_type = index_t;


__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ BeginOffset(const OperatorType &t) : size_(t.Size(t.Rank() - 1)), offset_(0) { }
__MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ BeginOffset(const OperatorType &t, stride_type offset) : size_(t.Size(t.Rank() - 1)), offset_(offset) {}
Expand Down Expand Up @@ -172,7 +173,7 @@ struct BeginOffset {
template <typename OperatorType>
struct EndOffset {
using self_type = BeginOffset<OperatorType>;
using value_type = typename OperatorType::scalar_type;
using value_type = detail::convert_matx_type_t<typename OperatorType::scalar_type>;
// using stride_type = std::conditional_t<is_tensor_view_v<OperatorType>, typename OperatorType::desc_type::stride_type,
// index_t>;
using stride_type = index_t;
Expand Down Expand Up @@ -290,11 +291,14 @@ class matxCubPlan_t {
}


/* Convert the output type into the optimized type for the reduction, and run the reduction function */
template <typename Func, typename OutputOp, typename InputOp, typename BeginIter, typename EndIter>
static inline void ReduceOutput(Func &&func, OutputOp &&out, InputOp &&in, BeginIter &&bi, EndIter &&ei) {
using dtype = detail::convert_matx_type_t<typename remove_cvref_t<OutputOp>::scalar_type>;

if constexpr (out.Rank() <= 1 && is_tensor_view_v<OutputOp>) {
if (out.IsContiguous()) {
auto res = func(in, out.Data(), bi, ei);
auto res = func(in, reinterpret_cast<dtype*>(out.Data()), bi, ei);
MATX_ASSERT_STR_EXP(res, cudaSuccess, matxCudaError, "Error when calling CUB reduction function");
return;
}
Expand All @@ -306,12 +310,15 @@ class matxCubPlan_t {
MATX_ASSERT_STR_EXP(res, cudaSuccess, matxCudaError, "Error when calling CUB reduction function");
}

/* Convert the input type to the optimal input type for the reduction, and call the output reduce stage */
template <typename Func, typename OutputOp, typename InputOp>
static inline void ReduceInput(Func &&func, OutputOp &&out, InputOp &&in) {
using dtype = detail::convert_matx_type_t<typename remove_cvref_t<InputOp>::scalar_type>;
typename detail::base_type_t<InputOp> in_base = in;

if constexpr (in_base.Rank() <= 2 && is_tensor_view_v<InputOp>) {
if (in_base.IsContiguous()) {
ReduceOutput(std::forward<Func>(func), std::forward<OutputOp>(out), in_base.Data(), BeginOffset{in_base}, EndOffset{in_base});
ReduceOutput(std::forward<Func>(func), std::forward<OutputOp>(out), reinterpret_cast<dtype*>(in_base.Data()), BeginOffset{in_base}, EndOffset{in_base});
return;
}
}
Expand Down Expand Up @@ -980,7 +987,7 @@ inline void ExecSort(OutputTensor &a_out,
#ifdef __CUDACC__
MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL)
typename detail::base_type_t<InputOperator> in_base = a;
typename detail::base_type_t<OutputTensor> out_base = a_out;
typename detail::base_type_t<OutputTensor> out_base = a_out;

// Check whether this is a segmented reduction or single-value output. Segmented reductions are any
// type of reduction where there's not a single output, since any type of reduction can be generalized
Expand Down
Loading

0 comments on commit 3809f72

Please sign in to comment.