Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making the shift parameter in shift an operator #234

Merged
merged 1 commit into from
Aug 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 28 additions & 30 deletions include/matx_tensor_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1692,44 +1692,36 @@ auto __MATX_INLINE__ reverse(Op t)
* of the tensor.
*/
namespace detail {
template <int DIM, typename T1>
class ShiftOp : public BaseOp<ShiftOp<DIM, T1>>
template <int DIM, typename T1, typename T2>
class ShiftOp : public BaseOp<ShiftOp<DIM, T1, T2>>
{
private:
typename base_type<T1>::type op_;
index_t shift_;
index_t base_;
T2 shift_;

public:
using matxop = bool;
using matxoplvalue = bool;
using scalar_type = typename T1::scalar_type;

__MATX_INLINE__ ShiftOp(T1 op, index_t shift) : op_(op), shift_(shift)
__MATX_INLINE__ ShiftOp(T1 op, T2 shift) : op_(op), shift_(shift)
{
static_assert(DIM < Rank(), "Dimension to shift must be less than rank of tensor");

if (shift < 0) {
while (-shift > Size(DIM)) {
shift += Size(DIM);
}

base_ = Size(DIM) + shift;
}
else {
while (shift > Size(DIM)) {
shift -= Size(DIM);
}

base_ = shift;
}
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto operator()(Is... indices) const
{
auto tup = cuda::std::make_tuple(indices...);
cuda::std::get<DIM>(tup) = (base_ + cuda::std::get<DIM>(tup)) % Size(DIM);
auto shift = get_value(shift_, indices...);


shift = (shift + cuda::std::get<DIM>(tup)) % Size(DIM);

if(shift<0) shift+= Size(DIM);

cuda::std::get<DIM>(tup) = shift;

return mapply(op_, tup);
}

Expand All @@ -1752,22 +1744,25 @@ auto __MATX_INLINE__ reverse(Op t)
* @tparam DIM
* The dimension to be shifted
*
* @tparam Op
* @tparam OpT
* Type of operator or view
*
* @tparam ShiftOpT
* Type of the operator for the shift
*
* @param op
* Operator or view to shift
*
* @param s
* Amount to shift forward
* Operator which returns the shift
*
* @returns
* New operator with shifted indices
*/
template <int DIM, typename Op>
auto __MATX_INLINE__ shift(Op op, index_t s)
template <int DIM, typename OpT, typename ShiftOpT>
auto __MATX_INLINE__ shift(OpT op, ShiftOpT s)
{
return detail::ShiftOp<DIM, Op>(op, s);
return detail::ShiftOp<DIM, OpT, ShiftOpT>(op, s);
};


Expand All @@ -1781,9 +1776,12 @@ auto __MATX_INLINE__ reverse(Op t)
* @tparam DIMS...
* The dimensions targeted for shifts
*
* @tparam Op
* @tparam OpT
* Type of operator or view
*
* @tparam ShiftsT
* Type of the shift operators
*
* @param op
* Operator or view to shift
*
Expand All @@ -1793,16 +1791,16 @@ auto __MATX_INLINE__ reverse(Op t)
* @returns
* New operator with shifted indices
*/
template <int DIM, int... DIMS, typename Op, typename... Shifts>
auto __MATX_INLINE__ shift(Op op, index_t s, Shifts... shifts)
template <int DIM, int... DIMS, typename OpT, typename ShiftT, typename... ShiftsT>
auto __MATX_INLINE__ shift(OpT op, ShiftT s, ShiftsT... shifts)
{
static_assert(sizeof...(DIMS) == sizeof...(shifts), "shift: number of DIMs must match number of shifts");

// recursively call shift on remaining bits
auto rop = shift<DIMS...>(op, shifts...);

// construct shift op
return detail::ShiftOp<DIM, decltype(rop)>(rop, s);
return detail::ShiftOp<DIM, decltype(rop), decltype(s)>(rop, s);
};

namespace detail {
Expand Down
34 changes: 24 additions & 10 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2171,14 +2171,16 @@ TYPED_TEST(OperatorTestsAll, RepMat)
MATX_EXIT_HANDLER();
}

TYPED_TEST(OperatorTestsNumeric, Shift)
TYPED_TEST(OperatorTestsNumeric, ShiftOp)
{
MATX_ENTER_HANDLER();
index_t count0 = 100;
index_t count1 = 201;
tensor_t<TypeParam, 2> t2({count0, count1});
tensor_t<TypeParam, 2> t2s({count0, count1});
tensor_t<TypeParam, 2> t2s2({count0, count1});
tensor_t<int, 0> t0;
t0() = 5;

for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
Expand All @@ -2192,7 +2194,19 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
EXPECT_TRUE(
ASSERT_TRUE(
MatXUtils::MatXTypeCompare(t2s(i, j), t2((i + 5) % count0, j)));
}
}
}

{
(t2s = shift<0>(t2, t0)).run();
cudaStreamSynchronize(0);

for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
ASSERT_TRUE(
MatXUtils::MatXTypeCompare(t2s(i, j), t2((i + 5) % count0, j)));
}
}
Expand All @@ -2204,7 +2218,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
EXPECT_TRUE(
ASSERT_TRUE(
MatXUtils::MatXTypeCompare(t2s(i, j), t2(i, (j + 5) % count1)));
}
}
Expand All @@ -2216,7 +2230,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
EXPECT_TRUE(MatXUtils::MatXTypeCompare(
ASSERT_TRUE(MatXUtils::MatXTypeCompare(
t2s(i, j), t2((i + 6) % count0, (j + 5) % count1)));
}
}
Expand All @@ -2228,7 +2242,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
EXPECT_TRUE(MatXUtils::MatXTypeCompare(
ASSERT_TRUE(MatXUtils::MatXTypeCompare(
t2s(i, j), t2((i + (count0 + 1) / 2) % count0,
(j + (count1 + 1) / 2) % count1)));
}
Expand All @@ -2241,7 +2255,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
EXPECT_TRUE(MatXUtils::MatXTypeCompare(
ASSERT_TRUE(MatXUtils::MatXTypeCompare(
t2s(i, j),
t2((i + (count0) / 2) % count0, (j + (count1) / 2) % count1)));
}
Expand All @@ -2256,7 +2270,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
index_t idim = i < 5 ? (t2.Size(0) - 5 + i) : (i - 5);
EXPECT_TRUE(MatXUtils::MatXTypeCompare(t2s(i, j), t2(idim, j)));
ASSERT_TRUE(MatXUtils::MatXTypeCompare(t2s(i, j), t2(idim, j)));
}
}
}
Expand All @@ -2268,7 +2282,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)
for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
index_t jdim = j < 5 ? (t2.Size(1) - 5 + j) : (j - 5);
EXPECT_TRUE(MatXUtils::MatXTypeCompare(t2s(i, j), t2(i, jdim)));
ASSERT_TRUE(MatXUtils::MatXTypeCompare(t2s(i, j), t2(i, jdim)));
}
}
}
Expand All @@ -2280,7 +2294,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
EXPECT_TRUE(MatXUtils::MatXTypeCompare(t2s(i, j), t2(i, j)));
ASSERT_TRUE(MatXUtils::MatXTypeCompare(t2s(i, j), t2(i, j)));
}
}
}
Expand All @@ -2294,7 +2308,7 @@ TYPED_TEST(OperatorTestsNumeric, Shift)

for (index_t i = 0; i < count0; i++) {
for (index_t j = 0; j < count1; j++) {
EXPECT_TRUE(MatXUtils::MatXTypeCompare(t2s(i, j), t2s2(i, j)));
ASSERT_TRUE(MatXUtils::MatXTypeCompare(t2s(i, j), t2s2(i, j)));
}
}
}
Expand Down