Skip to content

Commit

Permalink
much cleaner code
Browse files Browse the repository at this point in the history
  • Loading branch information
saleml committed Sep 12, 2022
1 parent 2d954c8 commit c6d3651
Show file tree
Hide file tree
Showing 25 changed files with 570 additions and 778 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,5 @@ wandb/
# slurm-specific stuff
*.sh
*.out

scripts.py
15 changes: 7 additions & 8 deletions dynamic_programming.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from gfn.estimators import LogEdgeFlowEstimator
from gfn.modules import Tabular, Uniform
from gfn.parametrizations.edge_flows import FMParametrization
from gfn.preprocessors import EnumPreprocessor
from gfn.validate import validate

parser = ArgumentParser()
Expand All @@ -28,11 +27,10 @@

logit_PB = Uniform(output_dim=env.n_actions - 1)

preprocessor = EnumPreprocessor(env)

all_states = env.all_states

all_states_indices = preprocessor(all_states)
all_states_indices = env.get_states_indices(all_states)

# Zeroth step: Define the necessary containers
Y = set() # Contains the state indices that do not need more visits
Expand All @@ -57,19 +55,20 @@
state_prime = all_states[[s_prime_index]]

backward_mask = state_prime.backward_masks[0]
pb_logits = logit_PB(preprocessor(state_prime))
pb_logits = logit_PB(env.get_states_indices(state_prime))
pb_logits[~backward_mask] = -float("inf")
pb = torch.softmax(pb_logits, dim=0)
for i in range(env.n_actions - 1):
if backward_mask[i]:
state = env.backward_step(state_prime, torch.tensor([i]))
s_index = preprocessor(state)[0].item()
pb_logits = logit_PB(preprocessor(state_prime))
s_index = env.get_states_indices(state)[0].item()
pb_logits = logit_PB(env.get_states_indices(state_prime))
F_edge[s_index, i] = F_s_prime * pb[i].item()
F_state[s_index] = F_state[s_index] + F_edge[s_index, i]
if all(
[
preprocessor(env.step(state, torch.tensor([j])))[0].item() in Y
env.get_states_indices(env.step(state, torch.tensor([j])))[0].item()
in Y
for j in range(env.n_actions - 1)
if state.forward_masks[0, j]
]
Expand All @@ -83,7 +82,7 @@
logF_edge_module = Tabular(env, output_dim=env.n_actions - 1)
logF_edge_module.logits = logF_edge[:, :-1]
logF_edge_estimator = LogEdgeFlowEstimator(
preprocessor=preprocessor, module=logF_edge_module
env=env, module_name="Tabular", module=logF_edge_module
)
parametrization = FMParametrization(logF=logF_edge_estimator)
print(validate(env, parametrization, n_validation_samples=10000))
12 changes: 7 additions & 5 deletions gfn/configs/env.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass
from typing import Literal

from simple_parsing import subgroups
from simple_parsing import choice, subgroups
from simple_parsing.helpers import JsonSerializable

from gfn.envs import Env, HyperGrid
Expand All @@ -20,29 +20,31 @@ class HyperGridConfig(BaseEnvConfig):
R1: float = 0.5
R2: float = 2.0
reward_cos: bool = False
preprocessor_name: str = choice("KHot", "OneHot", "Identity", default="KHot")

def parse(self, device: Literal["cpu", "cuda"]) -> Env:
def parse(self, device_str: Literal["cpu", "cuda"]) -> Env:
return HyperGrid(
ndim=self.ndim,
height=self.height,
R0=self.R0,
R1=self.R1,
R2=self.R2,
reward_cos=self.reward_cos,
device=device,
device_str=device_str,
preprocessor_name=self.preprocessor_name,
)


@dataclass
class BitSequenceConfig(BaseEnvConfig):
class MoleculesConfig(BaseEnvConfig):
def parse(self, device: Literal["cpu", "cuda"]) -> Env:
raise NotImplementedError("Not implemented yet")


@dataclass
class EnvConfig(JsonSerializable):
env: BaseEnvConfig = subgroups(
{"HyperGrid": HyperGridConfig, "BitSequence": BitSequenceConfig},
{"HyperGrid": HyperGridConfig, "Molecules": MoleculesConfig},
default=HyperGridConfig(),
)

Expand Down
60 changes: 0 additions & 60 deletions gfn/configs/module.py

This file was deleted.

Loading

0 comments on commit c6d3651

Please sign in to comment.