Skip to content

Commit

Permalink
add zero-checks to axpy-like operations
Browse files Browse the repository at this point in the history
This prevents NaNs from polluting the output
  • Loading branch information
upsj committed Mar 18, 2024
1 parent 4b31772 commit f7c8214
Show file tree
Hide file tree
Showing 23 changed files with 173 additions and 41 deletions.
7 changes: 5 additions & 2 deletions common/cuda_hip/matrix/csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,9 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_merge_path_spmv(
merge_path_spmv<items_per_thread>(
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
[&alpha_val](const type& x) { return alpha_val * x; },
[&beta_val](const type& x) { return beta_val * x; });
[&beta_val](const type& x) {
return is_zero(beta_val) ? zero(beta_val) : beta_val * x;
});
}


Expand Down Expand Up @@ -480,7 +482,8 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_classical_spmv(
device_classical_spmv<subwarp_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
return is_zero(beta_val) ? alpha_val * x
: alpha_val * x + beta_val * y;
});
}

Expand Down
5 changes: 4 additions & 1 deletion common/cuda_hip/matrix/ell_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,10 @@ __global__ __launch_bounds__(default_block_size) void spmv(
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val, &beta_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(
alpha_val * x + static_cast<arithmetic_type>(beta_val * y));
is_zero(beta_val)
? alpha_val * x
: alpha_val * x +
static_cast<arithmetic_type>(beta_val * y));
});
}
}
Expand Down
4 changes: 3 additions & 1 deletion common/cuda_hip/matrix/sellp_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ __global__ __launch_bounds__(default_block_size) void advanced_spmv_kernel(
}
}
c[row * c_stride + column_id] =
beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
is_zero(beta[0])
? alpha[0] * val
: beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
}
}

Expand Down
3 changes: 2 additions & 1 deletion common/cuda_hip/matrix/sparsity_csr_kernels.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ __global__ __launch_bounds__(spmv_block_size) void abstract_classical_spmv(
device_classical_spmv<subwarp_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
return is_zero(beta_val) ? alpha_val * x
: alpha_val * x + beta_val * y;
});
}

Expand Down
28 changes: 22 additions & 6 deletions common/unified/matrix/dense_kernels.template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,22 @@ void scale(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x) {
x(row, col) *= alpha[col];
if (is_zero(zero(alpha[col]))) {
x(row, col) = zero(alpha[col]);
} else {
x(row, col) *= alpha[col];
}
},
x->get_size(), alpha->get_const_values(), x);
} else {
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x) {
x(row, col) *= alpha[0];
if (is_zero(alpha[0])) {
x(row, col) = zero(alpha[0]);
} else {
x(row, col) *= alpha[0];
}
},
x->get_size(), alpha->get_const_values(), x);
}
Expand Down Expand Up @@ -130,7 +138,9 @@ void add_scaled(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x, auto y) {
y(row, col) += alpha[0] * x(row, col);
if (is_nonzero(alpha[0])) {
y(row, col) += alpha[0] * x(row, col);
}
},
x->get_size(), alpha->get_const_values(), x, y);
}
Expand All @@ -153,7 +163,9 @@ void sub_scaled(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto row, auto col, auto alpha, auto x, auto y) {
y(row, col) -= alpha[0] * x(row, col);
if (is_nonzero(alpha[0])) {
y(row, col) -= alpha[0] * x(row, col);
}
},
x->get_size(), alpha->get_const_values(), x, y);
}
Expand All @@ -170,7 +182,9 @@ void add_scaled_diag(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto i, auto alpha, auto diag, auto y) {
y(i, i) += alpha[0] * diag[i];
if (is_nonzero(alpha[0])) {
y(i, i) += alpha[0] * diag[i];
}
},
x->get_size()[0], alpha->get_const_values(), x->get_const_values(), y);
}
Expand All @@ -186,7 +200,9 @@ void sub_scaled_diag(std::shared_ptr<const DefaultExecutor> exec,
run_kernel(
exec,
[] GKO_KERNEL(auto i, auto alpha, auto diag, auto y) {
y(i, i) -= alpha[0] * diag[i];
if (is_nonzero(alpha[0])) {
y(i, i) -= alpha[0] * diag[i];
}
},
x->get_size()[0], alpha->get_const_values(), x->get_const_values(), y);
}
Expand Down
9 changes: 6 additions & 3 deletions dpcpp/matrix/csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,8 +490,10 @@ void abstract_merge_path_spmv(
merge_path_spmv<items_per_thread>(
num_rows, val, col_idxs, row_ptrs, srow, b, c, row_out, val_out,
[&alpha_val](const type& x) { return alpha_val * x; },
[&beta_val](const type& x) { return beta_val * x; }, item_ct1,
shared_row_ptrs);
[&beta_val](const type& x) {
return is_zero(beta_val) ? zero(beta_val) : beta_val * x;
},
item_ct1, shared_row_ptrs);
}

