Skip to content

Commit

Permalink
Update Jacobi write for scalar.
Browse files Browse the repository at this point in the history
  • Loading branch information
pratikvn committed Aug 5, 2021
1 parent 8d380da commit 9425740
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 20 deletions.
55 changes: 35 additions & 20 deletions core/preconditioner/jacobi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,29 +149,44 @@ void Jacobi<ValueType, IndexType>::write(mat_data &data) const
make_temporary_clone(this->get_executor()->get_master(), this);
data = {local_clone->get_size(), {}};

const auto ptrs = local_clone->parameters_.block_pointers.get_const_data();
for (size_type block = 0; block < local_clone->get_num_blocks(); ++block) {
const auto scheme = local_clone->get_storage_scheme();
const auto group_data = local_clone->blocks_.get_const_data() +
scheme.get_group_offset(block);
const auto block_size = ptrs[block + 1] - ptrs[block];
const auto precisions = local_clone->parameters_.storage_optimization
.block_wise.get_const_data();
const auto prec =
precisions ? precisions[block] : precision_reduction();
GKO_PRECONDITIONER_JACOBI_RESOLVE_PRECISION(ValueType, prec, {
const auto block_data =
reinterpret_cast<const resolved_precision *>(group_data) +
scheme.get_block_offset(block);
for (IndexType row = 0; row < block_size; ++row) {
for (IndexType col = 0; col < block_size; ++col) {
if (parameters_.max_block_size == 1) {
for (IndexType row = 0; row < data.size[0]; ++row) {
for (IndexType col = 0; col < data.size[1]; ++col) {
if (row == col) {
data.nonzeros.emplace_back(
ptrs[block] + row, ptrs[block] + col,
static_cast<ValueType>(
block_data[row + col * scheme.get_stride()]));
row, col,
static_cast<ValueType>(local_clone->get_blocks()[row]));
}
}
});
}
} else {
const auto ptrs =
local_clone->parameters_.block_pointers.get_const_data();
for (size_type block = 0; block < local_clone->get_num_blocks();
++block) {
const auto scheme = local_clone->get_storage_scheme();
const auto group_data = local_clone->blocks_.get_const_data() +
scheme.get_group_offset(block);
const auto block_size = ptrs[block + 1] - ptrs[block];
const auto precisions =
local_clone->parameters_.storage_optimization.block_wise
.get_const_data();
const auto prec =
precisions ? precisions[block] : precision_reduction();
GKO_PRECONDITIONER_JACOBI_RESOLVE_PRECISION(ValueType, prec, {
const auto block_data =
reinterpret_cast<const resolved_precision *>(group_data) +
scheme.get_block_offset(block);
for (IndexType row = 0; row < block_size; ++row) {
for (IndexType col = 0; col < block_size; ++col) {
data.nonzeros.emplace_back(
ptrs[block] + row, ptrs[block] + col,
static_cast<ValueType>(
block_data[row + col * scheme.get_stride()]));
}
}
});
}
}
}

Expand Down
25 changes: 25 additions & 0 deletions reference/test/preconditioner/jacobi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,31 @@ TYPED_TEST(Jacobi, GeneratesCorrectMatrixData)
}


TYPED_TEST(Jacobi, ScalarJacobiGeneratesCorrectMatrixData)
{
using value_type = typename TestFixture::value_type;
using index_type = typename TestFixture::index_type;
using Bj = typename TestFixture::Bj;
gko::matrix_data<value_type, index_type> data;
using tpl = typename decltype(data)::nonzero_type;
auto csr = gko::share(
gko::matrix::Csr<value_type, index_type>::create(this->exec));
csr->copy_from(gko::lend(this->mtx));
auto scalar_j = this->scalar_j_factory->generate(csr);

scalar_j->write(data);

auto tol = r<value_type>::value;
ASSERT_EQ(data.size, gko::dim<2>{5});
ASSERT_EQ(data.nonzeros.size(), 5);
GKO_EXPECT_NONZERO_NEAR(data.nonzeros[0], tpl(0, 0, 1 / 4.0), tol);
GKO_EXPECT_NONZERO_NEAR(data.nonzeros[1], tpl(1, 1, 1 / 4.0), tol);
GKO_EXPECT_NONZERO_NEAR(data.nonzeros[2], tpl(2, 2, 1 / 4.0), tol);
GKO_EXPECT_NONZERO_NEAR(data.nonzeros[3], tpl(3, 3, 1 / 4.0), tol);
GKO_EXPECT_NONZERO_NEAR(data.nonzeros[4], tpl(4, 4, 1 / 4.0), tol);
}


TYPED_TEST(Jacobi, GeneratesCorrectMatrixDataWithAdaptivePrecision)
{
using value_type = typename TestFixture::value_type;
Expand Down

0 comments on commit 9425740

Please sign in to comment.