Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move layer norm to phi #40193

Merged
merged 26 commits into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
6064b49
update
phlrain Mar 3, 2022
09a40fa
fix bugs; test=develop
phlrain Mar 6, 2022
3b26bf0
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 6, 2022
6f2833d
update; test=develop
phlrain Mar 6, 2022
4b1cb8a
fix test compile error; test=develop
phlrain Mar 6, 2022
13d4751
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 6, 2022
f533350
fix cpu compile error; test=develop
phlrain Mar 6, 2022
521f190
fix test error; test=develo
phlrain Mar 7, 2022
a1af264
fix layer_norm_op plugin error; test=develop
phlrain Mar 7, 2022
79e3f4b
fix error; test=develop
phlrain Mar 7, 2022
6ebdc83
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 7, 2022
9504426
fix test bug; test=develop
phlrain Mar 8, 2022
4fcad84
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 8, 2022
7bbbec0
Merge branch 'develop' into move_layer_norm_to_phi
phlrain Mar 10, 2022
334427c
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 11, 2022
88d15dc
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 11, 2022
09bf412
update; test=develop
phlrain Mar 11, 2022
63bc6b5
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 11, 2022
7dd3e9d
polish code; test=develop
phlrain Mar 11, 2022
39eb6ff
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 11, 2022
c527480
fix bugs; test=develop
phlrain Mar 12, 2022
7b3681a
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 12, 2022
e46bd87
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 15, 2022
72a18c9
remove unused depency; test=develop
phlrain Mar 15, 2022
fa6c9e6
Merge branch 'develop' of /~https://github.com/PaddlePaddle/Paddle into…
phlrain Mar 15, 2022
871f90b
polish code; test=develop
phlrain Mar 16, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/inference/tensorrt/convert/layer_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"

#include "paddle/fluid/inference/tensorrt/convert/op_converter.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/inference/tensorrt/plugin/layer_norm_op_plugin.h"
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"