template <int items_per_thread, typename matrix_accessor,
Expand Down Expand Up @@ -713,7 +715,8 @@ void abstract_classical_spmv(
device_classical_spmv<subgroup_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
return is_zero(beta_val) ? alpha_val * x
: alpha_val * x + beta_val * y;
},
item_ct1);
}
Expand Down
5 changes: 4 additions & 1 deletion dpcpp/matrix/ell_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,10 @@ void spmv(
num_stored_elements_per_row, b, c, c_stride,
[&alpha_val, &beta_val](const auto& x, const OutputValueType& y) {
return static_cast<OutputValueType>(
alpha_val * x + static_cast<arithmetic_type>(beta_val * y));
is_zero(beta_val)
? alpha_val * x
: alpha_val * x +
static_cast<arithmetic_type>(beta_val * y));
},
item_ct1, storage);
}
Expand Down
4 changes: 3 additions & 1 deletion dpcpp/matrix/sellp_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,9 @@ void advanced_spmv_kernel(size_type num_rows, size_type num_right_hand_sides,
}
}
c[row * c_stride + column_id] =
beta[0] * c[row * c_stride + column_id] + alpha[0] * val;
is_zero(beta[0])
? alpha[0] * val
: alpha[0] * val + beta[0] * c[row * c_stride + column_id];
}
}

