Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[x86] add instance norm #5860

Merged
merged 1 commit into from
Apr 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lite/backends/x86/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ if(WITH_AVX AND AVX_FOUND)
math_library(conv_utils AVX2 TRUE)
math_library(conv_depthwise_pack8 AVX2 TRUE)
math_library(conv_depthwise_pack4 AVX2 TRUE)
math_library(instance_norm AVX2 TRUE)
endif()
math_library(im2col)
math_library(sample_prob)
Expand Down
173 changes: 173 additions & 0 deletions lite/backends/x86/math/instance_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/backends/x86/math/instance_norm.h"
#include <immintrin.h>
#include <cmath>

namespace paddle {
namespace lite {
namespace x86 {
namespace math {

void instance_norm(const float* in,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有快速计算方法,PrepareForRun函数中计算好 new_scale 和 new_bias,在Run函数中只计算 new_scale * in + new_bias 即可。

float* out,
const int n,
const int c,
const int height,
const int width,
const float epsilon,
const float* scale,
const float* bias,
float* saved_mean,
float* saved_variance) {
int nc = n * c;
int spatial_size = height * width;

// compute saved_mean and saved_variance
#pragma omp parallel for
for (int i = 0; i < nc; ++i) {
const float* in_p = in + i * spatial_size;
float sum_spatial = 0.f;
float summ_spatial = 0.f;
for (int h = 0; h < height; ++h) {
int w = width;

__m128 sum0 = _mm_set1_ps(0.f);
__m128 sum1 = _mm_set1_ps(0.f);
__m128 sum2 = _mm_set1_ps(0.f);
__m128 sum3 = _mm_set1_ps(0.f);
__m128 square_sum0 = _mm_set1_ps(0.f);
__m128 square_sum1 = _mm_set1_ps(0.f);
__m128 square_sum2 = _mm_set1_ps(0.f);
__m128 square_sum3 = _mm_set1_ps(0.f);
__m128 in0, in1, in2, in3;
for (; w > 15; w -= 16) {
in0 = _mm_loadu_ps(in_p);
in1 = _mm_loadu_ps(in_p + 4);
in2 = _mm_loadu_ps(in_p + 8);
in3 = _mm_loadu_ps(in_p + 12);
// add x
sum0 = _mm_add_ps(sum0, in0);
sum1 = _mm_add_ps(sum1, in1);
sum2 = _mm_add_ps(sum2, in2);
sum3 = _mm_add_ps(sum3, in3);
// add x * x
square_sum0 = _mm_fmadd_ps(in0, in0, square_sum0);
square_sum1 = _mm_fmadd_ps(in1, in1, square_sum1);
square_sum2 = _mm_fmadd_ps(in2, in2, square_sum2);
square_sum3 = _mm_fmadd_ps(in3, in3, square_sum3);

in_p += 16;
}
for (; w > 7; w -= 8) {
in0 = _mm_loadu_ps(in_p);
in1 = _mm_loadu_ps(in_p + 4);
sum0 = _mm_add_ps(sum0, in0);
sum1 = _mm_add_ps(sum1, in1);
square_sum0 = _mm_fmadd_ps(in0, in0, square_sum0);
square_sum1 = _mm_fmadd_ps(in1, in1, square_sum1);
in_p += 8;
}
for (; w > 3; w -= 4) {
in0 = _mm_loadu_ps(in_p);
sum0 = _mm_add_ps(sum0, in0);
square_sum0 = _mm_fmadd_ps(in0, in0, square_sum0);
in_p += 4;
}
float sum = 0.f;
float summ = 0.f;
for (; w > 0; w--) {
sum += *in_p;
summ += (*in_p) * (*in_p);
in_p++;
}

sum0 = _mm_add_ps(sum0, sum1);
sum2 = _mm_add_ps(sum2, sum3);
square_sum0 = _mm_add_ps(square_sum0, square_sum1);
square_sum2 = _mm_add_ps(square_sum2, square_sum3);

sum0 = _mm_add_ps(sum0, sum2);
square_sum0 = _mm_add_ps(square_sum0, square_sum2);

__m128 r = _mm_hadd_ps(sum0, square_sum0);
r = _mm_hadd_ps(r, r);
float buf[4];
_mm_storeu_ps(buf, r);
sum += buf[0];
summ += buf[1];
sum_spatial += sum;
summ_spatial += summ;
}
float mean = sum_spatial / spatial_size;
// float variance = summ / spatial_size - mean * mean;
// the flolowing code has higher precision than above comment code
float variance = (summ_spatial - mean * mean * spatial_size) / spatial_size;
float std = 1.f / sqrtf(variance + epsilon);

saved_mean[i] = mean;
saved_variance[i] = std;
}
// compute instance_norm result: out = scale * (in - mean) / std + bias
#pragma omp parallel for
for (int i = 0; i < nc; ++i) {
const float* in_p = in + i * spatial_size;
float* out_p = out + i * spatial_size;
int j = spatial_size;
const float sstd_val =
scale == nullptr ? saved_variance[i] : scale[i % c] * saved_variance[i];
const float bias_val = bias == nullptr ? 0. : bias[i % c];
const float mean_val = saved_mean[i];
const __m128 vsstd = _mm_set1_ps(sstd_val);
const __m128 vbias = _mm_set1_ps(bias_val);
const __m128 vmean = _mm_set1_ps(mean_val);
__m128 in0, in1, submean0, submean1, out0, out1;

for (; j > 7; j -= 8) {
in0 = _mm_loadu_ps(in_p);
in1 = _mm_loadu_ps(in_p + 4);
submean0 = _mm_sub_ps(in0, vmean);
submean1 = _mm_sub_ps(in1, vmean);
out0 = _mm_fmadd_ps(submean0, vsstd, vbias);
out1 = _mm_fmadd_ps(submean1, vsstd, vbias);

_mm_storeu_ps(out_p, out0);
_mm_storeu_ps(out_p + 4, out1);

in_p += 8;
out_p += 8;
}
for (; j > 3; j -= 4) {
in0 = _mm_loadu_ps(in_p);
submean0 = _mm_sub_ps(in0, vmean);
out0 = _mm_fmadd_ps(submean0, vsstd, vbias);

_mm_storeu_ps(out_p, out0);

in_p += 4;
out_p += 4;
}
for (; j > 0; j--) {
*out_p = (*in_p - mean_val) * sstd_val + bias_val;
in_p++;
out_p++;
}
}
}

} // namespace math
} // namespace x86
} // namespace lite
} // namespace paddle
37 changes: 37 additions & 0 deletions lite/backends/x86/math/instance_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

