Skip to content

Commit

Permalink
add unit test for biharmonic
Browse files Browse the repository at this point in the history
Signed-off-by: jjyaoao <jjyaoao@126.com>
  • Loading branch information
jjyaoao committed Jun 23, 2023
1 parent 3759ba3 commit 391b6d5
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions test/equation/test_biharmonic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import paddle
import pytest
from paddle import nn

from ppsci import equation

__all__ = []


@pytest.mark.parametrize("dim", (2, 3))
def test_biharmonic(dim):
"""Test for biharmonic equation."""
batch_size = 13
input_dims = ("x", "y", "z")[:dim]
output_dims = ("u",)

q = -1.0
D = 1.0

# generate input data
x = paddle.randn([batch_size, 1])
y = paddle.randn([batch_size, 1])
x.stop_gradient = False
y.stop_gradient = False
input_data = paddle.concat([x, y], axis=1)
if dim == 3:
z = paddle.randn([batch_size, 1])
z.stop_gradient = False
input_data = paddle.concat([x, y, z], axis=1)

# build NN model
model = nn.Sequential(
nn.Linear(len(input_dims), len(output_dims)),
nn.Tanh(),
)

# manually generate output
u = model(input_data)

# use self-defined jacobian and hessian
def jacobian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor":
return paddle.grad(y, x, create_graph=True)[0]

def hessian(y: "paddle.Tensor", x: "paddle.Tensor") -> "paddle.Tensor":
return jacobian(jacobian(y, x), x)

# compute expected result
expected_result = -q / D

# compute fourth order derivative
for var_i in (x, y):
for var_j in (x, y):
expected_result += hessian(hessian(u, var_i), var_j)
if dim == 3:
for var_i in (x, y, z):
for var_j in (x, y, z):
expected_result += hessian(hessian(u, var_i), var_j)

# compute result using built-in Biharmonic module
biharmonic_equation = equation.Biharmonic(dim=dim, q=q, D=D)
data_dict = {
"x": x,
"y": y,
"u": u,
}
if dim == 3:
data_dict["z"] = z
test_result = biharmonic_equation.equations["biharmonic"](data_dict)

# check result whether is equal
assert paddle.allclose(expected_result, test_result)


if __name__ == "__main__":
pytest.main()

0 comments on commit 391b6d5

Please sign in to comment.