From b3aaf6ace84758259b35148894b717e32a18d09c Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 19 Jun 2023 22:07:47 +0800 Subject: [PATCH] add Stan activation (#398) --- ppsci/arch/activation.py | 22 ++++++++++++++++++++++ ppsci/arch/mlp.py | 17 ++++++++++------- 2 files changed, 32 insertions(+), 7 deletions(-) diff --git a/ppsci/arch/activation.py b/ppsci/arch/activation.py index fa2a2e943..31d82a799 100644 --- a/ppsci/arch/activation.py +++ b/ppsci/arch/activation.py @@ -19,6 +19,27 @@ from paddle import nn +class Stan(nn.Layer): + """Self-scalable Tanh. + paper: https://arxiv.org/abs/2204.12589v1 + + Args: + out_features (int, optional): Output features. Defaults to 1. + """ + + def __init__(self, out_features: int = 1): + super().__init__() + self.beta = self.create_parameter( + shape=(out_features,), + default_initializer=nn.initializer.Constant(1), + ) + + def forward(self, x): + # TODO: manually broadcast beta to x.shape for preventing backward error yet. + return F.tanh(x) * (1 + paddle.broadcast_to(self.beta, x.shape) * x) + # return F.tanh(x) * (1 + self.beta * x) + + class Swish(nn.Layer): def __init__(self, beta: float = 1.0): super().__init__() @@ -67,6 +88,7 @@ def forward(self, x): "swish": Swish(), "tanh": nn.Tanh(), "identity": nn.Identity(), + "stan": Stan, } diff --git a/ppsci/arch/mlp.py b/ppsci/arch/mlp.py index c450607cc..577cf1b28 100644 --- a/ppsci/arch/mlp.py +++ b/ppsci/arch/mlp.py @@ -82,6 +82,7 @@ def __init__( self.input_keys = input_keys self.output_keys = output_keys self.linears = [] + self.acts = [] if isinstance(hidden_size, (tuple, list)): if num_layers is not None: raise ValueError( @@ -107,16 +108,18 @@ def __init__( if weight_norm else nn.Linear(cur_size, _size) ) + # initialize activation function + self.acts.append( + act_mod.get_activation(activation) + if activation != "stan" + else act_mod.get_activation(activation)(_size) + ) cur_size = _size - self.linears = nn.LayerList(self.linears) + self.linears = nn.LayerList(self.linears) + self.acts = nn.LayerList(self.acts) self.last_fc = nn.Linear(cur_size, len(self.output_keys)) - # initialize activation function - self.act = nn.LayerList( - [act_mod.get_activation(activation) for _ in range(len(hidden_size))] - ) - self.skip_connection = skip_connection def forward_tensor(self, x): @@ -130,7 +133,7 @@ def forward_tensor(self, x): y = y + skip else: skip = y - y = self.act[i](y) + y = self.acts[i](y) y = self.last_fc(y)