Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG committed Mar 1, 2023
1 parent c930994 commit f1799fb
Show file tree
Hide file tree
Showing 14 changed files with 681 additions and 74 deletions.
9 changes: 9 additions & 0 deletions paddle/fluid/operators/optimizers/adagrad_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,23 @@ class AdagradOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Grad", "(Tensor) Input gradient");
AddInput("Moment", "(Tensor) Second moment");
AddInput("LearningRate", "(Tensor) Learning rate");
AddInput("MasterParam", "FP32 master weight for AMP.").AsDispensable();

AddOutput("ParamOut", "(Tensor) Output parameter");
AddOutput("MomentOut", "(Tensor) Output second moment");
AddOutput("MasterParamOut",
"The updated FP32 master weight for AMP. "
"It shared memory with Input(MasterParam).")
.AsDispensable();

AddAttr<float>("epsilon",
"(float, default 1.0e-6) "
"Constant for numerical stability")
.SetDefault(1.0e-6f);
AddAttr<bool>("multi_precision",
"(bool, default false) "
"Whether to use multi-precision during weight updating.")
.SetDefault(false);
AddComment(R"DOC(
Adaptive Gradient Algorithm (Adagrad).
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pybind/eager_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{"sparse_attention",
{"Q", "K", "V", "Offset", "Columns", "KeyPaddingMask", "AttnMask"}},
{"sgd", {"Param", "LearningRate", "Grad", "MasterParam"}},
{"adagrad", {"Param", "Grad", "Moment", "LearningRate", "MasterParam"}},
{"graph_khop_sampler", {"Row", "Eids", "Col_Ptr", "X"}},
{"nce",
{"Input",
Expand Down Expand Up @@ -361,6 +362,7 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"Beta2PowOut",
"MasterParamOut"}},
{"sgd", {"ParamOut", "MasterParamOut"}},
{"adagrad", {"ParamOut", "MomentOut", "MasterParamOut"}},
{"lamb",
{"ParamOut",
"Moment1Out",
Expand Down Expand Up @@ -399,7 +401,7 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
"MasterParamOut"}},
{"ftrl", {"ParamOut", "SquaredAccumOut", "LinearAccumOut"}},
{"adadelta", {"ParamOut", "AvgSquaredGradOut", "AvgSquaredUpdateOut"}},
{"adagrad", {"ParamOut", "MomentOut"}},
{"adagrad", {"ParamOut", "MomentOut", "MasterParamOut"}},
{"adamax", {"ParamOut", "MomentOut", "InfNormOut"}},
{"dpsgd", {"ParamOut"}},
{"decayed_adagrad", {"ParamOut", "MomentOut"}},
Expand Down
11 changes: 6 additions & 5 deletions paddle/phi/api/yaml/legacy_ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,16 @@
inplace : (param -> param_out), (avg_squared_grad -> moment_out), (avg_squared_update -> inf_norm_out)

- op : adagrad_
args : (Tensor param, Tensor grad, Tensor moment, Tensor learning_rate, float epsilon)
output : Tensor(param_out), Tensor(moment_out)
args : (Tensor param, Tensor grad, Tensor moment, Tensor learning_rate, Tensor master_param, float epsilon, bool multi_precision)
output : Tensor(param_out), Tensor(moment_out), Tensor(master_param_out)
infer_meta :
func : AdagradInferMeta
kernel :
func : adagrad {dense, dense, dense, dense -> dense, dense}
adagrad_dense_param_sparse_grad {dense, selected_rows, dense, dense -> dense, dense}
func : adagrad {dense, dense, dense, dense, dense -> dense, dense, dense}
adagrad_dense_param_sparse_grad {dense, selected_rows, dense, dense, dense-> dense, dense, dense}
data_type : param
inplace : (param -> param_out), (moment -> moment_out)
optional : master_param
inplace : (param -> param_out), (moment -> moment_out), (master_param -> master_param_out)

- op : adam_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor moment1, Tensor moment2, Tensor beta1_pow, Tensor beta2_pow, Tensor master_param, Tensor skip_update, Scalar beta1, Scalar beta2, Scalar epsilon, bool lazy_mode, int64_t min_row_size_to_use_multithread, bool multi_precision, bool use_global_beta_pow)
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,12 @@ void AdagradInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& moment,
const MetaTensor& learning_rate,
const MetaTensor& master_param,
float epsilon,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* moment_out) {
MetaTensor* moment_out,
MetaTensor* master_param_out) {
auto lr_dims = learning_rate.dims();
PADDLE_ENFORCE_EQ(
phi::product(lr_dims),
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,12 @@ void AdagradInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& moment,
const MetaTensor& learning_rate,
const MetaTensor& master_param,
float epsilon,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* moment_out);
MetaTensor* moment_out,
MetaTensor* master_param_out);