namespace paddle {
namespace lite {
namespace x86 {
namespace math {

void instance_norm(const float* in,
float* out,
const int n,
const int c,
const int height,
const int width,
const float epsilon,
const float* scale,
const float* bias,
float* saved_mean,
float* saved_variance);

} // namespace math
} // namespace x86
} // namespace lite
} // namespace paddle
7 changes: 4 additions & 3 deletions lite/kernels/x86/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ add_kernel(cast_compute_x86 X86 basic SRCS cast_compute.cc DEPS ${lite_kernel_de
add_kernel(slice_compute_x86 X86 basic SRCS slice_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reshape_compute_x86 X86 basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
if(WITH_AVX AND AVX_FOUND)
add_kernel(conv_depthwise_x86 X86 basic SRCS conv_depthwise.cc DEPS ${lite_kernel_deps} conv_utils conv_depthwise_pack8 conv_depthwise_pack4)
add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col conv_depthwise_x86 conv_bias)
add_kernel(conv_depthwise_x86 X86 basic SRCS conv_depthwise.cc DEPS ${lite_kernel_deps} conv_utils conv_depthwise_pack8 conv_depthwise_pack4)
add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col conv_depthwise_x86 conv_bias)
add_kernel(instance_norm_compute_x86 X86 basic SRCS instance_norm_compute.cc DEPS ${lite_kernel_deps} instance_norm)
else()
add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col conv_bias)
add_kernel(conv_compute_x86 X86 basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} blas im2col vol2col conv_bias)
endif()
# lite_cc_library(elementwise_compute_x86 SRCS elementwise_compute.cc DEPS ${lite_kernel_deps} elementwise_sub_op elementwise_add_op)
# lite_cc_library(softmax_compute_x86 SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
Expand Down
76 changes: 76 additions & 0 deletions lite/kernels/x86/instance_norm_compute.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/kernels/x86/instance_norm_compute.h"
#include <immintrin.h>
#include <cmath>
#include "lite/backends/x86/math/instance_norm.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

void InstanceNormCompute::PrepareForRun() {}

void InstanceNormCompute::Run() {
auto& param = this->Param<param_t>();
const float* in = param.x->data<float>();
const float* scale =
param.scale == nullptr ? nullptr : param.scale->data<float>();
const float* bias =
param.bias == nullptr ? nullptr : param.bias->data<float>();
float* out = param.out->mutable_data<float>();
float* saved_mean = param.saved_mean->mutable_data<float>();
float* saved_variance = param.saved_variance->mutable_data<float>();
float epsilon = param.epsilon;

int n = param.x->dims()[0];
int c = param.x->dims()[1];
int height = param.x->dims()[2];
int width = param.x->dims()[3];

lite::x86::math::instance_norm(in,
out,
n,
c,
height,
width,
epsilon,
scale,
bias,
saved_mean,
saved_variance);
}

} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle

REGISTER_LITE_KERNEL(instance_norm,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::InstanceNormCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还需要增加MeanOut, ReserveSpace, VarianceOut 这三个 arg_name;具体可以参考 arm backend 中的 bn kernel 注册。

.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("SavedMean", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("SavedVariance", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
40 changes: 40 additions & 0 deletions lite/kernels/x86/instance_norm_compute.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {

class InstanceNormCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::InstanceNormParam;

void PrepareForRun() override;

void Run() override;

virtual ~InstanceNormCompute() = default;

private:
};

} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
2 changes: 2 additions & 0 deletions lite/tests/kernels/instance_norm_compute_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ TEST(InstanceNorm, precision) {
ignored_outs = {"saved_mean", "saved_variance"};
#elif defined(LITE_WITH_ARM)
place = TARGET(kARM);
#elif defined(LITE_WITH_X86)
place = TARGET(kX86);
#else
return;
#endif
Expand Down