Skip to content

Commit

Permalink
Merge lars op (#35476)
Browse files Browse the repository at this point in the history
* A leap of try for cudaLaunchCooperativeKernel

* fix bugs

* Totally replace the lar cuda kernel

* Fix bugs

* a test for lars merge

* Adding las_op_momentum infer_shape

* Fix codes

* use avg_numel instead of max_numel to acquire grid num

* modify unittest files about lars op

* Finally converge when merged-lars works

* fix ctest files

* add merged_operation kernel when cuda version is older than 11

* Fix code style

* fix ctest failure

* fix error

* fix all ctest error and change lars compute code of cpu

* fix bugs on v100.

* revert python modififation about lars

* revert python modification codes
  • Loading branch information
JamesLim-sy authored Oct 13, 2021
1 parent 2441847 commit 0c31579
Show file tree
Hide file tree
Showing 6 changed files with 594 additions and 302 deletions.
140 changes: 126 additions & 14 deletions paddle/fluid/operators/optimizers/lars_momentum_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,46 +13,158 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/operators/optimizers/momentum_op.h"

namespace paddle {
namespace operators {

class LarsMomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

protected:
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("Param"), "Input", "Param", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Grad"), "Input", "Grad", "LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("Velocity"), "Input", "Velocity",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasInputs("LearningRate"), "Input", "LearningRate",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("ParamOut"), "Output", "ParamOut",
"LarsMomentum");
OP_INOUT_CHECK(ctx->HasOutputs("VelocityOut"), "Output", "VelocityOut",
"LarsMomentum");
PADDLE_ENFORCE_EQ(
ctx->GetInputsVarType("Param").front(),
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The input var's type should be LoDTensor, but the received is %s",
ctx->GetInputsVarType("Param").front()));

auto lr_dims = ctx->GetInputsDim("LearningRate");
auto grad_dim = ctx->GetInputsDim("Grad");
auto param_dim = ctx->GetInputsDim("Param");
auto velocity_dim = ctx->GetInputsDim("Velocity");
auto lars_weight_decays =
ctx->Attrs().Get<std::vector<float>>("lars_weight_decay");
auto multi_precision = ctx->Attrs().Get<bool>("multi_precision");

PADDLE_ENFORCE_EQ(
param_dim.size(), grad_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) of LarsMomentumOp should have "
"same quantity. But number of Param is [%d] and Grad is [%d].",
param_dim.size(), grad_dim.size()));
PADDLE_ENFORCE_EQ(
param_dim.size(), velocity_dim.size(),
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp should "
"have same quantity. But number of Param is [%d] and Velocity "
"is [%d].",
param_dim.size(), velocity_dim.size()));
PADDLE_ENFORCE_EQ(
lars_weight_decays.size(), grad_dim.size(),
platform::errors::InvalidArgument(
"Attr(Lars_weight_decay) and "
"Input(Grad) of LarsMomentumOp should have same quantity. "
"But number of Lars_weight_decay is [%d] and Grad is [%d].",
lars_weight_decays.size(), grad_dim.size()));

if (multi_precision) {
OP_INOUT_CHECK(ctx->HasInputs("MasterParam"), "Input", "MasterParam",
"LarsMomentumMultiPrecision");
OP_INOUT_CHECK(ctx->HasOutputs("MasterParamOut"), "Output",
"MasterParamOut", "LarsMomentumMultiPrecision");
}
for (size_t i = 0; i < lr_dims.size(); ++i) {
PADDLE_ENFORCE_EQ(framework::product(lr_dims[i]), 1,
platform::errors::InvalidArgument(
"Learning_rate should be a scalar. But Received "
"LearningRate's dim [%s]",
framework::product(lr_dims[i])));
}

for (size_t i = 0; i < param_dim.size(); ++i) {
PADDLE_ENFORCE_EQ(ctx->GetInputsVarType("Grad")[i],
framework::proto::VarType::LOD_TENSOR,
platform::errors::InvalidArgument(
"The Var(%s)'s type should be LoDTensor, "
"but the received is %s",
ctx->Inputs("Grad")[i].front(),
ctx->GetInputsVarType("Grad")[i]));
PADDLE_ENFORCE_EQ(
param_dim[i], grad_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Grad) input of LarsMomentumOp shall "
"have same dimension. But Param`s dim is [%s] and Grad's dim "
"is [%s].",
param_dim[i], grad_dim[i]));
PADDLE_ENFORCE_EQ(
param_dim[i], velocity_dim[i],
platform::errors::InvalidArgument(
"Input(Param) and Input(Velocity) of LarsMomentumOp shall have "
"same dimension. But Param dim [%s] differs with Velocity dim "
"[%s].",
param_dim[i], velocity_dim[i]));
}
ctx->SetOutputsDim("ParamOut", param_dim);
ctx->SetOutputsDim("VelocityOut", param_dim);
if (ctx->HasOutputs("MasterParamOut")) {
ctx->SetOutputsDim("MasterParamOut", param_dim);
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto input_data_type =
OperatorWithKernel::IndicateVarDataType(ctx, "Param");
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};

class LarsMomentumOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Param",
"(LoDTensor, default LoDTensor<float>) "
"Input parameter that has to be updated");
"Input parameter that has to be updated")
.AsDuplicable();
AddInput("Grad",
"(LoDTensor, default LoDTensor<float>) "
"Input gradient of the parameter");
"Input gradient of the parameter")
.AsDuplicable();
AddInput("Velocity",
"(LoDTensor, default LoDTensor<float>) "
"Input velocity (corresponding to the parameter) "
"that has to be updated");
"that has to be updated")
.AsDuplicable();
AddInput("LearningRate",
"(LoDTensor, default LoDTensor<float>) "
"Input learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();

"Input learning rate")
.AsDuplicable();
AddInput("MasterParam", "FP32 master weight for AMP.")
.AsDuplicable()
.AsDispensable();
AddOutput("ParamOut",
"(LoDTensor) This output is updated parameter. "
"It shared memory with Input(Param).");
"It shared memory with Input(Param).")
.AsDuplicable();
AddOutput("VelocityOut",
"(LoDTensor) This output is updated velocity. "
"It shared memory with Input(Velocity).");
"It shared memory with Input(Velocity).")
.AsDuplicable();
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDuplicable()
.AsDispensable();

AddAttr<float>("mu", "(float) Momentum coefficient");
AddAttr<float>("lars_coeff", "(float, default 0.001) LARS coefficient.")
.SetDefault(0.001);
AddAttr<float>("lars_weight_decay",
"(float, default 0.0005) LARS weight decay")
.SetDefault(0.0005);
AddAttr<std::vector<float>>(
"lars_weight_decay",
"(std::vector<float>, default 0.0005) LARS weight decay params")
.SetDefault({0.0005});
AddAttr<float>("epsilon",
"(float, default 0.0) epsilon to avoid Division by Zero.")
.SetDefault(0.0);
Expand Down Expand Up @@ -96,7 +208,7 @@ class LarsMomentumOpVarTypeInference : public framework::VarTypeInference {

namespace ops = paddle::operators;
REGISTER_OPERATOR(
lars_momentum, ops::MomentumOp, ops::LarsMomentumOpMaker,
lars_momentum, ops::LarsMomentumOp, ops::LarsMomentumOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LarsMomentumOpVarTypeInference);
Expand Down
Loading

0 comments on commit 0c31579

Please sign in to comment.