Skip to content

Commit

Permalink
[PaddlePaddle hackathon] + ADD CELU (#36088)
Browse files Browse the repository at this point in the history
* update

* update

* update

* try make CI pass

* doc typo

* update doc string
  • Loading branch information
JunnYu authored Oct 13, 2021
1 parent 0c31579 commit d7064f0
Show file tree
Hide file tree
Showing 11 changed files with 461 additions and 0 deletions.
74 changes: 74 additions & 0 deletions paddle/fluid/operators/activation_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,28 @@ Applies the following element-wise computation on the input according to
}
};

class CELUOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input is a multi-dimensional Tensor. The data type is "
"float32 or float64.");
AddOutput("Out",
"The output is a multi-dimensional Tensor which has same "
"dimension and data type as the ``x``.");
AddAttr<float>("alpha", "The alpha value of CELU").SetDefault(1.0f);
AddComment(R"DOC(
CELU Activation Operator.
Applies the following element-wise computation on the input according to
https://arxiv.org/abs/1704.07483.
$$out = \max(0, x) + \min(0, \alpha * (e^(x/\alpha) - 1))$$
)DOC");
}
};

class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
Expand Down Expand Up @@ -982,6 +1004,29 @@ class ELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
}
};

// celu grad: dx=dy if y>0 else dy*(x/alpha).exp()
// celu gradgrad: ddx=ddy if y>0 else ddy*(x/alpha).exp()/alpha
template <typename T>
class CELUDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("celu_grad_grad");

op->SetInput("X", this->Input("X"));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(this->Attrs());

// Out@GRAD@GRAD: ddy
op->SetOutput("DX", this->InputGrad("X"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
};

// sqrt Grad: dx = 0.5 * dy / y
// sqrt GradGrad: ddy = 0.5 * ddx / y, dy = -1 * dx * ddx
template <typename T>
Expand Down Expand Up @@ -1353,6 +1398,35 @@ REGISTER_OP_CPU_KERNEL(

/* ========================================================================== */

/* ======================== celu register ============================
*/
REGISTER_OPERATOR(
celu, ops::ActivationOp, ops::CELUOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::CELUGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::CELUGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
ops::ActFwdInplaceInferer);
REGISTER_OPERATOR(celu_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer,
ops::CELUDoubleGradMaker<paddle::framework::OpDesc>,
ops::CELUDoubleGradMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(
celu_grad_grad,
ops::ActivationOpDoubleGrad<ops::CELUGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);

REGISTER_ACTIVATION_CPU_KERNEL(celu, CELU, CELUFunctor, CELUGradFunctor);
REGISTER_OP_CPU_KERNEL(
celu_grad_grad, ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<float>>,
ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<double>>,
ops::CELUDoubleGradKernel<plat::CPUDeviceContext,
ops::CELUGradGradFunctor<plat::float16>>);

/* ========================================================================== */

/* =========================== sqrt register ============================= */
REGISTER_OPERATOR(
sqrt, ops::ActivationOp, ops::SqrtOpMaker, ops::ActivationOpInferVarType,
Expand Down
66 changes: 66 additions & 0 deletions paddle/fluid/operators/activation_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,59 @@ struct CudaELUGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CudaCELUFunctor : public BaseActivationFunctor<T> {
using CT = typename details::MPTypeTrait<T>::Type;
CT zero = static_cast<CT>(0.0f);
CT one = static_cast<CT>(1.0f);
float alpha;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}

// celu(x) = max(0, x) + min(0, alpha * (exp(x/alpha) - 1))
__device__ __forceinline__ T operator()(const T& arg_x) const {
CT x = static_cast<CT>(arg_x);
CT temp = static_cast<CT>(alpha) * (exp(x / static_cast<CT>(alpha)) - one);
CT res = (x > zero ? x : zero) + (temp > zero ? zero : temp);
return static_cast<T>(res);
}
};

template <typename T>
struct CudaCELUGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
MPType zero = static_cast<MPType>(0.0f);
MPType one = static_cast<MPType>(1.0f);
float alpha;

typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}

// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
__device__ __forceinline__ T operator()(const T& arg_dout,
const T& arg_x) const {
MPType dout = static_cast<MPType>(arg_dout);
MPType x = static_cast<MPType>(arg_x);
MPType a = static_cast<MPType>(alpha);
MPType temp_a_pos = static_cast<MPType>(alpha > 0.0f);
MPType temp_a_neg = static_cast<MPType>(alpha <= 0.0f);
MPType temp_x_pos = static_cast<MPType>(x > zero);
MPType temp_x_neg = static_cast<MPType>(x <= zero);
return static_cast<T>(
dout *
(temp_a_pos * temp_x_pos + temp_a_pos * temp_x_neg * exp(x / a) +
temp_a_neg * temp_x_pos + exp(x / a) * temp_a_neg * temp_x_neg));
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename DeviceContext, typename Functor>
class ActivationCudaKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Expand Down Expand Up @@ -1341,6 +1394,19 @@ REGISTER_OP_CUDA_KERNEL(
ops::ELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

/* ======================== celu register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(celu, CELU, CudaCELUFunctor,
CudaCELUGradFunctor);

REGISTER_OP_CUDA_KERNEL(
celu_grad_grad, ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::CELUGradGradFunctor<float>>,
ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::CELUGradGradFunctor<double>>,
ops::CELUDoubleGradKernel<plat::CUDADeviceContext,
ops::CELUGradGradFunctor<plat::float16>>);
/* ========================================================================== */

/* =========================== relu register ============================ */
#ifdef PADDLE_WITH_HIP
REGISTER_ACTIVATION_CUDA_KERNEL(relu, Relu, CudaReluFunctor,
Expand Down
111 changes: 111 additions & 0 deletions paddle/fluid/operators/activation_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,51 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CELUFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) =
(x < static_cast<T>(0))
.select(static_cast<T>(alpha) *
((x / static_cast<T>(alpha)).exp() - static_cast<T>(1)),
x);
}
};

template <typename T>
struct CELUGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp_a_pos = static_cast<T>(alpha > 0);
auto temp_a_neg = static_cast<T>(alpha <= 0);
auto temp_x_pos = (x > static_cast<T>(0)).template cast<T>();
auto temp_x_neg = (x <= static_cast<T>(0)).template cast<T>();

// dx = dout, if alpha > 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha > 0 and x <= 0
// dx = dout , if alpha < 0 and x > 0
// dx = dout * (x/alpha).exp(), if alpha < 0 and x <=0
dx.device(d) =
dout * temp_a_pos * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_pos * temp_x_neg +
dout * temp_a_neg * temp_x_pos +
dout * (x / static_cast<T>(alpha)).exp() * temp_a_neg * temp_x_neg;
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

// FIXME(qijun) /~https://github.com/PaddlePaddle/Paddle/issues/5198
template <typename T>
struct PowFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -1775,6 +1820,45 @@ struct ELUGradGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct CELUGradGradFunctor : public BaseActivationFunctor<T> {
float alpha;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* X,
const framework::Tensor* ddX, framework::Tensor* ddOut,
const framework::Tensor* dOut, framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "CELUGradGrad"));
auto x = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(X, "Input", "X", "CELUGradGrad"));

if (dX) {
auto dx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dX, "Output", "DX", "CELUGradGrad"));
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Output", "DOut", "CELUGradGrad"));
dx.device(*d) = ddx * dout / static_cast<T>(alpha) *
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>();
}

if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "CELUGradGrad"));
ddout.device(*d) = ddx *
((x > static_cast<T>(0)).template cast<T>() +
(x / static_cast<T>(alpha)).exp() *
(x <= static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
};

template <typename T>
struct SqrtGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
Expand Down Expand Up @@ -2107,6 +2191,33 @@ class ELUDoubleGradKernel
}
};

template <typename DeviceContext, typename Functor>
class CELUDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *X, *ddX, *dOut;
X = ddX = dOut = nullptr;
framework::Tensor *dX, *ddOut;
dX = ddOut = nullptr;

ExtractDoubleGradTensorWithInputDOut(ctx, &X, &ddX, &dX, &dOut, &ddOut);

if (dX) dX->mutable_data<T>(X->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(ctx.GetPlace());

auto& place = ctx.template device_context<DeviceContext>();

Functor functor;
auto attrs = functor.GetAttrs();
for (auto& attr : attrs) {
*attr.second = ctx.Attr<float>(attr.first);
}
functor(place, X, ddX, ddOut, dOut, dX);
}
};

template <typename DeviceContext, typename Functor>
class SqrtDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Expand Down
27 changes: 27 additions & 0 deletions python/paddle/fluid/tests/unittests/test_activation_nn_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import gradient_checker
import paddle.nn.functional as F

from decorator_helper import prog_scope

Expand Down Expand Up @@ -168,6 +169,32 @@ def test_grad(self):
self.func(p)


class TestCELUDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 4, 4, 4]
eps = 1e-6
alpha = 0.2
dtype = np.float64
SEED = 0

x = layers.data('x', shape, False, dtype)
x.persistable = True

y = F.celu(x, alpha=alpha)
np.random.RandomState(SEED)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)

def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)


class TestSqrtDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
Expand Down
Loading

0 comments on commit d7064f0

Please sign in to comment.