Skip to content

Commit

Permalink
Backend jax: Support loss weights feature (#1670)
Browse files Browse the repository at this point in the history
  • Loading branch information
bonneted authored Mar 7, 2024
1 parent 5a188cf commit 25fa474
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,6 @@ def closure():

def _compile_jax(self, lr, loss_fn, decay):
"""jax"""
if self.loss_weights is not None:
raise NotImplementedError("Loss weights are not supported for backend jax.")
# Initialize the network's parameters
if self.params is None:
key = jax.random.PRNGKey(config.jax_random_seed)
Expand All @@ -398,10 +396,12 @@ def outputs_fn(inputs):
# We use aux so that self.data.losses is a pure function.
aux = [outputs_fn, ext_params] if ext_params else [outputs_fn]
losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux)
# TODO: Add regularization loss, weighted losses
# TODO: Add regularization loss
if not isinstance(losses, list):
losses = [losses]
losses = jax.numpy.asarray(losses)
if self.loss_weights is not None:
losses *= jax.numpy.asarray(self.loss_weights)
return outputs_, losses

@jax.jit
Expand Down

0 comments on commit 25fa474

Please sign in to comment.