-
Notifications
You must be signed in to change notification settings - Fork 175
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Fix CB errors and add unitest. -) add cb unitests that replicates the cb tutorial. removed previous broken uci test. -) fixed a bug in joint CB. Reviewed By: rodrigodesalvobraz Differential Revision: D56336066 fbshipit-source-id: 7d2cd9c9d8201aaefc701cb5e3b6facbbb31a48b
- Loading branch information
1 parent
9a8ea26
commit 3200afa
Showing
5 changed files
with
172 additions
and
38 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||
|
||
|
||
import os | ||
import unittest | ||
|
||
import torch | ||
from pearl.action_representation_modules.one_hot_action_representation_module import ( | ||
OneHotActionTensorRepresentationModule, | ||
) | ||
from pearl.pearl_agent import PearlAgent | ||
from pearl.policy_learners.contextual_bandits.neural_bandit import NeuralBandit | ||
from pearl.policy_learners.contextual_bandits.neural_linear_bandit import ( | ||
NeuralLinearBandit, | ||
) | ||
from pearl.policy_learners.exploration_modules.contextual_bandits.squarecb_exploration import ( | ||
SquareCBExploration, | ||
) | ||
from pearl.policy_learners.exploration_modules.contextual_bandits.thompson_sampling_exploration import ( | ||
ThompsonSamplingExplorationLinear, | ||
) | ||
from pearl.policy_learners.exploration_modules.contextual_bandits.ucb_exploration import ( | ||
UCBExploration, | ||
) | ||
from pearl.replay_buffers.sequential_decision_making.fifo_off_policy_replay_buffer import ( | ||
FIFOOffPolicyReplayBuffer, | ||
) | ||
from pearl.utils.functional_utils.experimentation.set_seed import set_seed | ||
from pearl.utils.functional_utils.train_and_eval.online_learning import online_learning | ||
from pearl.utils.instantiations.environments.contextual_bandit_uci_environment import ( | ||
SLCBEnvironment, | ||
) | ||
from pearl.utils.uci_data import download_uci_data | ||
|
||
set_seed(0) | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
device_id = 0 if torch.cuda.is_available() else -1 | ||
|
||
""" | ||
This is a unit test version of the CB tutorial. | ||
It is meant to check whether code changes break the tutorial. | ||
It is therefore important that the tutorial and the code here are kept in sync. | ||
As part of that synchronization, the markdown cells in the tutorial are | ||
kept here as multi-line strings. | ||
For it to run quickly, the number of steps used for training is reduced. | ||
""" | ||
|
||
|
||
class TestCBTutorials(unittest.TestCase): | ||
def setUp(self) -> None: | ||
super().setUp() | ||
|
||
def test_cb_tutorials(self) -> None: | ||
# load environment | ||
device = -1 | ||
|
||
# Download UCI dataset if doesn't exist | ||
uci_data_path = "./utils/instantiations/environments/uci_datasets" | ||
if not os.path.exists(uci_data_path): | ||
os.makedirs(uci_data_path) | ||
download_uci_data(data_path=uci_data_path) | ||
|
||
# Built CB environment using the pendigits UCI dataset | ||
pendigits_uci_dict = { | ||
"path_filename": os.path.join(uci_data_path, "pendigits/pendigits.tra"), | ||
"action_embeddings": "discrete", | ||
"delim_whitespace": False, | ||
"ind_to_drop": [], | ||
"target_column": 16, | ||
} | ||
env = SLCBEnvironment(**pendigits_uci_dict) # pyre-ignore | ||
|
||
# experiment code | ||
number_of_steps = 200 | ||
record_period = 400 | ||
|
||
""" | ||
SquareCB | ||
""" | ||
# Create a Neural SquareCB pearl agent with 1-hot action representation | ||
action_representation_module = OneHotActionTensorRepresentationModule( | ||
max_number_actions=env.unique_labels_num, | ||
) | ||
|
||
agent = PearlAgent( | ||
policy_learner=NeuralBandit( | ||
feature_dim=env.observation_dim + env.unique_labels_num, | ||
hidden_dims=[64, 16], | ||
training_rounds=10, | ||
learning_rate=0.01, | ||
action_representation_module=action_representation_module, | ||
exploration_module=SquareCBExploration( | ||
gamma=env.observation_dim * env.unique_labels_num * number_of_steps | ||
), | ||
), | ||
replay_buffer=FIFOOffPolicyReplayBuffer(100_000), | ||
device_id=device, | ||
) | ||
|
||
_ = online_learning( | ||
agent=agent, | ||
env=env, | ||
number_of_steps=number_of_steps, | ||
print_every_x_steps=100, | ||
record_period=record_period, | ||
learn_after_episode=True, | ||
) | ||
|
||
# Neural LinUCB | ||
action_representation_module = OneHotActionTensorRepresentationModule( | ||
max_number_actions=env.unique_labels_num, | ||
) | ||
|
||
agent = PearlAgent( | ||
policy_learner=NeuralLinearBandit( | ||
feature_dim=env.observation_dim + env.unique_labels_num, | ||
hidden_dims=[64, 16], | ||
state_features_only=False, | ||
training_rounds=10, | ||
learning_rate=0.01, | ||
action_representation_module=action_representation_module, | ||
exploration_module=UCBExploration(alpha=1.0), | ||
), | ||
replay_buffer=FIFOOffPolicyReplayBuffer(100_000), | ||
device_id=device, | ||
) | ||
|
||
_ = online_learning( | ||
agent=agent, | ||
env=env, | ||
number_of_steps=number_of_steps, | ||
print_every_x_steps=100, | ||
record_period=record_period, | ||
learn_after_episode=True, | ||
) | ||
|
||
# Neural LinTS | ||
|
||
action_representation_module = OneHotActionTensorRepresentationModule( | ||
max_number_actions=env.unique_labels_num, | ||
) | ||
|
||
agent = PearlAgent( | ||
policy_learner=NeuralLinearBandit( | ||
feature_dim=env.observation_dim + env.unique_labels_num, | ||
hidden_dims=[64, 16], | ||
state_features_only=False, | ||
training_rounds=10, | ||
learning_rate=0.01, | ||
action_representation_module=action_representation_module, | ||
exploration_module=ThompsonSamplingExplorationLinear(), | ||
), | ||
replay_buffer=FIFOOffPolicyReplayBuffer(100_000), | ||
device_id=-1, | ||
) | ||
|
||
_ = online_learning( | ||
agent=agent, | ||
env=env, | ||
number_of_steps=number_of_steps, | ||
print_every_x_steps=100, | ||
record_period=record_period, | ||
learn_after_episode=True, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters