Skip to content

Commit

Permalink
Covolution optimization and unit tests (#255)
Browse files Browse the repository at this point in the history
Fix lvalue semantics for collapse

Co-authored-by: jluitjens <jluitjens@nvidia.com>
  • Loading branch information
luitjens and luitjens authored Aug 25, 2022
1 parent f1e1397 commit c44a681
Show file tree
Hide file tree
Showing 5 changed files with 237 additions and 15 deletions.
4 changes: 2 additions & 2 deletions bench/00_transform/conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using namespace matx;

using conv_types =
nvbench::type_list<cuda::std::complex<float>, cuda::std::complex<double>>;
nvbench::type_list<cuda::std::complex<float>, cuda::std::complex<double>, float, double>;

/* FFT benchmarks */
template <typename ValueType>
Expand Down Expand Up @@ -44,4 +44,4 @@ void conv1d_2d_batch(nvbench::state &state,
state.exec(
[&out, &at, &bt](nvbench::launch &launch) { conv1d(out, at, bt, MATX_C_MODE_FULL, launch.get_stream()); });
}
NVBENCH_BENCH_TYPES(conv1d_2d_batch, NVBENCH_TYPE_AXES(conv_types));
NVBENCH_BENCH_TYPES(conv1d_2d_batch, NVBENCH_TYPE_AXES(conv_types));
19 changes: 14 additions & 5 deletions include/matx/kernels/conv.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
using outtype_strip = typename OutType::scalar_type;
int chunk_idx = blockIdx.y;
int batch_idx = blockIdx.x;
index_t filter_len = d_filter.Size(Rank-1);
int32_t filter_len = d_filter.Size(Rank-1);

// All but the last dim will be populated
auto bdims = BlockToIdx(d_in, batch_idx, 1);
Expand All @@ -62,13 +62,13 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
if constexpr (std::alignment_of_v < intype_strip >>
std::alignment_of_v<ftype_strip>) {
s_data =
matx::detail::AlignAddr<intype_strip>((uint8_t *)&s_exch[static_cast<index_t>(
matx::detail::AlignAddr<intype_strip>((uint8_t *)&s_exch[static_cast<int32_t>(
filter_len * filt_size_adj)]); // Start data portion after 2x the
// filter to remove conditionals and
// multiply by 0
}
else {
s_data = reinterpret_cast<intype_strip *>(&s_exch[static_cast<index_t>(
s_data = reinterpret_cast<intype_strip *>(&s_exch[static_cast<int32_t>(
filter_len *
filt_size_adj)]); // Start data portion after 2x the filter to
// remove conditionals and multiply by 0
Expand All @@ -80,7 +80,7 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
// duplicate tids based on this formula, but not all threads write out to
// memory. Some are only there to fetch data, while others both fetch and
// compute output
const index_t tid =
const int32_t tid =
static_cast<index_t>(chunk_idx) * (blockDim.x - filter_len + 1) +
threadIdx.x;
int offset = tid - filter_len + 1;
Expand All @@ -89,7 +89,7 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,

// Zero out shared memory since it's used later to index into where we want
// 0-valued taps
for (index_t i = threadIdx.x; i < filter_len + blockDim.x; i += blockDim.x) {
for (int32_t i = threadIdx.x; i < filter_len + blockDim.x; i += blockDim.x) {
s_data[i] = 0.0;
}

Expand Down Expand Up @@ -135,10 +135,19 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
// data in shared memory for blockDim-filt_len+1 to operate on. The rest sit
// idle through this process.
if (tid < full_len && (threadIdx.x < blockDim.x - filter_len + 1)) {
#if 0
#pragma unroll
for (index_t r = 0; r < filter_len; r++) {
val = val + s_filter[r] * s_data[threadIdx.x + filter_len - 1 - r];
}
#else
s_data += threadIdx.x + filter_len - 1;
for (int32_t r = 0; r < filter_len; r++) {
val = val + s_filter[0] * s_data[0];
s_data--;
s_filter++;
}
#endif

if (mode == MATX_C_MODE_FULL) {
bdims[Rank - 1] = tid;
Expand Down
56 changes: 53 additions & 3 deletions include/matx/operators/collapse.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ namespace matx
using matxop = bool;
using scalar_type = typename T1::scalar_type;
using shape_type = typename T1::shape_type;
using matxlvalue = bool;
using matxoplvalue = bool;

__MATX_INLINE__ LCollapseOp(const T1 &op) : op_(op)
{
Expand All @@ -61,7 +61,7 @@ namespace matx
// comptue size of collapsed dimension
size_ = 1;

// Collapse right-most dims
// Collapse left-most dims
#pragma unroll
for(int i = 0 ; i <= DIM; i++) {
size_ *= op_.Size(i);
Expand All @@ -85,7 +85,32 @@ namespace matx
auto ind = in[0];
#pragma unroll
for(int i = 0; i <= DIM; i++) {
index_t d = DIM - i;
int d = DIM - i;
out[d] = ind % op_.Size(d);
ind /= op_.Size(d);
}

return mapply(op_, out);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{
// indices coming in
std::array<index_t, Rank()> in{indices...}; // index coming in
std::array<index_t, T1::Rank()> out; // index going out

#pragma unroll
for(int i = 1; i < Rank(); i++) {
// copy all but first input index into out array
out[DIM+i] = in[i];
}

// expand first input index into DIM indices
auto ind = in[0];
#pragma unroll
for(int i = 0; i <= DIM; i++) {
int d = DIM - i;
out[d] = ind % op_.Size(d);
ind /= op_.Size(d);
}
Expand Down Expand Up @@ -168,6 +193,31 @@ namespace matx
std::array<index_t, Rank()> in{indices...}; // index coming in
std::array<index_t, T1::Rank()> out; // index going out

#pragma unroll
for(int i = 0 ; i < Rank() - 1; i++) {
// copy all but last index into out array
out[i] = in[i];
}

// expand last index into DIM indices
auto ind = in[Rank() - 1];
#pragma unroll
for(int i = 0; i <= DIM; i++) {
index_t d = T1::Rank() - 1 - i;
out[d] = ind % op_.Size(d);
ind /= op_.Size(d);
}

return mapply(op_, out);
}

template <typename... Is>
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ auto& operator()(Is... indices)
{
// indices coming in
std::array<index_t, Rank()> in{indices...}; // index coming in
std::array<index_t, T1::Rank()> out; // index going out

#pragma unroll
for(int i = 0 ; i < Rank() - 1; i++) {
// copy all but last index into out array
Expand Down
2 changes: 1 addition & 1 deletion include/matx/transforms/conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ inline void conv1d_impl(OutputType &o, const In1Type &i1, const In2Type &i2,
* @param stream CUDA stream
*/
template <typename OutputType, typename In1Type, typename In2Type>
inline void conv1d(OutputType &&o, const In1Type &i1, const In2Type &i2,
inline void conv1d(OutputType o, const In1Type &i1, const In2Type &i2,
matxConvCorrMode_t mode, cudaStream_t stream) {
if constexpr ( In1Type::Rank() > In2Type::Rank() ) {
// broadcast i2 path. clone i2 across batches
Expand Down
171 changes: 167 additions & 4 deletions test/00_operators/OperatorTests.cu
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ TYPED_TEST(OperatorTestsNumericNonComplex, CollapseOp)
for(int n = 0; n < N; n++) {
for(int m = 0; m < M; m++) {
for(int k = 0; k < K; k++) {
EXPECT_TRUE(tiv(n,m,k) == tov(n,m*K+k));
ASSERT_TRUE(tiv(n,m,k) == tov(n,m*K+k));
}
}
}
Expand All @@ -637,7 +637,7 @@ TYPED_TEST(OperatorTestsNumericNonComplex, CollapseOp)
for(int n = 0; n < N; n++) {
for(int m = 0; m < M; m++) {
for(int k = 0; k < K; k++) {
EXPECT_TRUE(tiv(n,m,k) == tov(n*M+m,k));
ASSERT_TRUE(tiv(n,m,k) == tov(n*M+m,k));
}
}
}
Expand All @@ -658,7 +658,7 @@ TYPED_TEST(OperatorTestsNumericNonComplex, CollapseOp)
for(int n = 0; n < N; n++) {
for(int m = 0; m < M; m++) {
for(int k = 0; k < K; k++) {
EXPECT_TRUE(tiv(n,m,k) == tov(n*M*K+m*K+k));
ASSERT_TRUE(tiv(n,m,k) == tov(n*M*K+m*K+k));
}
}
}
Expand All @@ -679,7 +679,7 @@ TYPED_TEST(OperatorTestsNumericNonComplex, CollapseOp)
for(int n = 0; n < N; n++) {
for(int m = 0; m < M; m++) {
for(int k = 0; k < K; k++) {
EXPECT_TRUE(tiv(n,m,k) == tov(n*M*K+m*K+k));
ASSERT_TRUE(tiv(n,m,k) == tov(n*M*K+m*K+k));
}
}
}
Expand Down Expand Up @@ -2618,6 +2618,169 @@ TEST(OperatorTests, Cast)
MATX_EXIT_HANDLER();
}

TEST(OperatorTestsAdvanced, AdvancedRemapOp)
{
typedef cuda::std::complex<float> complex;
MATX_ENTER_HANDLER();

int I = 4;
int J = 4;
int K = 14;
int L = 133;

int F = 4096;
int P = 288;

int M = 2;

auto idx = matx::make_tensor<int, 1>({M});

idx(0) = 1;
idx(1) = 3;

auto A = matx::make_tensor<complex, 4>({I, J, K, L});
//collapsed tensor
auto B = matx::make_tensor<complex, 2>({I * M * K, L});

auto index = [&] (int i, int j, int k, int l) {
return i * J * K * L +
j * K * L +
k * L +
l;
};
for (int i = 0; i < I ; i++) {
for (int j = 0; j < J ; j++) {
for (int k = 0; k < K ; k++) {
for (int l = 0; l < L ; l++) {
float val = (float)index(i,j,k,l);
A(i,j,k,l) = complex(val, val/100);
}
}
}
}

(B = 0).run();

auto rop = remap<1>(A, idx);
auto lop = lcollapse<2>(rop);

ASSERT_EQ(lop.Rank() , 2);
ASSERT_EQ(lop.Size(1) , A.Size(3));
ASSERT_EQ(lop.Size(0) , I * M * K);

(B = lop).run();

cudaDeviceSynchronize();

for (int i = 0; i < I; i++) {
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
for (int l = 0; l < L; l++) {
int j = idx(m);
int fidx = i * M * K + m * K + k;
float val = (float)index(i,j,k,l);
complex expected_val = complex(val,val/100);
complex a_val = A(i,j,k,l);
complex b_val = B(fidx, l);
complex lop_val = lop(fidx, l);
complex rop_val = rop(i, m, k, l);

// printf("fidx: %d, i: %d, j: %d, k: %d, l: %d, val: %f,%f\n", fidx, i, j, k, l, val, val/100);
// printf("a_val: %f, %f, rop_val: %f, %f, lop_val: %f, %f, b_val: %f, %f\n",
// a_val.real(), a_val.imag(),
// rop_val.real(), rop_val.imag(),
// lop_val.real(), lop_val.imag(),
// b_val.real(), b_val.imag());
ASSERT_EQ(a_val, expected_val);
ASSERT_EQ(rop_val, expected_val);
ASSERT_EQ(lop_val, expected_val);
ASSERT_EQ(b_val, expected_val);

ASSERT_EQ(B(fidx, l) , lop(fidx, l));
}
}
}
}


// convolution test
auto O1 = matx::make_tensor<complex, 4>({I, J, K, F + P + L - 1});
auto O2 = matx::make_tensor<complex, 4>({I, J, K, F + P + L - 1});
auto O3 = matx::make_tensor<complex, 4>({I, J, K, F + P + L - 1});
auto O4 = matx::make_tensor<complex, 4>({I, J, K, F + P + L - 1});

auto C = matx::make_tensor<complex, 3>({I, K, F + P});
//collapsed tensor
auto D = matx::make_tensor<complex, 2>({I * M * K, F + P});

auto indexc = [&] (int i, int j, int k) {
return i * C.Size(1) * C.Size(2) +
j * C.Size(2) +
k;
};

for (int i = 0; i < I ; i++) {
for (int j = 0; j < J ; j++) {
for (int k = 0; k < K ; k++) {
float val = (float) indexc(i,j,k);
C(i,j,k) = complex(val, val/100);
}
}
}

A.PrefetchDevice(0);
B.PrefetchDevice(0);
C.PrefetchDevice(0);
D.PrefetchDevice(0);
O1.PrefetchDevice(0);
O2.PrefetchDevice(0);
O3.PrefetchDevice(0);
O4.PrefetchDevice(0);

cudaDeviceSynchronize();

auto o1op = lcollapse<2>(remap<1>(O1, idx));
auto o2op = lcollapse<2>(remap<1>(O2, idx));
auto o3op = lcollapse<2>(remap<1>(O3, idx));
auto o4op = lcollapse<2>(remap<1>(O4, idx));

auto cop = C.Clone<4>({matxKeepDim, M, matxKeepDim, matxKeepDim});
auto rcop = lcollapse<2>(remap<1>(cop, idx));

(O1 = 1).run();
(O2 = 2).run();
(O3 = 3).run();
(O4 = 4).run();

(B = lop).run();
(D = rcop).run();

// two operators as input
matx::conv1d(o1op, lop, rcop, matx::matxConvCorrMode_t::MATX_C_MODE_FULL, 0);

// one tensor and one operators as input
matx::conv1d(o2op, B, rcop, matx::matxConvCorrMode_t::MATX_C_MODE_FULL, 0);

// one tensor and one operators as input
matx::conv1d(o3op, lop, D, matx::matxConvCorrMode_t::MATX_C_MODE_FULL, 0);

//two tensors as input
matx::conv1d(o4op, B, D, matx::matxConvCorrMode_t::MATX_C_MODE_FULL, 0);

cudaDeviceSynchronize();

for (int i = 0; i < o1op.Size(0); i++) {
for (int l = 0; l < o1op.Size(1); l++) {
ASSERT_EQ(o1op(i,l), o2op(i,l));
ASSERT_EQ(o2op(i,l), o3op(i,l));
ASSERT_EQ(o3op(i,l), o4op(i,l));
}
}

MATX_EXIT_HANDLER();
}


TYPED_TEST(OperatorTestsFloat, Print)
{
MATX_ENTER_HANDLER();
Expand Down

0 comments on commit c44a681

Please sign in to comment.