diff --git a/deepxde/model.py b/deepxde/model.py index 4ebdf6859..9ecd92a56 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -371,8 +371,6 @@ def closure(): def _compile_jax(self, lr, loss_fn, decay): """jax""" - if self.loss_weights is not None: - raise NotImplementedError("Loss weights are not supported for backend jax.") # Initialize the network's parameters if self.params is None: key = jax.random.PRNGKey(config.jax_random_seed) @@ -398,10 +396,12 @@ def outputs_fn(inputs): # We use aux so that self.data.losses is a pure function. aux = [outputs_fn, ext_params] if ext_params else [outputs_fn] losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux) - # TODO: Add regularization loss, weighted losses + # TODO: Add regularization loss if not isinstance(losses, list): losses = [losses] losses = jax.numpy.asarray(losses) + if self.loss_weights is not None: + losses *= jax.numpy.asarray(self.loss_weights) return outputs_, losses @jax.jit