Skip to content

Commit

Permalink
add Stan activation (PaddlePaddle#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
HydrogenSulfate authored Jun 19, 2023
1 parent b36daa5 commit b3aaf6a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 7 deletions.
22 changes: 22 additions & 0 deletions ppsci/arch/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,27 @@
from paddle import nn


class Stan(nn.Layer):
"""Self-scalable Tanh.
paper: https://arxiv.org/abs/2204.12589v1
Args:
out_features (int, optional): Output features. Defaults to 1.
"""

def __init__(self, out_features: int = 1):
super().__init__()
self.beta = self.create_parameter(
shape=(out_features,),
default_initializer=nn.initializer.Constant(1),
)

def forward(self, x):
# TODO: manually broadcast beta to x.shape for preventing backward error yet.
return F.tanh(x) * (1 + paddle.broadcast_to(self.beta, x.shape) * x)
# return F.tanh(x) * (1 + self.beta * x)


class Swish(nn.Layer):
def __init__(self, beta: float = 1.0):
super().__init__()
Expand Down Expand Up @@ -67,6 +88,7 @@ def forward(self, x):
"swish": Swish(),
"tanh": nn.Tanh(),
"identity": nn.Identity(),
"stan": Stan,
}


Expand Down
17 changes: 10 additions & 7 deletions ppsci/arch/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def __init__(
self.input_keys = input_keys
self.output_keys = output_keys
self.linears = []
self.acts = []
if isinstance(hidden_size, (tuple, list)):
if num_layers is not None:
raise ValueError(
Expand All @@ -107,16 +108,18 @@ def __init__(
if weight_norm
else nn.Linear(cur_size, _size)
)
# initialize activation function
self.acts.append(
act_mod.get_activation(activation)
if activation != "stan"
else act_mod.get_activation(activation)(_size)
)
cur_size = _size
self.linears = nn.LayerList(self.linears)

self.linears = nn.LayerList(self.linears)
self.acts = nn.LayerList(self.acts)
self.last_fc = nn.Linear(cur_size, len(self.output_keys))

# initialize activation function
self.act = nn.LayerList(
[act_mod.get_activation(activation) for _ in range(len(hidden_size))]
)

self.skip_connection = skip_connection

def forward_tensor(self, x):
Expand All @@ -130,7 +133,7 @@ def forward_tensor(self, x):
y = y + skip
else:
skip = y
y = self.act[i](y)
y = self.acts[i](y)

y = self.last_fc(y)

Expand Down

0 comments on commit b3aaf6a

Please sign in to comment.