Skip to content

Commit

Permalink
fix test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
shentanyue committed Jan 20, 2022
1 parent 40a4967 commit 627c42d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 14 deletions.
20 changes: 10 additions & 10 deletions lite/kernels/host/unfold_compute.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ void UnfoldCompute<T, PType>::Run() {
lite::Tensor* output = param.Y;
auto input_dims = input->dims();
const int batch_size = static_cast<int>(input_dims[0]);
output->template mutable_data<T>();

std::vector<int> kernel_sizes = param.kernel_sizes;
std::vector<int> strides = param.strides;
Expand All @@ -101,12 +102,11 @@ void UnfoldCompute<T, PType>::Run() {
paddings[3],
strides[1]);

std::vector<int64_t> output_shape(
{input_dims[0],
input_dims[1] * kernel_sizes[0] * kernel_sizes[1],
output_height * output_width});
output->Resize(output_shape);
output->template mutable_data<T>();
// std::vector<int64_t> output_shape(
// {input_dims[0],
// input_dims[1] * kernel_sizes[0] * kernel_sizes[1],
// output_height * output_width});
// output->Resize(output_shape);

DDim input_shape({input_dims[1], input_dims[2], input_dims[3]});
DDim output_matrix_shape({input_dims[1],
Expand Down Expand Up @@ -136,15 +136,15 @@ REGISTER_LITE_KERNEL(unfold, kHost, kFloat, kNCHW, unfold_float, def)
.Finalize();

using unfold_int32 =
paddle::lite::kernels::host::UnfoldCompute<int, PRECISION(kInt32)>;
REGISTER_LITE_KERNEL(unfold, kHost, kInt32, kNCHW, unfold_int32, def_int32)
paddle::lite::kernels::host::UnfoldCompute<int, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(unfold, kHost, kFloat, kNCHW, unfold_int32, def_int32)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt32))})
.Finalize();

using unfold_int64 =
paddle::lite::kernels::host::UnfoldCompute<int64_t, PRECISION(kInt64)>;
REGISTER_LITE_KERNEL(unfold, kHost, kInt64, kNCHW, unfold_int64, def_int64)
paddle::lite::kernels::host::UnfoldCompute<int64_t, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(unfold, kHost, kFloat, kNCHW, unfold_int64, def_int64)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kHost), PRECISION(kInt64))})
.Finalize();
Expand Down
6 changes: 3 additions & 3 deletions lite/operators/unfold_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ namespace lite {
namespace operators {

bool UnfoldOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Y);
CHECK(param_.X);
CHECK(param_.Y);

const auto x_dims = param_.X->dims();
CHECK_OR_FALSE(x_dims.size() == 4);
CHECK_EQ(x_dims.size(), 4);
return true;
}

Expand Down
2 changes: 1 addition & 1 deletion lite/tests/unittest_py/op/test_unfold_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def sample_program_configs(self, draw):
H = draw(st.integers(min_value=2, max_value=64))
W = draw(st.integers(min_value=2, max_value=64))
in_shape = draw(st.sampled_from([[N, C, H, W]]))
in_dtype = draw(st.sampled_from([np.float32]))
in_dtype = draw(st.sampled_from([np.float32, ]))

def generate_X_data():
return np.random.normal(0.0, 5.0, in_shape).astype(in_dtype)
Expand Down

0 comments on commit 627c42d

Please sign in to comment.