Skip to content

Commit

Permalink
Back out "Fix drectories structure discrepancy in internal tests and …
Browse files Browse the repository at this point in the history
…github CI"

Summary: The diff broke imports in Github, that do not match internal imports. Specifically, "from pearl.test.utils import xx" works internally but doesn't work on Github, "from test.utils import xx" works on Github, but doesn't work internally.

Reviewed By: rodrigodesalvobraz

Differential Revision:
D56643061

Privacy Context Container: L1202097

fbshipit-source-id: 49cf927d9576c3cefc25b3dbdd26d58f4132b916
  • Loading branch information
Dmytro Korenkevych authored and facebook-github-bot committed Apr 27, 2024
1 parent 3200afa commit b9ef4ed
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 35 deletions.
5 changes: 1 addition & 4 deletions pearl/utils/scripts/cb_benchmark/cb_benchmark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,11 @@
from pearl.policy_learners.exploration_modules.contextual_bandits.ucb_exploration import (
UCBExploration,
)
from pearl.test.utils import prefix_dir
from pearl.utils.instantiations.environments.contextual_bandit_uci_environment import (
SLCBEnvironment,
)


DATA_PATH: str = f"{prefix_dir()}utils/instantiations/environments/uci_datasets"

DATA_PATH: str = "./utils/instantiations/environments/uci_datasets"

"""
Experiment config
Expand Down
9 changes: 3 additions & 6 deletions pearl/utils/scripts/cb_benchmark/run_cb_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@
from pearl.replay_buffers.contextual_bandits.discrete_contextual_bandit_replay_buffer import (
DiscreteContextualBanditReplayBuffer,
)
from pearl.test.utils import prefix_dir
from pearl.utils.instantiations.environments.contextual_bandit_uci_environment import (
SLCBEnvironment,
)

from pearl.utils.instantiations.spaces.discrete_action import DiscreteActionSpace

from pearl.utils.scripts.cb_benchmark.cb_benchmark_config import (
Expand Down Expand Up @@ -274,18 +274,15 @@ def run_cb_benchmarks(
"""

# Create UCI data directory if it does not already exist
uci_data_path: str = f"{prefix_dir()}utils/instantiations/environments/uci_datasets"
save_results_path: str = (
f"{prefix_dir()}utils/scripts/cb_benchmark/experiments_results"
)

uci_data_path = "./utils/instantiations/environments/uci_datasets"
if not os.path.exists(uci_data_path):
os.makedirs(uci_data_path)

# Download UCI data
download_uci_data(data_path=uci_data_path)

# Create folder for result if it does not already exist
save_results_path: str = "./utils/scripts/cb_benchmark/experiments_results"
if not os.path.exists(save_results_path):
os.makedirs(save_results_path)

Expand Down
12 changes: 2 additions & 10 deletions test/unit/test_tutorials/test_rec_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# pyre-strict


import os
import random
import unittest
from typing import List, Optional, Tuple
Expand Down Expand Up @@ -190,23 +189,16 @@ def setUp(self) -> None:
def test_rec_system(self) -> None:
# load environment
model = SequenceClassificationModel(100).to(device)
if os.path.exists("../Pearl"):
# Github CI
model_dir = "tutorials/single_item_recommender_system_example/"
else:
# internal Meta tests
model_dir = "pearl/tutorials/single_item_recommender_system_example/"

model.load_state_dict(
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
torch.load(
os.path.join(model_dir, "env_model_state_dict.pt"),
"pearl/tutorials/single_item_recommender_system_example/env_model_state_dict.pt",
weights_only=True,
)
)
# Note: in the tutorial the directory "pearl" must be replaced by "Pearl"
actions = torch.load(
os.path.join(model_dir, "news_embedding_small.pt"),
"pearl/tutorials/single_item_recommender_system_example/news_embedding_small.pt",
weights_only=True,
)
history_length = 8
Expand Down
15 changes: 0 additions & 15 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
This file contains helpers for unittest creation
"""

import os
from typing import Tuple

import torch
Expand Down Expand Up @@ -41,17 +40,3 @@ def create_normal_pdf_training_data(
) # corresponding pdf of mvn
y_corrupted = y + 0.01 * torch.randn(num_data_points) # noise corrupted targets
return x, y_corrupted


def prefix_dir() -> str:
"""
Returns the path needed to go from the current working directory while running
tests to the second-level Pearl packages, depending on the platform being run.
On the GitHub setup, this is "pearl/". In the internal Meta setup, this is "".
"""
if os.path.exists("../Pearl"):
# github CI
return "pearl/"
else:
# internal Meta tests
return ""

0 comments on commit b9ef4ed

Please sign in to comment.