Skip to content

Commit

Permalink
Expose tensor format (and lvl specs) to sparse tensor data (#833)
Browse files Browse the repository at this point in the history
  • Loading branch information
aartbik authored Jan 21, 2025
1 parent 3ec2306 commit 75e70bc
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 23 deletions.
8 changes: 4 additions & 4 deletions include/matx/core/sparse_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ template <typename VAL, typename CRD, typename POS, typename TF,
typename StorageC = DefaultStorage<CRD>,
typename StorageP = DefaultStorage<POS>,
typename DimDesc = DefaultDescriptor<TF::DIM>>
class sparse_tensor_t : public detail::tensor_impl_t<
VAL, TF::DIM, DimDesc,
detail::SparseTensorData<VAL, CRD, POS, TF::LVL>> {
class sparse_tensor_t
: public detail::tensor_impl_t<
VAL, TF::DIM, DimDesc, detail::SparseTensorData<VAL, CRD, POS, TF>> {
public:
using sparse_tensor = bool;
static constexpr int DIM = TF::DIM;
Expand All @@ -79,7 +79,7 @@ class sparse_tensor_t : public detail::tensor_impl_t<
sparse_tensor_t(const typename DimDesc::shape_type (&shape)[DIM],
StorageV &&vals, StorageC (&&crd)[LVL], StorageP (&&pos)[LVL])
: detail::tensor_impl_t<VAL, DIM, DimDesc,
detail::SparseTensorData<VAL, CRD, POS, LVL>>(
detail::SparseTensorData<VAL, CRD, POS, TF>>(
shape) {
// Initialize primary and secondary storage.
values_ = std::move(vals);
Expand Down
38 changes: 21 additions & 17 deletions include/matx/core/sparse_tensor_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ template <typename Expr, LvlType ltype> class LvlSpec {
//
template <int D, typename... LvlSpecs> class SparseTensorFormat {
public:
using LVLSPECS = std::tuple<LvlSpecs...>;
static constexpr int DIM = D;
static constexpr int LVL = sizeof...(LvlSpecs);

Expand All @@ -199,7 +200,7 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {

static constexpr bool isDnVec() {
if constexpr (LVL == 1) {
using first_type = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using first_type = std::tuple_element_t<0, LVLSPECS>;
return first_type::lvltype == LvlType::Dense &&
first_type::expr::op == LvlOp::Id && first_type::expr::di == 0;
}
Expand All @@ -208,7 +209,7 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {

static constexpr bool isSpVec() {
if constexpr (LVL == 1) {
using first_type = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using first_type = std::tuple_element_t<0, LVLSPECS>;
return first_type::lvltype == LvlType::Compressed &&
first_type::expr::op == LvlOp::Id && first_type::expr::di == 0;
}
Expand All @@ -217,8 +218,8 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {

static constexpr bool isCOO() {
if constexpr (LVL == 2) {
using first_type = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using second_type = std::tuple_element_t<1, std::tuple<LvlSpecs...>>;
using first_type = std::tuple_element_t<0, LVLSPECS>;
using second_type = std::tuple_element_t<1, LVLSPECS>;
return first_type::lvltype == LvlType::CompressedNonUnique &&
first_type::expr::op == LvlOp::Id && first_type::expr::di == 0 &&
second_type::lvltype == LvlType::Singleton &&
Expand All @@ -229,8 +230,8 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {

static constexpr bool isCSR() {
if constexpr (LVL == 2) {
using first_type = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using second_type = std::tuple_element_t<1, std::tuple<LvlSpecs...>>;
using first_type = std::tuple_element_t<0, LVLSPECS>;
using second_type = std::tuple_element_t<1, LVLSPECS>;
return first_type::lvltype == LvlType::Dense &&
first_type::expr::op == LvlOp::Id && first_type::expr::di == 0 &&
second_type::lvltype == LvlType::Compressed &&
Expand All @@ -241,8 +242,8 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {

static constexpr bool isCSC() {
if constexpr (LVL == 2) {
using first_type = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using second_type = std::tuple_element_t<1, std::tuple<LvlSpecs...>>;
using first_type = std::tuple_element_t<0, LVLSPECS>;
using second_type = std::tuple_element_t<1, LVLSPECS>;
return first_type::lvltype == LvlType::Dense &&
first_type::expr::op == LvlOp::Id && first_type::expr::di == 1 &&
second_type::lvltype == LvlType::Compressed &&
Expand All @@ -252,12 +253,13 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {
}

template <typename CRD>
static CRD *dim2lvl(const CRD *dims, CRD *lvls, bool asSize) {
static CRD __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ *
dim2lvl(const CRD *dims, CRD *lvls, bool asSize) {
// Lambda for dim2lvl translation.
auto loop_fun = [&dims, &lvls, &asSize](auto ic) {
constexpr int idx = decltype(ic)::value;
if constexpr (LVL >= (idx + 1)) {
using ftype = std::tuple_element_t<idx, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<idx, LVLSPECS>;
if constexpr (ftype::expr::op == LvlOp::Id) {
lvls[idx] = dims[ftype::expr::di];
} else if constexpr (ftype::expr::op == LvlOp::Div) {
Expand All @@ -278,12 +280,14 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {
return lvls;
}

template <typename CRD> static CRD *lvl2dim(const CRD *lvls, CRD *dims) {
template <typename CRD>
static CRD __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ *
lvl2dim(const CRD *lvls, CRD *dims) {
// Lambda for lvl2dim translation.
auto loop_fun = [&lvls, &dims](auto ic) {
constexpr int idx = decltype(ic)::value;
if constexpr (LVL >= (idx + 1)) {
using ftype = std::tuple_element_t<idx, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<idx, LVLSPECS>;
if constexpr (ftype::expr::op == LvlOp::Id) {
dims[ftype::expr::di] = lvls[idx];
} else if constexpr (ftype::expr::op == LvlOp::Div) {
Expand Down Expand Up @@ -314,35 +318,35 @@ template <int D, typename... LvlSpecs> class SparseTensorFormat {
// Assumes LVL <= 5.
static_assert(LVL <= 5);
if constexpr (LVL > 1) {
using ftype = std::tuple_element_t<0, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<0, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL != 1) {
std::cout << ",";
}
}
if constexpr (LVL >= 2) {
using ftype = std::tuple_element_t<1, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<1, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL > 2) {
std::cout << ",";
}
}
if constexpr (LVL >= 3) {
using ftype = std::tuple_element_t<2, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<2, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL > 3) {
std::cout << ",";
}
}
if constexpr (LVL >= 4) {
using ftype = std::tuple_element_t<3, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<3, LVLSPECS>;
std::cout << " " << ftype::toString();
if constexpr (LVL > 4) {
std::cout << ",";
}
}
if constexpr (LVL >= 5) {
using ftype = std::tuple_element_t<4, std::tuple<LvlSpecs...>>;
using ftype = std::tuple_element_t<4, LVLSPECS>;
std::cout << " " << ftype::toString();
}
std::cout << " )" << std::endl;
Expand Down
5 changes: 3 additions & 2 deletions include/matx/core/tensor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,13 @@ struct DenseTensorData {
T *ldata_;
};

template <typename T, typename CRD, typename POS, int L>
template <typename T, typename CRD, typename POS, typename TF>
struct SparseTensorData {
using sparse_data = bool;
using crd_type = CRD;
using pos_type = POS;
static constexpr int LVL = L;
using Format = TF;
static constexpr int LVL = TF::LVL;
T *ldata_;
CRD *crd_[LVL];
POS *pos_[LVL];
Expand Down

0 comments on commit 75e70bc

Please sign in to comment.