namespace paddle {
namespace inference {
Expand Down Expand Up @@ -83,7 +83,7 @@ int LayerNormPlugin::enqueue(int batch_size, const void *const *inputs,
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream);

paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
phi::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps);
return cudaGetLastError() != cudaSuccess;
Expand Down Expand Up @@ -177,7 +177,7 @@ int LayerNormPluginDynamic::enqueue(
cudaMemcpyAsync(bias_d, bias_.data(), sizeof(float) * feature_size,
cudaMemcpyHostToDevice, stream);

paddle::operators::LayerNormDirectCUDAFunctor<float> layer_norm;
phi::LayerNormDirectCUDAFunctor<float> layer_norm;
layer_norm(stream, input, input_shape, bias_d, scale_d, output, mean_d,
variance_d, begin_norm_axis, eps);
} else {
Expand Down
24 changes: 15 additions & 9 deletions paddle/fluid/operators/fused/fused_dropout_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,16 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/string/printf.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/layer_norm_kernel.h"

namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace memory = paddle::memory;

USE_OP_ITSELF(dropout);
USE_OP(layer_norm);
USE_OP_ITSELF(layer_norm);

template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
Expand Down Expand Up @@ -136,18 +138,23 @@ void LayerNorm(const std::vector<LayerNormParamType<T>> &scale,
const platform::CUDADeviceContext &ctx) {
framework::Scope scope;
auto place = ctx.GetPlace();
paddle::optional<const framework::LoDTensor &> scale_opt = paddle::none;
if (scale.size() > 0) {
auto var_scale = scope.Var("Scale");
auto tensor_scale = var_scale->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(scale, ctx, tensor_scale);
tensor_scale->Resize({cols});
scale_opt = *tensor_scale;
}

paddle::optional<const framework::LoDTensor &> bias_opt = paddle::none;
if (bias.size() > 0) {
auto var_bias = scope.Var("Bias");
auto tensor_bias = var_bias->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(bias, ctx, tensor_bias);
tensor_bias->Resize({cols});

bias_opt = *tensor_bias;
}

auto var_x = scope.Var("X");
Expand All @@ -157,20 +164,19 @@ void LayerNorm(const std::vector<LayerNormParamType<T>> &scale,

auto var_y = scope.Var("Y");
auto tensor_y = var_y->GetMutable<framework::LoDTensor>();
tensor_y->Resize({rows, cols});

auto var_mean = scope.Var("Mean");
auto tensor_mean = var_mean->GetMutable<framework::LoDTensor>();
tensor_mean->Resize({rows});

auto var_variance = scope.Var("Variance");
auto tensor_variance = var_variance->GetMutable<framework::LoDTensor>();

framework::AttributeMap attrs;
attrs.insert({"epsilon", epsilon});

auto op = framework::OpRegistry::CreateOp(
"layer_norm", {{"X", {"X"}}, {"Scale", {"Scale"}}, {"Bias", {"Bias"}}},
{{"Y", {"Y"}}, {"Mean", {"Mean"}}, {"Variance", {"Variance"}}}, attrs);
op->Run(scope, place);
tensor_variance->Resize({rows});
ctx.Wait();
phi::LayerNormKernel<T>(static_cast<const phi::GPUContext &>(ctx), *tensor_x,
scale_opt, bias_opt, 1e-5, 1, false, tensor_y,
tensor_mean, tensor_variance);
framework::TensorToVector(*tensor_y, ctx, y);
framework::TensorToVector(*tensor_mean, ctx, means);
framework::TensorToVector(*tensor_variance, ctx, vars);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#include <time.h>

#include <iostream>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下边的调试信息删除了,这里也要删除

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

#include <random>
#include <vector>

Expand Down Expand Up @@ -192,7 +193,6 @@ struct TestFusedLayernormResidualDropoutBias {
residual_vec[i * cols + j] + out2[i * cols + j];
}
}

LayerNorm<T>(scale_vec, layernorm_bias_vec, correct_out, &correct_means,
&correct_vars, &correct_layernorm_out, epsilon, rows, cols,
*ctx);
Expand Down Expand Up @@ -264,6 +264,7 @@ struct TestFusedLayernormResidualDropoutBias {

template <typename T>
static void BaseTest(const bool is_fp16 = false) {
std::cerr << "1" << std::endl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个调试信息删除

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

const int rows = 16;
T default_diff = !is_fp16 ? static_cast<T>(1e-4) : static_cast<T>(1e-2);
for (auto cols : {16, 17}) {
Expand Down
17 changes: 9 additions & 8 deletions paddle/fluid/operators/layer_norm_kernel.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,14 @@ __global__ __launch_bounds__(THREADS_PER_CTA) void ln_bwd_1024_final_kernel(
*/
template <typename T, typename U, typename ScaleT = U,
typename MaskType = uint8_t>
void ln_bwd_1024_kernel_driver(
const platform::CUDADeviceContext &dev_ctx, const int rows, const int cols,
float epsilon, const T *x_ptr, const ScaleT *scale_ptr, const U *mean_ptr,
const U *var_ptr, const T *dout_ptr, T *dx_ptr, ScaleT *dscale_ptr,
ScaleT *dbias_ptr, const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0), T *d_dropout_src_ptr = nullptr) {
void ln_bwd_1024_kernel_driver(const phi::GPUContext &dev_ctx, const int rows,
const int cols, float epsilon, const T *x_ptr,
const ScaleT *scale_ptr, const U *mean_ptr,
const U *var_ptr, const T *dout_ptr, T *dx_ptr,
ScaleT *dscale_ptr, ScaleT *dbias_ptr,
const MaskType *mask_ptr = nullptr,
T factor = static_cast<T>(0),
T *d_dropout_src_ptr = nullptr) {
auto stream = dev_ctx.stream();
if (cols == 1024) {
// step-1: compute dx and reduced part results of dscale and dbias.
Expand Down Expand Up @@ -1334,8 +1336,7 @@ static void LayerNormBackward(
const U *mean, const U *var, T *d_x,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_scale,
LayerNormScaleBiasT<T, U, ScaleBiasWithSameTypeX> *d_bias, float epsilon,
int64_t batch_size, int64_t feature_size,
const platform::CUDADeviceContext &dev_ctx) {
int64_t batch_size, int64_t feature_size, const phi::GPUContext &dev_ctx) {
auto stream = dev_ctx.stream();
#ifdef __HIPCC__
const int kMaxBlockDim = 256;
Expand Down
10 changes: 1 addition & 9 deletions paddle/fluid/operators/layer_norm_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/layer_norm_op.h"

#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"

#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
Expand Down Expand Up @@ -278,10 +277,3 @@ REGISTER_OPERATOR(layer_norm, ops::LayerNormOp, ops::LayerNormOpMaker,
ops::LayerNormGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(layer_norm_grad, ops::LayerNormGradOp,
ops::LayerNormGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(
layer_norm, ops::LayerNormKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CPUDeviceContext, double>);
Loading