Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Sparse] Support Diag sparse format in C++ #5432

Merged
merged 3 commits into from
Mar 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion dgl_sparse/include/sparse/sparse_format.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace dgl {
namespace sparse {

/** @brief SparseFormat enumeration. */
enum SparseFormat { kCOO, kCSR, kCSC };
enum SparseFormat { kCOO, kCSR, kCSC, kDiag };

/** @brief COO sparse structure. */
struct COO {
Expand Down Expand Up @@ -50,6 +50,11 @@ struct CSR {
bool sorted = false;
};

struct Diag {
/** @brief The dense shape of the matrix. */
int64_t num_rows = 0, num_cols = 0;
};

/** @brief Convert an old DGL COO format to a COO in the sparse library. */
std::shared_ptr<COO> COOFromOldDGLCOO(const aten::COOMatrix& dgl_coo);

Expand Down Expand Up @@ -90,6 +95,21 @@ std::shared_ptr<CSR> COOToCSC(const std::shared_ptr<COO>& coo);
/** @brief Convert a CSR format to CSC format. */
std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr);

/** @brief Convert a Diag format to COO format. */
std::shared_ptr<COO> DiagToCOO(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this 3 conversion? In which case, they will be used?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are used for operators that do not have implementation on Diag format, e.g., SpMM.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the implication of the performance here if we convert diag to COO/CSR/CSV for operators?
Will it strongly decrease the spmm performance?

This solution looks good to me, but let's make sure we understand the trade-off we are making here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simply follow our current implementation to ensure no performance regression. Currently, we also convert the DiagMatrix to SparseMatrix for SpMM on the Python side.

const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options);

/** @brief Convert a Diag format to CSR format. */
std::shared_ptr<CSR> DiagToCSR(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options);

/** @brief Convert a Diag format to CSC format. */
std::shared_ptr<CSR> DiagToCSC(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options);

/** @brief COO transposition. */
std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo);

Expand Down
36 changes: 33 additions & 3 deletions dgl_sparse/include/sparse/sparse_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class SparseMatrix : public torch::CustomClassHolder {
*/
SparseMatrix(
const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape);
const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
torch::Tensor value, const std::vector<int64_t>& shape);

/**
* @brief Construct a SparseMatrix from a COO format.
Expand Down Expand Up @@ -77,6 +77,18 @@ class SparseMatrix : public torch::CustomClassHolder {
const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape);

/**
* @brief Construct a SparseMatrix from a Diag format.
* @param diag The Diag format
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
static c10::intrusive_ptr<SparseMatrix> FromDiagPointer(
const std::shared_ptr<Diag>& diag, torch::Tensor value,
const std::vector<int64_t>& shape);

/**
* @brief Create a SparseMatrix from tensors in COO format.
* @param indices COO coordinates with shape (2, nnz).
Expand Down Expand Up @@ -115,6 +127,16 @@ class SparseMatrix : public torch::CustomClassHolder {
torch::Tensor indptr, torch::Tensor indices, torch::Tensor value,
const std::vector<int64_t>& shape);

/**
* @brief Create a SparseMatrix with Diag format.
* @param value Values of the sparse matrix
* @param shape Shape of the sparse matrix
*
* @return SparseMatrix
*/
static c10::intrusive_ptr<SparseMatrix> FromDiag(
torch::Tensor value, const std::vector<int64_t>& shape);

/**
* @brief Create a SparseMatrix from a SparseMatrix using new values.
* @param mat An existing sparse matrix
Expand Down Expand Up @@ -142,13 +164,20 @@ class SparseMatrix : public torch::CustomClassHolder {
std::shared_ptr<CSR> CSRPtr();
/** @return CSC of the sparse matrix. The CSC is created if not exists. */
std::shared_ptr<CSR> CSCPtr();
/**
* @return Diagonal format of the sparse matrix. An error will be raised if
* it does not have a diagonal format.
*/
std::shared_ptr<Diag> DiagPtr();

/** @brief Check whether this sparse matrix has COO format. */
inline bool HasCOO() const { return coo_ != nullptr; }
/** @brief Check whether this sparse matrix has CSR format. */
inline bool HasCSR() const { return csr_ != nullptr; }
/** @brief Check whether this sparse matrix has CSC format. */
inline bool HasCSC() const { return csc_ != nullptr; }
/** @brief Check whether this sparse matrix has Diag format. */
inline bool HasDiag() const { return diag_ != nullptr; }

/** @return {row, col} tensors in the COO format. */
std::tuple<torch::Tensor, torch::Tensor> COOTensors();
Expand Down Expand Up @@ -191,9 +220,10 @@ class SparseMatrix : public torch::CustomClassHolder {
/** @brief Create the CSC format for the sparse matrix internally */
void _CreateCSC();

// COO/CSC/CSR pointers. Nullptr indicates non-existence.
// COO/CSC/CSR/Diag pointers. Nullptr indicates non-existence.
std::shared_ptr<COO> coo_;
std::shared_ptr<CSR> csr_, csc_;
std::shared_ptr<Diag> diag_;
// Value of the SparseMatrix
torch::Tensor value_;
// Shape of the SparseMatrix
Expand Down
4 changes: 4 additions & 0 deletions dgl_sparse/src/elemenwise_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ c10::intrusive_ptr<SparseMatrix> SpSpAdd(
const c10::intrusive_ptr<SparseMatrix>& A,
const c10::intrusive_ptr<SparseMatrix>& B) {
ElementwiseOpSanityCheck(A, B);
if (A->HasDiag() && B->HasDiag()) {
return SparseMatrix::FromDiagPointer(
A->DiagPtr(), A->value() + B->value(), A->shape());
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I will usually prefer parallel if-else branch, e.g.,

if (...) {
  ...
} else {
  ...
}

@frozenbugs what would you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The current implementation is better.
For all functions, it could have multiple early return checks, so it is highly recommended to do:

if (...) {
  return ...
}
if (...) {
  return ...
}
blablabla
blablabla
blablabla
return ...

auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
auto sum = (torch_A + torch_B).coalesce();
Expand Down
1 change: 1 addition & 0 deletions dgl_sparse/src/python_binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ TORCH_LIBRARY(dgl_sparse, m) {
m.def("from_coo", &SparseMatrix::FromCOO)
.def("from_csr", &SparseMatrix::FromCSR)
.def("from_csc", &SparseMatrix::FromCSC)
.def("from_diag", &SparseMatrix::FromDiag)
.def("spsp_add", &SpSpAdd)
.def("reduce", &Reduce)
.def("sum", &ReduceSum)
Expand Down
33 changes: 33 additions & 0 deletions dgl_sparse/src/sparse_format.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,39 @@ std::shared_ptr<CSR> CSRToCSC(const std::shared_ptr<CSR>& csr) {
return CSRFromOldDGLCSR(dgl_csc);
}

std::shared_ptr<COO> DiagToCOO(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options) {
int64_t nnz = std::min(diag->num_rows, diag->num_cols);
auto indices = torch::arange(nnz, indices_options).repeat({2, 1});
return std::make_shared<COO>(
COO{diag->num_rows, diag->num_cols, indices, true, true});
}

std::shared_ptr<CSR> DiagToCSR(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options) {
int64_t nnz = std::min(diag->num_rows, diag->num_cols);
auto indptr = torch::full(diag->num_rows + 1, nnz, indices_options);
torch::arange_out(indptr, nnz + 1);
auto indices = torch::arange(nnz, indices_options);
return std::make_shared<CSR>(
CSR{diag->num_rows, diag->num_cols, indptr, indices,
torch::optional<torch::Tensor>(), true});
}

std::shared_ptr<CSR> DiagToCSC(
const std::shared_ptr<Diag>& diag,
const c10::TensorOptions& indices_options) {
int64_t nnz = std::min(diag->num_rows, diag->num_cols);
auto indptr = torch::full(diag->num_cols + 1, nnz, indices_options);
torch::arange_out(indptr, nnz + 1);
auto indices = torch::arange(nnz, indices_options);
return std::make_shared<CSR>(
CSR{diag->num_cols, diag->num_rows, indptr, indices,
torch::optional<torch::Tensor>(), true});
}

std::shared_ptr<COO> COOTranspose(const std::shared_ptr<COO>& coo) {
auto dgl_coo = COOToOldDGLCOO(coo);
auto dgl_coo_tr = aten::COOTranspose(dgl_coo);
Expand Down
72 changes: 61 additions & 11 deletions dgl_sparse/src/sparse_matrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@ namespace sparse {

SparseMatrix::SparseMatrix(
const std::shared_ptr<COO>& coo, const std::shared_ptr<CSR>& csr,
const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape)
: coo_(coo), csr_(csr), csc_(csc), value_(value), shape_(shape) {
const std::shared_ptr<CSR>& csc, const std::shared_ptr<Diag>& diag,
torch::Tensor value, const std::vector<int64_t>& shape)
: coo_(coo),
csr_(csr),
csc_(csc),
diag_(diag),
value_(value),
shape_(shape) {
TORCH_CHECK(
coo != nullptr || csr != nullptr || csc != nullptr, "At least ",
"one of CSR/COO/CSC is required to construct a SparseMatrix.")
coo != nullptr || csr != nullptr || csc != nullptr || diag != nullptr,
"At least one of CSR/COO/CSC/Diag is required to construct a "
"SparseMatrix.")
TORCH_CHECK(
shape.size() == 2, "The shape of a sparse matrix should be ",
"2-dimensional.");
Expand Down Expand Up @@ -51,24 +57,37 @@ SparseMatrix::SparseMatrix(
TORCH_CHECK(csc->indptr.device() == value.device());
TORCH_CHECK(csc->indices.device() == value.device());
}
if (diag != nullptr) {
TORCH_CHECK(value.size(0) == std::min(diag->num_rows, diag->num_cols));
}
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOOPointer(
const std::shared_ptr<COO>& coo, torch::Tensor value,
const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(coo, nullptr, nullptr, value, shape);
return c10::make_intrusive<SparseMatrix>(
coo, nullptr, nullptr, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSRPointer(
const std::shared_ptr<CSR>& csr, torch::Tensor value,
const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(nullptr, csr, nullptr, value, shape);
return c10::make_intrusive<SparseMatrix>(
nullptr, csr, nullptr, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSCPointer(
const std::shared_ptr<CSR>& csc, torch::Tensor value,
const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(nullptr, nullptr, csc, value, shape);
return c10::make_intrusive<SparseMatrix>(
nullptr, nullptr, csc, nullptr, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiagPointer(
const std::shared_ptr<Diag>& diag, torch::Tensor value,
const std::vector<int64_t>& shape) {
return c10::make_intrusive<SparseMatrix>(
nullptr, nullptr, nullptr, diag, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCOO(
Expand Down Expand Up @@ -97,6 +116,12 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromCSC(
return SparseMatrix::FromCSCPointer(csc, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::FromDiag(
torch::Tensor value, const std::vector<int64_t>& shape) {
auto diag = std::make_shared<Diag>(Diag{shape[0], shape[1]});
return SparseMatrix::FromDiagPointer(diag, value, shape);
}

c10::intrusive_ptr<SparseMatrix> SparseMatrix::ValLike(
const c10::intrusive_ptr<SparseMatrix>& mat, torch::Tensor value) {
TORCH_CHECK(
Expand Down Expand Up @@ -136,6 +161,13 @@ std::shared_ptr<CSR> SparseMatrix::CSCPtr() {
return csc_;
}

std::shared_ptr<Diag> SparseMatrix::DiagPtr() {
TORCH_CHECK(
diag_ != nullptr,
"Cannot get Diag sparse format from a non-diagonal sparse matrix");
return diag_;
}

std::tuple<torch::Tensor, torch::Tensor> SparseMatrix::COOTensors() {
auto coo = COOPtr();
return std::make_tuple(coo->indices.index({0}), coo->indices.index({1}));
Expand Down Expand Up @@ -175,7 +207,13 @@ c10::intrusive_ptr<SparseMatrix> SparseMatrix::Transpose() const {

void SparseMatrix::_CreateCOO() {
if (HasCOO()) return;
if (HasCSR()) {
if (HasDiag()) {
auto indices_options = torch::TensorOptions()
.dtype(torch::kInt64)
.layout(torch::kStrided)
.device(this->device());
coo_ = DiagToCOO(diag_, indices_options);
} else if (HasCSR()) {
coo_ = CSRToCOO(csr_);
} else if (HasCSC()) {
coo_ = CSCToCOO(csc_);
Expand All @@ -186,7 +224,13 @@ void SparseMatrix::_CreateCOO() {

void SparseMatrix::_CreateCSR() {
if (HasCSR()) return;
if (HasCOO()) {
if (HasDiag()) {
auto indices_options = torch::TensorOptions()
.dtype(torch::kInt64)
.layout(torch::kStrided)
.device(this->device());
csr_ = DiagToCSR(diag_, indices_options);
} else if (HasCOO()) {
csr_ = COOToCSR(coo_);
} else if (HasCSC()) {
csr_ = CSCToCSR(csc_);
Expand All @@ -197,7 +241,13 @@ void SparseMatrix::_CreateCSR() {

void SparseMatrix::_CreateCSC() {
if (HasCSC()) return;
if (HasCOO()) {
if (HasDiag()) {
auto indices_options = torch::TensorOptions()
.dtype(torch::kInt64)
.layout(torch::kStrided)
.device(this->device());
csc_ = DiagToCSC(diag_, indices_options);
} else if (HasCOO()) {
csc_ = COOToCSC(coo_);
} else if (HasCSR()) {
csc_ = CSRToCSC(csr_);
Expand Down
39 changes: 39 additions & 0 deletions dgl_sparse/src/spspmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,49 @@ tensor_list SpSpMMAutoGrad::backward(
return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};
}

c10::intrusive_ptr<SparseMatrix> DiagSpSpMM(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
// Diag @ Diag
const int64_t m = lhs_mat->shape()[0];
const int64_t n = lhs_mat->shape()[1];
const int64_t p = rhs_mat->shape()[1];
const int64_t common_diag_len = std::min({m, n, p});
const int64_t new_diag_len = std::min(m, p);
auto slice = torch::indexing::Slice(0, common_diag_len);
auto new_val =
lhs_mat->value().index({slice}) * rhs_mat->value().index({slice});
new_val =
torch::constant_pad_nd(new_val, {0, new_diag_len - common_diag_len}, 0);
return SparseMatrix::FromDiag(new_val, {m, p});
}
if (lhs_mat->HasDiag() && !rhs_mat->HasDiag()) {
// Diag @ Sparse
auto row = rhs_mat->Indices().index({0});
auto val = lhs_mat->value().index_select(0, row) * rhs_mat->value();
return SparseMatrix::ValLike(rhs_mat, val);
}
if (!lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
// Sparse @ Diag
auto col = lhs_mat->Indices().index({1});
auto val = rhs_mat->value().index_select(0, col) * lhs_mat->value();
return SparseMatrix::ValLike(lhs_mat, val);
}
TORCH_CHECK(
false,
"For DiagSpSpMM, at least one of the sparse matries need to have kDiag "
"format");
return c10::intrusive_ptr<SparseMatrix>();
}

c10::intrusive_ptr<SparseMatrix> SpSpMM(
const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
_SpSpMMSanityCheck(lhs_mat, rhs_mat);
if (lhs_mat->HasDiag() || rhs_mat->HasDiag()) {
return DiagSpSpMM(lhs_mat, rhs_mat);
}
auto results = SpSpMMAutoGrad::apply(
lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());
std::vector<int64_t> ret_shape({lhs_mat->shape()[0], rhs_mat->shape()[1]});
Expand Down