Skip to content

Commit

Permalink
Add observation space to single item recommendation system tutorial
Browse files Browse the repository at this point in the history
Summary: Add observation space to single item recommendation system tutorial

Reviewed By: jb3618columbia

Differential Revision: D56854463

fbshipit-source-id: 4eb24338f1b1cb881c20ceda6e39d6a944a74903
  • Loading branch information
rodrigodesalvobraz authored and facebook-github-bot committed May 1, 2024
1 parent 33862b8 commit d1d4dad
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
1 change: 0 additions & 1 deletion pearl/api/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def action_space(self) -> ActionSpace:
"""Returns the action space of the environment."""
pass

# FIXME: add this and in implement in all concrete subclasses
@property
@abstractmethod
def observation_space(self) -> Space:
Expand Down
7 changes: 7 additions & 0 deletions test/unit/test_tutorials/test_rec_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pearl.api.action_space import ActionSpace
from pearl.api.environment import Environment
from pearl.api.observation import Observation
from pearl.api.space import Space
from pearl.history_summarization_modules.lstm_history_summarization_module import (
LSTMHistorySummarizationModule,
)
Expand All @@ -40,6 +41,7 @@
)
from pearl.utils.functional_utils.experimentation.set_seed import set_seed
from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning
from pearl.utils.instantiations.spaces.box import BoxSpace
from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace

set_seed(0)
Expand Down Expand Up @@ -147,9 +149,14 @@ def __init__(
self.state: torch.Tensor = torch.zeros((self.history_length, 100)).to(device)
self._action_space: DiscreteActionSpace = DiscreteActionSpace(self.actions[0])

@property
def action_space(self) -> ActionSpace:
return DiscreteActionSpace(self.actions[0])

@property
def observation_space(self) -> Space:
return BoxSpace(low=torch.zeros((1,)), high=torch.ones((1,)))

def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]:
self.state: torch.Tensor = torch.zeros((self.history_length, 100))
self.t = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"metadata": {
"id": "nFomZD4OjZLK"
},
"execution_count": 2,
"execution_count": null,
"outputs": []
},
{
Expand Down Expand Up @@ -66,7 +66,7 @@
"id": "5i2jE98RjhK1",
"outputId": "ad80e72d-51cb-4594-a97e-2f4cf5466667"
},
"execution_count": 3,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -220,6 +220,7 @@
"from pearl.api.action_space import ActionSpace\n",
"from pearl.api.environment import Environment\n",
"from pearl.api.observation import Observation\n",
"from pearl.api.space import Space\n",
"from pearl.history_summarization_modules.lstm_history_summarization_module import (\n",
" LSTMHistorySummarizationModule,\n",
")\n",
Expand All @@ -241,6 +242,7 @@
")\n",
"from pearl.utils.functional_utils.experimentation.set_seed import set_seed\n",
"from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning\n",
"from pearl.utils.instantiations.spaces.box import BoxSpace\n",
"from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace\n",
"import matplotlib.pyplot as plt\n",
"\n",
Expand All @@ -249,7 +251,7 @@
"metadata": {
"id": "Lp6pRjTDjpDo"
},
"execution_count": 4,
"execution_count": null,
"outputs": []
},
{
Expand Down Expand Up @@ -345,9 +347,14 @@
" self.state: torch.Tensor = torch.zeros((self.history_length, 100)).to(device)\n",
" self._action_space: DiscreteActionSpace = DiscreteActionSpace(self.actions[0])\n",
"\n",
" @property\n",
" def action_space(self) -> ActionSpace:\n",
" return DiscreteActionSpace(self.actions[0])\n",
"\n",
" @property\n",
" def observation_space(self) -> Space:\n",
" return BoxSpace(low=torch.zeros((1,)), high=torch.ones((1,)))\n",
"\n",
" def reset(self, seed: Optional[int] = None) -> Tuple[Observation, ActionSpace]:\n",
" self.state: torch.Tensor = torch.zeros((self.history_length, 100))\n",
" self.t = 0\n",
Expand Down Expand Up @@ -402,7 +409,7 @@
"id": "BYdPfGgZp8HN",
"outputId": "6be7ab71-02d3-4d82-c66e-2ca60655efd1"
},
"execution_count": 5,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -469,7 +476,7 @@
"id": "GDnAlQQNqC7z",
"outputId": "e55dc40f-f5ad-48ba-f5ac-e1efb3f63a1c"
},
"execution_count": 6,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -1549,7 +1556,7 @@
"id": "hewvpLU_qHhO",
"outputId": "64d05cd0-c71b-4337-def6-441f9fbd59f5"
},
"execution_count": 6,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down Expand Up @@ -3635,7 +3642,7 @@
"id": "xuuCmTfoqMg9",
"outputId": "7fe1ee42-6697-4443-9f9a-953a1cbac5fa"
},
"execution_count": 7,
"execution_count": null,
"outputs": [
{
"output_type": "stream",
Expand Down

0 comments on commit d1d4dad

Please sign in to comment.