Expand Down
3 changes: 2 additions & 1 deletion dpcpp/matrix/sparsity_csr_kernels.dp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ void abstract_classical_spmv(
device_classical_spmv<subgroup_size>(
num_rows, val, col_idxs, row_ptrs, b, c,
[&alpha_val, &beta_val](const type& x, const type& y) {
return alpha_val * x + beta_val * y;
return is_zero(beta_val) ? alpha_val * x
: alpha_val * x + beta_val * y;
},
item_ct1);
}
Expand Down
2 changes: 1 addition & 1 deletion omp/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
#pragma omp parallel for
for (size_type row = 0; row < a->get_size()[0]; ++row) {
for (size_type j = 0; j < c->get_size()[1]; ++j) {
auto sum = c_vals(row, j) * vbeta;
auto sum = is_zero(vbeta) ? zero(vbeta) : c_vals(row, j) * vbeta;
for (size_type k = row_ptrs[row];
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
arithmetic_type val = a_vals(k);
Expand Down
2 changes: 1 addition & 1 deletion omp/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ void apply(std::shared_ptr<const DefaultExecutor> exec,
#pragma omp parallel for
for (size_type row = 0; row < c->get_size()[0]; ++row) {
for (size_type col = 0; col < c->get_size()[1]; ++col) {
c->at(row, col) *= zero<ValueType>();
c->at(row, col) = zero<ValueType>();
}
}
}
Expand Down
4 changes: 3 additions & 1 deletion omp/matrix/ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
const auto alpha_val = arithmetic_type{alpha->at(0, 0)};
const auto beta_val = arithmetic_type{beta->at(0, 0)};
auto out = [&](auto i, auto j, auto value) {
return alpha_val * value + beta_val * arithmetic_type{c->at(i, j)};
return is_zero(beta_val) ? alpha_val * value
: alpha_val * value +
beta_val * arithmetic_type{c->at(i, j)};
};
if (num_rhs == 1) {
spmv_small_rhs<1>(exec, a, b, c, out);
Expand Down
6 changes: 5 additions & 1 deletion omp/matrix/fbcsr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
for (IndexType ibrow = 0; ibrow < nbrows; ++ibrow) {
for (IndexType row = ibrow * bs; row < (ibrow + 1) * bs; ++row) {
for (IndexType rhs = 0; rhs < nvecs; rhs++) {
c->at(row, rhs) *= vbeta;
if (is_zero(vbeta)) {
c->at(row, rhs) = zero(vbeta);
} else {
c->at(row, rhs) *= vbeta;
}
}
}
for (IndexType inz = row_ptrs[ibrow]; inz < row_ptrs[ibrow + 1];
Expand Down
3 changes: 2 additions & 1 deletion omp/matrix/sellp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
const auto alpha_val = alpha->at(0, 0);
const auto beta_val = beta->at(0, 0);
auto out = [&](auto i, auto j, auto value) {
return alpha_val * value + beta_val * c->at(i, j);
return is_zero(beta_val) ? alpha_val * value
: alpha_val * value + beta_val * c->at(i, j);
};
if (num_rhs == 1) {
spmv_small_rhs<1>(exec, a, b, c, out);
Expand Down
4 changes: 3 additions & 1 deletion omp/matrix/sparsity_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ void advanced_spmv(std::shared_ptr<const OmpExecutor> exec,
val * static_cast<arithmetic_type>(b->at(col_idxs[k], j));
}
c->at(row, j) = static_cast<OutputValueType>(
vbeta * static_cast<arithmetic_type>(c->at(row, j)) +
(is_zero(vbeta)
? zero(vbeta)
: vbeta * static_cast<arithmetic_type>(c->at(row, j))) +
valpha * temp_val);
}
}
Expand Down
2 changes: 1 addition & 1 deletion reference/matrix/csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,
auto c_vals = acc::helper::build_rrm_accessor<arithmetic_type>(c);
for (size_type row = 0; row < a->get_size()[0]; ++row) {
for (size_type j = 0; j < c->get_size()[1]; ++j) {
auto sum = c_vals(row, j) * vbeta;
auto sum = is_zero(vbeta) ? zero(vbeta) : c_vals(row, j) * vbeta;
for (size_type k = row_ptrs[row];
k < static_cast<size_type>(row_ptrs[row + 1]); ++k) {
arithmetic_type val = a_vals(k);
Expand Down
36 changes: 24 additions & 12 deletions reference/matrix/dense_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void apply(std::shared_ptr<const ReferenceExecutor> exec,
} else {
for (size_type row = 0; row < c->get_size()[0]; ++row) {
for (size_type col = 0; col < c->get_size()[1]; ++col) {
c->at(row, col) *= zero<ValueType>();
c->at(row, col) = zero<ValueType>();
}
}
}
Expand Down Expand Up @@ -133,7 +133,11 @@ void scale(std::shared_ptr<const ReferenceExecutor> exec,
if (alpha->get_size()[1] == 1) {
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
x->at(i, j) *= alpha->at(0, 0);
if (is_zero(alpha->at(0, 0))) {
x->at(i, j) = zero<ValueType>();
} else {
x->at(i, j) *= alpha->at(0, 0);
}
}
}
} else {
Expand Down Expand Up @@ -178,9 +182,11 @@ void add_scaled(std::shared_ptr<const ReferenceExecutor> exec,
const matrix::Dense<ValueType>* x, matrix::Dense<ValueType>* y)
{
if (alpha->get_size()[1] == 1) {
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
y->at(i, j) += alpha->at(0, 0) * x->at(i, j);
if (is_nonzero(alpha->at(0, 0))) {
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
y->at(i, j) += alpha->at(0, 0) * x->at(i, j);
}
}
}
} else {
Expand All @@ -202,9 +208,11 @@ void sub_scaled(std::shared_ptr<const ReferenceExecutor> exec,
const matrix::Dense<ValueType>* x, matrix::Dense<ValueType>* y)
{
if (alpha->get_size()[1] == 1) {
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
y->at(i, j) -= alpha->at(0, 0) * x->at(i, j);
if (is_nonzero(alpha->at(0, 0))) {
for (size_type i = 0; i < x->get_size()[0]; ++i) {
for (size_type j = 0; j < x->get_size()[1]; ++j) {
y->at(i, j) -= alpha->at(0, 0) * x->at(i, j);
}
}
}
} else {
Expand All @@ -227,8 +235,10 @@ void add_scaled_diag(std::shared_ptr<const ReferenceExecutor> exec,
matrix::Dense<ValueType>* y)
{
const auto diag_values = x->get_const_values();
for (size_type i = 0; i < x->get_size()[0]; i++) {
y->at(i, i) += alpha->at(0, 0) * diag_values[i];
if (is_nonzero(alpha->at(0, 0))) {
for (size_type i = 0; i < x->get_size()[0]; i++) {
y->at(i, i) += alpha->at(0, 0) * diag_values[i];
}
}
}

Expand All @@ -242,8 +252,10 @@ void sub_scaled_diag(std::shared_ptr<const ReferenceExecutor> exec,
matrix::Dense<ValueType>* y)
{
const auto diag_values = x->get_const_values();
for (size_type i = 0; i < x->get_size()[0]; i++) {
y->at(i, i) -= alpha->at(0, 0) * diag_values[i];
if (is_nonzero(alpha->at(0, 0))) {
for (size_type i = 0; i < x->get_size()[0]; i++) {
y->at(i, i) -= alpha->at(0, 0) * diag_values[i];
}
}
}

Expand Down
4 changes: 2 additions & 2 deletions reference/matrix/ell_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,

for (size_type j = 0; j < c->get_size()[1]; j++) {
for (size_type row = 0; row < a->get_size()[0]; row++) {
arithmetic_type result = c->at(row, j);
result *= beta_val;
arithmetic_type result =
is_zero(beta_val) ? zero(beta_val) : beta_val * c->at(row, j);
for (size_type i = 0; i < num_stored_elements_per_row; i++) {
arithmetic_type val = a_vals(row + i * stride);
auto col = a->col_at(row, i);
Expand Down
6 changes: 5 additions & 1 deletion reference/matrix/fbcsr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,11 @@ void advanced_spmv(const std::shared_ptr<const ReferenceExecutor>,
for (IndexType ibrow = 0; ibrow < nbrows; ++ibrow) {
for (IndexType row = ibrow * bs; row < (ibrow + 1) * bs; ++row) {
for (IndexType rhs = 0; rhs < nvecs; rhs++) {
c->at(row, rhs) *= vbeta;
if (is_zero(vbeta)) {
c->at(row, rhs) = zero(vbeta);
} else {
c->at(row, rhs) *= vbeta;
}
}
}

Expand Down
6 changes: 5 additions & 1 deletion reference/matrix/sellp_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,11 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,
break;
}
for (size_type j = 0; j < c->get_size()[1]; j++) {
c->at(global_row, j) *= vbeta;
if (is_nonzero(vbeta)) {
c->at(global_row, j) *= vbeta;
} else {
c->at(global_row, j) = zero<ValueType>();
}
}
for (size_type i = 0; i < slice_lengths[slice]; i++) {
auto val = a->val_at(row, slice_sets[slice], i);
Expand Down
4 changes: 3 additions & 1 deletion reference/matrix/sparsity_csr_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ void advanced_spmv(std::shared_ptr<const ReferenceExecutor> exec,
val * static_cast<arithmetic_type>(b->at(col_idxs[k], j));
}
c->at(row, j) = static_cast<OutputValueType>(
vbeta * static_cast<arithmetic_type>(c->at(row, j)) +
(is_zero(vbeta)
? zero(vbeta)
: vbeta * static_cast<arithmetic_type>(c->at(row, j))) +
valpha * temp_val);
}
}
Expand Down
Loading

0 comments on commit f7c8214

Please sign in to comment.