Skip to content

Commit

Permalink
Add abs2() operator for squared abs() (#568)
Browse files Browse the repository at this point in the history
* Add abs2() operator for squared abs()

Add abs2() operator that is equivalent to abs()*abs() without
the extra loads and the unnecessary sqrt operators associated with
complex abs() calculations.
  • Loading branch information
tbensonatl authored Jan 22, 2024
1 parent 532f699 commit d9b045b
Show file tree
Hide file tree
Showing 4 changed files with 111 additions and 0 deletions.
20 changes: 20 additions & 0 deletions docs_input/api/math/misc/abs2.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
.. _abs2_func:

abs2
====

Squared absolute value. For complex numbers, this is the squared
complex magnitude, or real(t)\ :sup:`2` + imag(t)\ :sup:`2`. For real numbers,
this is equivalent to the squared value, or t\ :sup:`2`.

.. doxygenfunction:: abs2(Op t)

Examples
~~~~~~~~

.. literalinclude:: ../../../../test/00_operators/OperatorTests.cu
:language: cpp
:start-after: example-begin abs2-test-1
:end-before: example-end abs2-test-1
:dedent:

14 changes: 14 additions & 0 deletions include/matx/operators/scalar_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,20 @@ template <typename T> struct ExpjF {
};
template <typename T> using ExpjOp = UnOp<T, ExpjF<T>>;

template <typename T> struct Abs2F {
static __MATX_INLINE__ std::string str() { return "abs2"; }

static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto op(T v)
{
if constexpr (is_complex_v<T>) {
return v.real() * v.real() + v.imag() * v.imag();
}
else {
return v * v;
}
}
};
template <typename T> using Abs2Op = UnOp<T, Abs2F<T>>;

template <typename T> static __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto _internal_normcdf(T v1)
{
Expand Down
10 changes: 10 additions & 0 deletions include/matx/operators/unary_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,15 @@ namespace matx
*/
Op abs(Op t) {}

/**
* Compute squared absolute value of every element in the tensor. For complex numbers
* this returns the squared magnitude, or real(t)^2 + imag(t)^2. For real numbers
* this returns the squared value, or t*t.
* @param t
* Tensor or operator input
*/
Op abs2(Op t) {}

/**
* Compute the sine of every element in the tensor
* @param t
Expand Down Expand Up @@ -379,6 +388,7 @@ namespace matx
#endif
DEFINE_UNARY_OP(norm, detail::NormOp);
DEFINE_UNARY_OP(abs, detail::AbsOp);
DEFINE_UNARY_OP(abs2, detail::Abs2Op);
DEFINE_UNARY_OP(sin, detail::SinOp);
DEFINE_UNARY_OP(cos, detail::CosOp);
DEFINE_UNARY_OP(tan, detail::TanOp);
Expand Down
67 changes: 67 additions & 0 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1582,6 +1582,73 @@ TYPED_TEST(OperatorTestsAllExecs, OperatorFuncs)
MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsNumericAllExecs, Abs2)
{
MATX_ENTER_HANDLER();
using TestType = std::tuple_element_t<0, TypeParam>;
using ExecType = std::tuple_element_t<1, TypeParam>;
using inner_type = typename inner_op_type_t<TestType>::type;

ExecType exec{};

auto sync = [&exec]() constexpr {
if constexpr (std::is_same_v<ExecType,cudaExecutor>) {
cudaDeviceSynchronize();
}
};

if constexpr (std::is_same_v<TestType, cuda::std::complex<float>> &&
std::is_same_v<ExecType,cudaExecutor>) {
// example-begin abs2-test-1
auto x = make_tensor<cuda::std::complex<float>>({});
auto y = make_tensor<float>({});
x() = { 1.5f, 2.5f };
(y = abs2(x)).run();
cudaDeviceSynchronize();
ASSERT_NEAR(y(), 1.5f*1.5f+2.5f*2.5f, 1.0e-6);
// example-end abs2-test-1
}

auto x = make_tensor<TestType>({});
auto y = make_tensor<inner_type>({});
if constexpr (is_complex_v<TestType>) {
x() = TestType{2.0, 2.0};
(y = abs2(x)).run(exec);
sync();
ASSERT_NEAR(y(), 8.0, 1.0e-6);
} else {
x() = 2.0;
(y = abs2(x)).run(exec);
sync();
ASSERT_NEAR(y(), 4.0, 1.0e-6);

// Test with higher rank tensor
auto x3 = make_tensor<TestType>({3,3,3});
auto y3 = make_tensor<TestType>({3,3,3});
for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < 3; k++) {
x3(i,j,k) = static_cast<TestType>(i*9 + j*3 + k);
}
}
}

(y3 = abs2(x3)).run(exec);
sync();

for (int i = 0; i < 3; i++) {
for (int j = 0; j < 3; j++) {
for (int k = 0; k < 3; k++) {
TestType v = static_cast<TestType>(i*9 + j*3 + k);
ASSERT_NEAR(y3(i,j,k), v*v, 1.0e-6);
}
}
}
}

MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsFloatNonComplexAllExecs, OperatorFuncsR2C)
{
MATX_ENTER_HANDLER();
Expand Down

0 comments on commit d9b045b

Please sign in to comment.