Skip to content

Commit

Permalink
inital fixes for testing shift shapes for validity
Browse files Browse the repository at this point in the history
  • Loading branch information
tylera-nvidia authored and cliffburdick committed Jul 20, 2023
1 parent ee4aafb commit 7d2a18c
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions include/matx/operators/shift.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ namespace matx
__MATX_INLINE__ ShiftOp(T1 op, T2 shift) : op_(op), shift_(shift)
{
static_assert(DIM < Rank(), "Dimension to shift must be less than rank of tensor");
ASSERT_COMPATIBLE_OP_SIZES(shift_);
ASSERT_COMPATIBLE_OP_SIZES(op_);
}

template <typename... Is>
Expand Down Expand Up @@ -103,12 +105,14 @@ namespace matx

static __MATX_INLINE__ constexpr __MATX_HOST__ __MATX_DEVICE__ int32_t Rank()
{
return detail::get_rank<T1>();
return detail::matx_max(detail::get_rank<T1>(), detail::get_rank<T2>());
}

constexpr __MATX_INLINE__ __MATX_HOST__ __MATX_DEVICE__ auto Size(int dim) const noexcept
{
return op_.Size(dim);
index_t size1 = detail::get_expanded_size<Rank()>(op_, dim);
index_t size2 = detail::get_expanded_size<Rank()>(shift_, dim);
return detail::matx_max(size1,size2);
}

template<typename R> __MATX_INLINE__ auto operator=(const R &rhs) { return set(*this, rhs); }
Expand Down

0 comments on commit 7d2a18c

Please sign in to comment.