Skip to content

Commit

Permalink
Fixed convolution mode SAME and added unit tests (#217)
Browse files Browse the repository at this point in the history
  • Loading branch information
cliffburdick authored Jul 17, 2022
1 parent 395dd25 commit 22c752c
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 22 deletions.
6 changes: 3 additions & 3 deletions include/kernels/matx_conv_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -149,11 +149,11 @@ __global__ void Conv1D(OutType d_out, InType d_in, FilterType d_filter,
int start_tid, stop_tid;
if (filter_len & 1) {
start_tid = (filter_len - 1) >> 1;
stop_tid = signal_len - ((filter_len - 1) >> 1);
stop_tid = full_len - ((filter_len - 1) >> 1);
}
else {
start_tid = filter_len >> 1;
stop_tid = signal_len - (filter_len >> 1) + 1;
start_tid = (filter_len >> 1) - 1;
stop_tid = full_len - (filter_len >> 1) + 1;
}

if (tid >= start_tid && tid <= stop_tid) {
Expand Down
84 changes: 66 additions & 18 deletions test/00_transform/ConvCorr.cu
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,11 @@
using namespace matx;

constexpr index_t a_len = 256;
constexpr index_t b_len = 16;
constexpr index_t c_len = a_len + b_len - 1;
constexpr index_t b_len_even = 16;
constexpr index_t b_len_odd = 15;
constexpr index_t c_len_full_even = a_len + b_len_even - 1;
constexpr index_t c_len_full_odd = a_len + b_len_odd - 1;
constexpr index_t c_len_same = a_len;

template <typename T>
class CorrelationConvolutionTest : public ::testing::Test {
Expand All @@ -50,7 +53,6 @@ protected:
{
CheckTestTypeSupport<T>();
pb = std::make_unique<detail::MatXPybind>();
pb->InitTVGenerator<T>("00_transforms", "conv_operators", {a_len, b_len});

// Half precision needs a bit more tolerance when compared to
// fp32
Expand All @@ -63,8 +65,11 @@ protected:

std::unique_ptr<detail::MatXPybind> pb;
tensor_t<T, 1> av{{a_len}};
tensor_t<T, 1> bv{{b_len}};
tensor_t<T, 1> cv{{c_len}};
tensor_t<T, 1> bv_even{{b_len_even}};
tensor_t<T, 1> bv_odd{{b_len_odd}};
tensor_t<T, 1> cv_full_even{{c_len_full_even}};
tensor_t<T, 1> cv_same{{c_len_same}};
tensor_t<T, 1> cv_full_odd{{c_len_full_odd}};
float thresh = 0.01f;
};

Expand All @@ -76,51 +81,94 @@ class CorrelationConvolutionTestFloatTypes
TYPED_TEST_SUITE(CorrelationConvolutionTestFloatTypes, MatXFloatTypes);

// Real/real direct 1D convolution
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolution)
TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionFullEven)
{
MATX_ENTER_HANDLER();
this->pb->template InitTVGenerator<TypeParam>("00_transforms", "conv_operators", {a_len, b_len_even});
this->pb->RunTVGenerator("conv");
this->pb->NumpyToTensorView(this->av, "a_op");
this->pb->NumpyToTensorView(this->bv, "b_op");
conv1d(this->cv, this->av, this->bv, MATX_C_MODE_FULL, 0);
this->pb->NumpyToTensorView(this->bv_even, "b_op");
conv1d(this->cv_full_even, this->av, this->bv_even, MATX_C_MODE_FULL, 0);

MATX_TEST_ASSERT_COMPARE(this->pb, this->cv, "conv", this->thresh);
MATX_TEST_ASSERT_COMPARE(this->pb, this->cv_full_even, "conv_full", this->thresh);
MATX_EXIT_HANDLER();
}

TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionSameEven)
{
MATX_ENTER_HANDLER();
this->pb->template InitTVGenerator<TypeParam>("00_transforms", "conv_operators", {a_len, b_len_even});
this->pb->RunTVGenerator("conv");
this->pb->NumpyToTensorView(this->av, "a_op");
this->pb->NumpyToTensorView(this->bv_even, "b_op");
conv1d(this->cv_same, this->av, this->bv_even, MATX_C_MODE_SAME, 0);

MATX_TEST_ASSERT_COMPARE(this->pb, this->cv_same, "conv_same", this->thresh);
MATX_EXIT_HANDLER();
}

TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionFullOdd)
{
MATX_ENTER_HANDLER();
this->pb->template InitTVGenerator<TypeParam>("00_transforms", "conv_operators", {a_len, b_len_odd});
this->pb->RunTVGenerator("conv");
this->pb->NumpyToTensorView(this->av, "a_op");
this->pb->NumpyToTensorView(this->bv_odd, "b_op");
conv1d(this->cv_full_odd, this->av, this->bv_odd, MATX_C_MODE_FULL, 0);

MATX_TEST_ASSERT_COMPARE(this->pb, this->cv_full_odd, "conv_full", this->thresh);
MATX_EXIT_HANDLER();
}

TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionSameOdd)
{
MATX_ENTER_HANDLER();
this->pb->template InitTVGenerator<TypeParam>("00_transforms", "conv_operators", {a_len, b_len_odd});
this->pb->RunTVGenerator("conv");
this->pb->NumpyToTensorView(this->av, "a_op");
this->pb->NumpyToTensorView(this->bv_odd, "b_op");
conv1d(this->cv_same, this->av, this->bv_odd, MATX_C_MODE_SAME, 0);

MATX_TEST_ASSERT_COMPARE(this->pb, this->cv_same, "conv_same", this->thresh);
MATX_EXIT_HANDLER();
}

TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DConvolutionSwap)
{
MATX_ENTER_HANDLER();
this->pb->template InitTVGenerator<TypeParam>("00_transforms", "conv_operators", {a_len, b_len_even});
this->pb->RunTVGenerator("conv");
this->pb->NumpyToTensorView(this->av, "a_op");
this->pb->NumpyToTensorView(this->bv, "b_op");
conv1d(this->cv, this->bv, this->av, MATX_C_MODE_FULL, 0);
this->pb->NumpyToTensorView(this->bv_even, "b_op");
conv1d(this->cv_full_even, this->bv_even, this->av, MATX_C_MODE_FULL, 0);

MATX_TEST_ASSERT_COMPARE(this->pb, this->cv, "conv", this->thresh);
MATX_TEST_ASSERT_COMPARE(this->pb, this->cv_full_even, "conv_full", this->thresh);
MATX_EXIT_HANDLER();
}

TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DCorrelation)
{
MATX_ENTER_HANDLER();
this->pb->template InitTVGenerator<TypeParam>("00_transforms", "conv_operators", {a_len, b_len_even});
this->pb->RunTVGenerator("corr");
this->pb->NumpyToTensorView(this->av, "a_op");
this->pb->NumpyToTensorView(this->bv, "b_op");
corr(this->cv, this->av, this->bv, MATX_C_MODE_FULL, MATX_C_METHOD_DIRECT, 0);
this->pb->NumpyToTensorView(this->bv_even, "b_op");
corr(this->cv_full_even, this->av, this->bv_even, MATX_C_MODE_FULL, MATX_C_METHOD_DIRECT, 0);

MATX_TEST_ASSERT_COMPARE(this->pb, this->cv, "corr", this->thresh);
MATX_TEST_ASSERT_COMPARE(this->pb, this->cv_full_even, "corr", this->thresh);
MATX_EXIT_HANDLER();
}

TYPED_TEST(CorrelationConvolutionTestFloatTypes, Direct1DCorrelationSwap)
{
MATX_ENTER_HANDLER();
this->pb->template InitTVGenerator<TypeParam>("00_transforms", "conv_operators", {a_len, b_len_even});
this->pb->RunTVGenerator("corr_swap");
this->pb->NumpyToTensorView(this->av, "a_op");
this->pb->NumpyToTensorView(this->bv, "b_op");
corr(this->cv, this->bv, this->av, MATX_C_MODE_FULL, MATX_C_METHOD_DIRECT, 0);
this->pb->NumpyToTensorView(this->bv_even, "b_op");
corr(this->cv_full_even, this->bv_even, this->av, MATX_C_MODE_FULL, MATX_C_METHOD_DIRECT, 0);

MATX_TEST_ASSERT_COMPARE(this->pb, this->cv, "corr_swap", this->thresh);
MATX_TEST_ASSERT_COMPARE(this->pb, this->cv_full_even, "corr_swap", this->thresh);
MATX_EXIT_HANDLER();
}

Expand Down
3 changes: 2 additions & 1 deletion test/test_vectors/generators/00_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ def __init__(self, dtype: str, size: List[int]):
}

def conv(self):
self.res['conv'] = np.convolve(self.a, self.b, 'full')
self.res['conv_full'] = np.convolve(self.a, self.b, 'full')
self.res['conv_same'] = np.convolve(self.a, self.b, 'same')
return self.res

def corr(self):
Expand Down

0 comments on commit 22c752c

Please sign in to comment.