From 88a41f3722a49c14950939e9d1a031898297a2bf Mon Sep 17 00:00:00 2001 From: Yonathan Efroni Date: Tue, 4 Jun 2024 17:18:44 -0700 Subject: [PATCH] Fixes for APS integration Summary: Minor modifications for APS training 1) Log actor loss each step (otherwise get tensorboard error) 2) Removed assertion that actor type is GaussianActorNetwork, VanillaContinuousActorNetwork. This allow us to use customized actors which may not be of this type (generally, the user would like to specify the actors rather us forcing the user). Notice: we don't have such assertion in other AC methods, so better to remove it as of now. Reviewed By: danielrjiang Differential Revision: D58140074 fbshipit-source-id: 02cf789b7581275328c9c8a5fd3a1da033308c20 --- .../sequential_decision_making/implicit_q_learning.py | 10 ---------- .../policy_learners/sequential_decision_making/td3.py | 4 +++- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py index 87ada9a8..31411924 100644 --- a/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py +++ b/pearl/policy_learners/sequential_decision_making/implicit_q_learning.py @@ -125,16 +125,6 @@ def __init__( self._expectile = expectile self._is_action_continuous: bool = action_space.is_continuous - # TODO: create actor network interfaces for discrete and continuous actor networks - # and use the continuous one in this test. - if self._is_action_continuous: - torch._assert( - actor_network_type == GaussianActorNetwork - or actor_network_type == VanillaContinuousActorNetwork, - "continuous action space requires a deterministic or a stochastic actor which works" - "with continuous action spaces", - ) - self._temperature_advantage_weighted_regression = ( temperature_advantage_weighted_regression ) diff --git a/pearl/policy_learners/sequential_decision_making/td3.py b/pearl/policy_learners/sequential_decision_making/td3.py index 1fb25ea9..774d2317 100644 --- a/pearl/policy_learners/sequential_decision_making/td3.py +++ b/pearl/policy_learners/sequential_decision_making/td3.py @@ -94,6 +94,7 @@ def __init__( self._actor_update_noise = actor_update_noise self._actor_update_noise_clip = actor_update_noise_clip self._critic_update_count = 0 + self._last_actor_loss: float = 0.0 def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: # The actor and the critic updates are arranged in the following way @@ -108,7 +109,8 @@ def learn_batch(self, batch: TransitionBatch) -> Dict[str, Any]: # see ddpg base class for actor update details actor_loss = self._actor_loss(batch) actor_loss.backward(retain_graph=True) - report["actor_loss"] = actor_loss.item() + self._last_actor_loss = actor_loss.item() + report["actor_loss"] = self._last_actor_loss self._critic_optimizer.zero_grad() critic_loss = self._critic_loss(batch) # critic update