Skip to content

Commit

Permalink
review update:
Browse files Browse the repository at this point in the history
- refactoring
- test for matrix generator

Co-authored-by: Tobias Ribizel <ribizel@kit.edu>
Co-authored-by: Yu-Hsiang M. Tsai <yhmtsai@gmail.com>
  • Loading branch information
3 people committed Apr 25, 2023
1 parent aa417dc commit b4b2fb2
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 11 deletions.
19 changes: 8 additions & 11 deletions core/test/utils/matrix_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,8 +566,7 @@ std::unique_ptr<MatrixType> generate_tridiag_matrix(
*/
template <typename ValueType, typename IndexType>
gko::matrix_data<ValueType, IndexType> generate_tridiag_inverse_matrix_data(
gko::size_type size, std::array<ValueType, 3> coeffs,
std::shared_ptr<const gko::Executor> exec)
gko::size_type size, std::array<ValueType, 3> coeffs)
{
auto lower = coeffs[0];
auto diag = coeffs[1];
Expand All @@ -588,17 +587,15 @@ gko::matrix_data<ValueType, IndexType> generate_tridiag_inverse_matrix_data(
if (i == j) {
md.nonzeros.emplace_back(i, j,
alpha[i] * beta[j + 1] / alpha.back());
} else if (i < j) {
auto sign = static_cast<ValueType>((i + j) % 2 ? -1 : 1);
auto val = sign *
static_cast<ValueType>(std::pow(upper, j - i)) *
alpha[i] * beta[j + 1] / alpha.back();
md.nonzeros.emplace_back(i, j, val);
} else {
auto sign = static_cast<ValueType>((i + j) % 2 ? -1 : 1);
auto off_diag = i < j ? upper : lower;
auto min_idx = std::min(i, j);
auto max_idx = std::max(i, j);
auto val = sign *
static_cast<ValueType>(std::pow(lower, i - j)) *
alpha[j] * beta[i + 1] / alpha.back();
static_cast<ValueType>(
std::pow(off_diag, max_idx - min_idx)) *
alpha[min_idx] * beta[max_idx + 1] / alpha.back();
md.nonzeros.emplace_back(i, j, val);
}
}
Expand All @@ -619,7 +616,7 @@ std::unique_ptr<MatrixType> generate_tridiag_inverse_matrix(
mtx->read(
generate_tridiag_inverse_matrix_data<typename MatrixType::value_type,
typename MatrixType::index_type>(
size, coeffs, exec));
size, coeffs));
return mtx;
}

Expand Down
53 changes: 53 additions & 0 deletions core/test/utils/matrix_generator_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <gtest/gtest.h>


#include "core/base/utils.hpp"
#include "core/test/utils.hpp"


Expand Down Expand Up @@ -270,4 +271,56 @@ TYPED_TEST(MatrixGenerator, CanGenerateBandMatrix)
}


TYPED_TEST(MatrixGenerator, CanGenerateTridiagMatrix)
{
using T = typename TestFixture::value_type;
using Dense = typename TestFixture::mtx_type;
auto dist = std::normal_distribution<gko::remove_complex<T>>(0, 1);
auto engine = std::default_random_engine(42);
auto lower = gko::test::detail::get_rand_value<T>(dist, engine);
auto diag = gko::test::detail::get_rand_value<T>(dist, engine);
auto upper = gko::test::detail::get_rand_value<T>(dist, engine);

auto mtx = gko::test::generate_tridiag_matrix<Dense>(
50, {lower, diag, upper}, this->exec);

GKO_ASSERT_IS_SQUARE_MATRIX(mtx);
for (gko::size_type i = 0; i < mtx->get_size()[0]; ++i) {
ASSERT_EQ(mtx->at(i, i), diag);
if (i > 0) {
ASSERT_EQ(mtx->at(i, i - 1), lower);
ASSERT_EQ(mtx->at(i - 1, i), upper);
}
}
}


TYPED_TEST(MatrixGenerator, CanGenerateTridiagInverseMatrix)
{
using T = typename TestFixture::value_type;
using Dense = typename TestFixture::mtx_type;
auto dist = std::normal_distribution<gko::remove_complex<T>>(0, 1);
auto engine = std::default_random_engine(42);
auto lower = gko::test::detail::get_rand_value<T>(dist, engine);
auto upper = gko::test::detail::get_rand_value<T>(dist, engine);
// make diagonally dominant
auto diag = std::abs(gko::test::detail::get_rand_value<T>(dist, engine)) +
std::abs(lower) + std::abs(upper);

auto mtx = gko::test::generate_tridiag_matrix<Dense>(
50, {lower, diag, upper}, this->exec);
auto inv_mtx = gko::test::generate_tridiag_inverse_matrix<Dense>(
50, {lower, diag, upper}, this->exec);

auto result = Dense::create(this->exec, mtx->get_size());
inv_mtx->apply(mtx, result);
auto id = Dense::create(this->exec, mtx->get_size());
id->fill(0.0);
for (gko::size_type i = 0; i < mtx->get_size()[0]; ++i) {
id->at(i, i) = gko::one<T>();
}
GKO_ASSERT_MTX_NEAR(result, id, r<T>::value * 10);
}


} // namespace

0 comments on commit b4b2fb2

Please sign in to comment.