void AdamaxInferMeta(const MetaTensor& param,
const MetaTensor& grad,
Expand Down
10 changes: 8 additions & 2 deletions paddle/phi/kernels/adagrad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,24 @@ void AdagradDenseKernel(const Context& dev_ctx,
const DenseTensor& grad,
const DenseTensor& moment,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment_out);
DenseTensor* moment_out,
DenseTensor* master_param_outs);

template <typename T, typename Context>
void AdagradSparseKernel(const Context& dev_ctx,
const DenseTensor& param,
const SelectedRows& grad,
const DenseTensor& moment,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment_out);
DenseTensor* moment_out,
DenseTensor* master_param_outs);

} // namespace phi
38 changes: 38 additions & 0 deletions paddle/phi/kernels/cpu/adagrad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,42 @@ size_t FindPos(const std::vector<int64_t>& rows, int64_t value) {
}
} // namespace

template <typename T>
struct DenseAdagradFunctor<phi::CPUContext, T> {
void operator()(const phi::CPUContext& ctx,
const DenseTensor& param_t,
const DenseTensor& grad_t,
const DenseTensor& moment_t,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor,
DenseTensor* master_param_outs) {
ctx.template Alloc<T>(param_out_tensor);
ctx.template Alloc<T>(moment_out_tensor);

T epsilon = static_cast<T>(epsilon_t);

auto param = EigenVector<T>::Flatten(param_t);

auto grad = EigenVector<T>::Flatten(grad_t);

auto moment = EigenVector<T>::Flatten(moment_t);

auto param_out = EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = EigenVector<T>::Flatten(*moment_out_tensor);
auto place = *ctx.eigen_device();

moment_out.device(place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
auto* lr = learning_rate.data<T>();
param_out.device(place) =
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
}
};

template <typename T>
struct SparseAdagradFunctor<phi::CPUContext, T> {
void operator()(const phi::CPUContext& context,
Expand Down Expand Up @@ -67,6 +103,8 @@ struct SparseAdagradFunctor<phi::CPUContext, T> {

template struct SparseAdagradFunctor<phi::CPUContext, float>;
template struct SparseAdagradFunctor<phi::CPUContext, double>;
template struct DenseAdagradFunctor<phi::CPUContext, float>;
template struct DenseAdagradFunctor<phi::CPUContext, double>;

} // namespace phi

Expand Down
89 changes: 86 additions & 3 deletions paddle/phi/kernels/gpu/adagrad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,91 @@
// limitations under the License.

#include "paddle/phi/kernels/adagrad_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/selected_rows_functor.h"
#include "paddle/phi/kernels/impl/adagrad_kernel_impl.h"

namespace phi {

template <typename T, typename MT>
__global__ void AdagradGPUKernel(const T* param,
const T* grad,
const MT* moment,
const MT* lr,
const MT* master_param,
MT epsilon,
T* param_out,
MT* moment_out,
MT* master_param_out,
int num) {
auto idx = blockDim.x * blockIdx.x + threadIdx.x;
MT lr_data = static_cast<T>(lr[0]);

for (int i = idx; i < num; i += blockDim.x * gridDim.x) {
MT grad_data = static_cast<MT>(grad[i]);
MT moment_out_data = static_cast<MT>(moment[i]) + grad_data * grad_data;
moment_out[i] = static_cast<MT>(moment_out_data);
auto in = master_param_out ? master_param[i] : static_cast<MT>(param[i]);
MT param_out_data =
in - (lr_data * grad_data) / (sqrt(moment_out_data) + epsilon);

param_out[i] = static_cast<MT>(param_out_data);

if (master_param_out) {
master_param_out[i] = param_out_data;
}
}
}

template <typename T>
struct DenseAdagradFunctor<phi::GPUContext, T> {
void operator()(const phi::GPUContext& ctx,
const DenseTensor& param_t,
const DenseTensor& grad_t,
const DenseTensor& moment_t,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor,
DenseTensor* master_param_outs) {
using MPDType = typename phi::dtype::template MPTypeTrait<T>::Type;
T* param_out_data = ctx.template Alloc<T>(param_out_tensor);
MPDType* moment_out_data = ctx.template Alloc<MPDType>(moment_out_tensor);
const MPDType* master_in_data =
multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision ? ctx.template Alloc<MPDType>(master_param_outs)
: nullptr;

MPDType epsilon = static_cast<MPDType>(epsilon_t);

int numel = param_t.numel();
auto config = phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, 1);
int grid = config.block_per_grid.x;
int block = config.thread_per_block.x;
auto stream = ctx.stream();
AdagradGPUKernel<T, MPDType>
<<<block, grid, 0, stream>>>(param_t.data<T>(),
grad_t.data<T>(),
moment_t.data<MPDType>(),
learning_rate.data<MPDType>(),
master_in_data,
epsilon,
param_out_data,
moment_out_data,
master_out_data,
numel);
}
};

template <typename T, int block_size>
__global__ void MergeGradKernel(const T* grad,
const int64_t* grad_rows,
Expand Down Expand Up @@ -123,11 +198,19 @@ struct SparseAdagradFunctor<phi::GPUContext, T> {

template struct SparseAdagradFunctor<phi::GPUContext, float>;
template struct SparseAdagradFunctor<phi::GPUContext, double>;
template struct DenseAdagradFunctor<phi::GPUContext, float>;
template struct DenseAdagradFunctor<phi::GPUContext, double>;
template struct DenseAdagradFunctor<phi::GPUContext, phi::dtype::float16>;

} // namespace phi

PD_REGISTER_KERNEL(
adagrad, GPU, ALL_LAYOUT, phi::AdagradDenseKernel, float, double) {}
PD_REGISTER_KERNEL(adagrad,
GPU,
ALL_LAYOUT,
phi::AdagradDenseKernel,
float,
double,
phi::dtype::float16) {}

PD_REGISTER_KERNEL(adagrad_dense_param_sparse_grad,
GPU,
Expand Down
63 changes: 35 additions & 28 deletions paddle/phi/kernels/impl/adagrad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,21 @@ struct SparseAdagradFunctor {
DenseTensor* param);
};

template <typename DeviceContext, typename T>
struct DenseAdagradFunctor {
void operator()(const DeviceContext& ctx,
const DenseTensor& param_t,
const DenseTensor& grad_t,
const DenseTensor& moment_t,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor,
DenseTensor* master_param_outs);
};

template <typename DeviceContext, typename T>
phi::SelectedRows SquareSelectedRows(const DeviceContext& context,
const phi::SelectedRows& input) {
Expand All @@ -50,35 +65,24 @@ void AdagradDenseKernel(const Context& ctx,
const DenseTensor& grad_t,
const DenseTensor& moment_t,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor) {
ctx.template Alloc<T>(param_out_tensor);
ctx.template Alloc<T>(moment_out_tensor);

T epsilon = static_cast<T>(epsilon_t);

auto param = EigenVector<T>::Flatten(param_t);

auto grad = EigenVector<T>::Flatten(grad_t);

auto moment = EigenVector<T>::Flatten(moment_t);

auto param_out = EigenVector<T>::Flatten(*param_out_tensor);
auto moment_out = EigenVector<T>::Flatten(*moment_out_tensor);
auto place = *ctx.eigen_device();

moment_out.device(place) = moment + grad * grad;
Eigen::DSizes<int, 1> m_dsize(moment_out_tensor->numel());
if (paddle::platform::is_cpu_place(ctx.GetPlace())) {
auto* lr = learning_rate.data<T>();
param_out.device(place) =
param - lr[0] * grad / (moment_out.sqrt() + epsilon);
} else {
auto lr = EigenVector<T>::Flatten(learning_rate);
param_out.device(place) =
param - lr.broadcast(m_dsize) * grad / (moment_out.sqrt() + epsilon);
}
DenseTensor* moment_out_tensor,
DenseTensor* master_param_outs) {
DenseAdagradFunctor<Context, T> functor;
functor(ctx,
param_t,
grad_t,
moment_t,
learning_rate,
master_param,
epsilon_t,
multi_precision,
param_out_tensor,
moment_out_tensor,
master_param_outs);
}

template <typename T, typename Context>
Expand All @@ -87,9 +91,12 @@ void AdagradSparseKernel(const Context& ctx,
const SelectedRows& grad_t,
const DenseTensor& moment_t,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* moment_out) {
DenseTensor* moment_out,
DenseTensor* master_param_outs) {
auto* param_out_tensor = param_out;
auto* moment_out_tensor = moment_out;

Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/kernels/xpu/adagrad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ void AdagradDenseKernel(const Context& ctx,
const DenseTensor& grad,
const DenseTensor& moment,
const DenseTensor& learning_rate,
const paddle::optional<DenseTensor>& master_param,
float epsilon_t,
bool multi_precision,
DenseTensor* param_out_tensor,
DenseTensor* moment_out_tensor) {
DenseTensor* moment_out_tensor,
DenseTensor* master_param_outs) {
ctx.template Alloc<T>(param_out_tensor);
ctx.template Alloc<T>(moment_out_tensor);

Expand Down
Loading

0 comments on commit f1799fb

Please sign in to comment.