Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Knowledge Base + Multi-Task Warm-starting #933

Merged
merged 160 commits into from
Sep 30, 2022
Merged
Show file tree
Hide file tree
Changes from 143 commits
Commits
Show all changes
160 commits
Select commit Hold shift + click to select a range
978be98
Import files from old PR, start fresh
lebrice May 16, 2022
e579f97
Cleaning up algo_wrapper.py
lebrice May 16, 2022
0b0d78b
Minor typing improvements in utils/__init__.py
lebrice May 16, 2022
0bdf0c9
add transform-related methods to AlgoWrapper
lebrice May 16, 2022
e3d2e0e
Clean up multi_task_algo.py
lebrice May 16, 2022
f99af06
[big] move wrappers to folder, simplify wrappers
lebrice May 17, 2022
609ccc9
Add tests for the wrappers
lebrice May 18, 2022
591d085
Move algo wrappers to folder, split up tests a bit
lebrice May 19, 2022
74011a2
Fix cyclical import issue with knowledbe base
lebrice May 19, 2022
177f276
Remove accidentally added .pre-commit-config.yaml
lebrice May 19, 2022
5d9feef
"algorithms.algorithm" -> "algorithms.unwrapped"
lebrice May 19, 2022
cf7d3ac
Have wrappers return the algo's config, not theirs
lebrice May 19, 2022
d5b2ac8
Remove duplicated code from SpaceTransform
lebrice May 19, 2022
8cd3adb
Rename test files for algo wrappers
lebrice May 19, 2022
55702a8
Fix bug in test_gridsearch.py
lebrice May 19, 2022
c835e4c
Remove more duplicated code from SpaceTransform
lebrice May 19, 2022
87041d6
Husk of unit tests for multi-task wrapper
lebrice May 30, 2022
9fdfd84
Rename multi_task_wrapper.py -> multi_task.py
lebrice May 30, 2022
4b43796
Start to write body of the test
lebrice May 30, 2022
a0a2517
Remove other duplicate methods of SpaceTransform
lebrice May 30, 2022
926c05a
Add dataclasses for various objects
lebrice Oct 14, 2021
e556851
Adapting previous changes from config_dataclasses
lebrice May 30, 2022
8e84147
Adapting previous changes from config_dataclasses
lebrice May 30, 2022
1e72b4e
Merge branch 'knowledge-base' of https://www.github.com/lebrice/orion…
lebrice May 31, 2022
3ae6acb
Fix / reformat docstrings
lebrice Jun 22, 2022
16eb753
Fix typo in log.debug call
lebrice Jun 22, 2022
ac3b9d6
Fix outdated docstrings
lebrice Jun 22, 2022
c3a5266
Revert changes to ugly part of testing/__init__.py
lebrice Jun 22, 2022
5c89d67
Fix isort issues
lebrice Jun 22, 2022
8456f86
Merge branch 'develop' into knowledge-base
lebrice Jun 22, 2022
6488274
Fix isort / flake8 issues
lebrice Jun 22, 2022
9bf31d7
Add an intermediate TransformWrapper ABC
lebrice Jun 22, 2022
6aa0ac8
Add optional `max_trials` property to algorithm
lebrice Jun 22, 2022
3df7b35
Remove outdated/misleading comments
lebrice Jun 22, 2022
f590266
Fix various pylint/flake8 errors
lebrice Jun 22, 2022
b86efe8
Touchups on the types of primary_algo.py
lebrice Jul 4, 2022
c70a1c7
Slight refactoring in serializable.py
lebrice Jul 4, 2022
ed9d137
Rename AlgoType TypeVar to AlgoT
lebrice Jul 4, 2022
cf641a8
Start to add tests for ExperimentInfo dataclass
lebrice Jul 4, 2022
5add7df
Allow passing Algo type to create_experiment
lebrice Jul 4, 2022
94b3f33
Minor tweak in AlgoWrapper.__repr__
lebrice Jul 4, 2022
60b8f0f
Adding (failing) tests for warm-starting.
lebrice Jul 4, 2022
bf9e5bb
Add Knowledge Base arg to Experiment and Producer
lebrice Jul 6, 2022
0e78382
Add KnowledgeBase as argument to workon functions
lebrice Jul 6, 2022
f4c70f6
Move warm-start method to right wrapper class
lebrice Jul 6, 2022
a5c65b2
Add a simple test for warm starting
lebrice Jul 6, 2022
72f3765
Merge branch 'develop' into knowledge-base
lebrice Jul 6, 2022
035404d
Improve __repr__ of Registry
lebrice Jul 11, 2022
d5d0180
Type the `space` property of ExperimentClient
lebrice Jul 11, 2022
9fcc8b1
Remove line that caused bug in experiment_builder
lebrice Jul 11, 2022
28307b3
Add tests for KB + Warm-Starting
lebrice Jul 11, 2022
e735b85
(big, ugly commit) Add tests, clarify, refactor
lebrice Jul 13, 2022
da8c0c0
Merge branch 'develop' into knowledge-base
lebrice Jul 15, 2022
55ba3ab
Add repr for RegistryMapping
lebrice Jul 18, 2022
b6d4e6f
Type the _results and _params attributes of Trial
lebrice Jul 18, 2022
175765a
Move and Rename algo for unit tests
lebrice Jul 18, 2022
25632be
Add test for Multi-Task wrapper collisions
lebrice Jul 18, 2022
391f67a
Fix bug in AlgoWrapper (see desc.)
lebrice Jul 18, 2022
b805afd
Add test for suggest to always give task_id=0
lebrice Jul 18, 2022
92d09ae
Fix small bug in gridsearch.py
lebrice Jul 18, 2022
4347ee7
Simplification: kb is only attr of Experiment
lebrice Jul 18, 2022
8dd83ea
Moved "functional" tests to different file
lebrice Jul 18, 2022
64cabc6
Fix pylint errors
lebrice Jul 25, 2022
1827cb5
Add unwrap convenience method on AlgoWrapper
lebrice Jul 25, 2022
c3fb801
Misc changes (test cleanup, copy status)
lebrice Jul 25, 2022
c499c6c
Remove unused warm_start_mode context manager
lebrice Jul 25, 2022
7cef578
Use the new max_trials property on algo
lebrice Jul 25, 2022
cc6ecd6
Add tests for setting max_trials and n_observed
lebrice Jul 25, 2022
a545922
Fix renaming of SpaceTransform algo wrapper
lebrice Jul 25, 2022
dc72b15
Don't register trials from other tasks in algo
lebrice Jul 25, 2022
8c88802
Fix randomness of flaky-ish test
lebrice Jul 25, 2022
a4b6ec4
Remove dict[k, v] type annotation for python < 3.9
lebrice Jul 25, 2022
1b29a38
Move / Simplify the Config classes -> TypedDicts
lebrice Jul 25, 2022
4399741
Type out the Random algo, fix a test
lebrice Jul 25, 2022
627aae2
Fixing some broken tests
lebrice Jul 25, 2022
4bd0e66
Fix experiment_builder tests
lebrice Jul 25, 2022
278bfcf
Fix broken test in test_experiment.py
lebrice Jul 25, 2022
f04add1
Add test for ExperimentConfig fields
lebrice Jul 26, 2022
41b64eb
Minor touchups in experiment_config.py
lebrice Jul 26, 2022
8f48a0e
Removed unused typeddicts in experiment_config.py
lebrice Jul 26, 2022
b0f9de3
Remove outdated todo
lebrice Jul 26, 2022
03dc654
[breaking] Knowledge Base implementation
lebrice Jul 26, 2022
08d1294
Remove unused knowledge_base argument to workon
lebrice Jul 26, 2022
413b3e8
Fix instantiation of KB in exp builder
lebrice Jul 26, 2022
ff031fc
[nit] Fix typing of ExperimentStats fields
lebrice Jul 26, 2022
7877eea
[optional] Type out storage/base.py and misc types
lebrice Jul 26, 2022
c4d6323
Adapting the MultiTask wrapper tests, add stubs
lebrice Jul 26, 2022
c406ff7
Move test_experiment_config to reflect src
lebrice Jul 27, 2022
589fe27
Add test for KB
lebrice Jul 27, 2022
dceba63
Add more tests for the KnowledgeBase
lebrice Jul 28, 2022
2adf96b
Fix incorrect type for experiment id in docstrings
lebrice Jul 28, 2022
ba76f56
Fix error in fixture, adjust docstrings
lebrice Jul 28, 2022
55f7652
Pass experiment config instead of experiment obj
lebrice Jul 28, 2022
7ea2ef5
Fix docstrings, use Unpack[] to type **kwargs
lebrice Jul 28, 2022
a9e2b01
Move and update functional tests
lebrice Jul 28, 2022
aceaae5
Remove unused code block in test
lebrice Jul 28, 2022
7c3a319
Add note about the Unpack[] annotation
lebrice Jul 28, 2022
7d22340
Add PartialExperimentConfig typeddict
lebrice Jul 28, 2022
af467fc
Trying to make functional tests pass...
lebrice Jul 28, 2022
3682e01
Fix pylint error in experiment.py
lebrice Jul 28, 2022
1329547
Merge branch 'develop' into knowledge-base
lebrice Jul 29, 2022
a18cb51
Fix bugs in test_knowledge_base
lebrice Jul 29, 2022
f3993dd
Pass kb to instantiate_algo, minor typing stuffs
lebrice Jul 29, 2022
4ec56d6
Clarify potential issue in register, touchups
lebrice Jul 29, 2022
cba0d1e
Minor typing touchups
lebrice Jul 29, 2022
32e201c
Simplify functional tests for warm_starting
lebrice Jul 29, 2022
d41ab49
Add tests for how to pass the algorithm
lebrice Jul 29, 2022
e250605
Minor typing improvements to experiment_builder.py
lebrice Jul 29, 2022
6b1d4a7
Fix assignment of max_trials in exp client
lebrice Jul 29, 2022
b534749
Use KnowledgeBase in functional test, remove todos
lebrice Jul 29, 2022
dbdd28e
Fix pylint error
lebrice Jul 29, 2022
cdf5eee
Fix import error for TypedDict
lebrice Jul 29, 2022
bad79fc
Re-introduce fix from #964 (?)
lebrice Jul 29, 2022
24d71b8
Fix bug in algo creation logic
lebrice Jul 29, 2022
97f8e89
Add missing types in BaseAlgorithm
lebrice Jul 29, 2022
cb6b9c4
Add some of the missing types in exp builder
lebrice Jul 29, 2022
4293737
Remove leftover todo in BaseAlgorithm
lebrice Jul 29, 2022
dfd62de
Fix type annotation on create_experiment
lebrice Jul 29, 2022
626ddb8
Remove extra type annotation on create_experiment
lebrice Jul 29, 2022
f854652
Fix bug in test_tpe (added wrapper)
lebrice Jul 30, 2022
4f02082
Fix value in DumbAlgo
lebrice Jul 30, 2022
67099d2
Fix test condition
lebrice Aug 10, 2022
b792d7e
Remove fixme comment
lebrice Aug 10, 2022
6f810fb
Merge branch 'develop' into knowledge-base
lebrice Aug 10, 2022
07f8854
Fix tests
lebrice Aug 10, 2022
7836f6b
Fix most PBT tests
lebrice Aug 10, 2022
fe01d9e
Misc changes
lebrice Aug 10, 2022
12427d5
Misc typing changes
lebrice Aug 10, 2022
4e11b74
Debugging the PBT Errors
lebrice Aug 10, 2022
2c84f4e
Fix broken AlgoWrapper tests
lebrice Aug 10, 2022
affee2e
Add missing register method on AlgoWrapper
lebrice Aug 10, 2022
e800c41
Remove unused `original_space` property
lebrice Aug 10, 2022
e21783e
Fix docstring and warning in BaseAlgorithm.get_id
lebrice Aug 10, 2022
7c0d985
Fix bug introduced previously
lebrice Aug 18, 2022
22518cf
Minor improvements to TransformWrapper.suggest
lebrice Aug 18, 2022
043bf7d
Use self.reverse_transform in get_original_parent
lebrice Aug 18, 2022
8711e89
Ugly temporary fix to bug that affected PBT (desc)
lebrice Aug 18, 2022
7237ade
Remove hacky fix, fix bug source (hopefully)
lebrice Aug 18, 2022
95cf645
Make _get_original_parent "static" again
lebrice Aug 18, 2022
157fe38
Add todo for later (copying attributes explicitly)
lebrice Aug 18, 2022
3c19783
Fix minor bug in DumbAlgo-related test
lebrice Aug 18, 2022
604beaa
Greatly reduce number of warnings
lebrice Aug 19, 2022
85bc9a8
Remove Registry from InsistSuggestWrapper
lebrice Aug 29, 2022
3c7141e
Add tests to increase coverage of Registry class
lebrice Sep 20, 2022
34965b4
Add more coverage for _instantiate_knowledge_base
lebrice Sep 20, 2022
cb616d2
Add a bit more coverage for AlgoWrapper
lebrice Sep 20, 2022
84de58b
Fix bug in _instantiate_knowledge_base
lebrice Sep 20, 2022
5040c07
Add tests for _instantiate_algo
lebrice Sep 20, 2022
bc02974
Add generic test class for AlgoWrappers
lebrice Sep 20, 2022
57d561b
Remove redundant fixture in AlgoWrapper tests
lebrice Sep 20, 2022
a007076
Clean / minor changes to test_knowledge_base.py
lebrice Sep 20, 2022
8f23392
Use asserts to avoid writing useless tests
lebrice Sep 20, 2022
3555408
Add test case for no compatible trials found
lebrice Sep 20, 2022
9977d87
Clean up testing utility function a bit
lebrice Sep 20, 2022
aff8c7b
Add test for not warm-starting twice
lebrice Sep 20, 2022
47d8a95
Add test for "space already has task_id" error
lebrice Sep 20, 2022
0d25b26
Add tests for is_warmstarteable function
lebrice Sep 20, 2022
5bd51af
Standardize the imports of ExperimentConfig
lebrice Sep 20, 2022
3f35bba
Add more tests for _instantiate_kb
lebrice Sep 23, 2022
5a61b59
Remove redundante else clause in _instantiate_algo
lebrice Sep 23, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
---
repos:
- repo: /~https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0 # Use the ref you want to point at
hooks:
- id: check-merge-conflict
- repo: /~https://github.com/python/black
rev: 22.6.0
hooks:
Expand Down
61 changes: 32 additions & 29 deletions src/orion/algo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import inspect
import logging
from abc import abstractmethod
from typing import Any

from orion.algo.registry import Registry
from orion.algo.space import Space
from orion.core.utils import GenericFactory
from orion.core.worker.trial import Trial

Expand Down Expand Up @@ -101,7 +103,9 @@ def observe(self, points, results):
requires_shape = None
requires_dist = None

def __init__(self, space, **kwargs):
max_trials: int | None = None

def __init__(self, space: Space, **kwargs):
log.debug(
"Creating Algorithm object of %s type with parameters:\n%s",
type(self).__name__,
Expand Down Expand Up @@ -142,18 +146,19 @@ def state_dict(self):
"""Return a state dict that can be used to reset the state of the algorithm."""
return {"registry": self.registry.state_dict}

def set_state(self, state_dict):
def set_state(self, state_dict: dict):
"""Reset the state of the algorithm based on the given state_dict

:param state_dict: Dictionary representing state of an algorithm
"""
self.registry.set_state(state_dict["registry"])

def get_id(self, trial, ignore_fidelity=False, ignore_parent=False):
def get_id(
self, trial: Trial, ignore_fidelity: bool = False, ignore_parent: bool = False
) -> str:
"""Return unique hash for a trials based on params

The trial is assumed to be in the transformed space if the algorithm is working in a
transformed space.
The trial is assumed to be in the optimization space of the algorithm.

Parameters
----------
Expand All @@ -170,13 +175,12 @@ def get_id(self, trial, ignore_fidelity=False, ignore_parent=False):
return trial.compute_trial_hash(
trial,
ignore_fidelity=ignore_fidelity,
ignore_experiment=True,
ignore_lie=True,
ignore_parent=ignore_parent,
)

@property
def fidelity_index(self):
def fidelity_index(self) -> str | None:
"""Returns the name of the first fidelity dimension if there is one, otherwise `None`."""
fidelity_dims = [dim for dim in self.space.values() if dim.type == "fidelity"]
if fidelity_dims:
Expand Down Expand Up @@ -209,7 +213,7 @@ def suggest(self, num: int) -> list[Trial]:
has suggested/observed, and for the auto-generated unit-tests to pass.
"""

def observe(self, trials):
def observe(self, trials: list[Trial]) -> None:
"""Observe the `results` of the evaluation of the `trials` in the
process defined in user's script.

Expand All @@ -223,7 +227,7 @@ def observe(self, trials):
if not self.has_observed(trial):
self.register(trial)

def register(self, trial):
def register(self, trial: Trial) -> None:
"""Save the trial as one suggested or observed by the algorithm.

Parameters
Expand All @@ -234,16 +238,16 @@ def register(self, trial):
self.registry.register(trial)

@property
def n_suggested(self):
def n_suggested(self) -> int:
"""Number of trials suggested by the algorithm"""
return len(self.registry)

@property
def n_observed(self):
def n_observed(self) -> int:
"""Number of completed trials observed by the algorithm."""
return sum(self.has_observed(trial) for trial in self.registry)

def has_suggested(self, trial):
def has_suggested(self, trial: Trial) -> bool:
"""Whether the algorithm has suggested a given point.

Parameters
Expand All @@ -259,7 +263,7 @@ def has_suggested(self, trial):
"""
return self.registry.has_suggested(trial)

def has_observed(self, trial):
def has_observed(self, trial: Trial) -> bool:
"""Whether the algorithm has observed a given point objective.

This only counts observed completed trials.
Expand Down Expand Up @@ -312,15 +316,11 @@ def has_completed_max_trials(self) -> bool:
"""Returns True if the algorithm has a `max_trials` attribute, and has completed more trials
than its value.
"""
if not hasattr(self, "max_trials"):
return False
max_trials = getattr(self, "max_trials")
if max_trials is None:
if self.max_trials is None:
return False

fidelity_index = self.fidelity_index
max_fidelity_value = None

# When a fidelity dimension is present, we only count trials that have the maximum value.
if fidelity_index is not None:
_, max_fidelity_value = self.space[fidelity_index].interval()
Expand All @@ -333,10 +333,11 @@ def _is_completed(trial: Trial) -> bool:
and trial.params[fidelity_index] >= max_fidelity_value
)

return sum(map(_is_completed, self.registry)) >= max_trials
return sum(map(_is_completed, self.registry)) >= self.max_trials

def score(self, trial): # pylint:disable=no-self-use,unused-argument
"""Allow algorithm to evaluate `point` based on a prediction about
# pylint:disable=no-self-use,unused-argument
def score(self, trial: Trial) -> float:
"""Allow algorithm to evaluate `trial` based on a prediction about
this parameter set's performance.

By default, return the same score any parameter (no preference).
Expand All @@ -353,7 +354,8 @@ def score(self, trial): # pylint:disable=no-self-use,unused-argument
"""
return 0

def judge(self, trial, measurements): # pylint:disable=no-self-use,unused-argument
# pylint:disable=no-self-use,unused-argument
def judge(self, trial: Trial, measurements: Any) -> dict | None:
"""Inform an algorithm about online `measurements` of a running trial.

This method is to be used as a callback in a client-server communication
Expand Down Expand Up @@ -381,7 +383,7 @@ def judge(self, trial, measurements): # pylint:disable=no-self-use,unused-argum
"""
return None

def should_suspend(self, trial):
def should_suspend(self, trial: Trial) -> bool:
"""Allow algorithm to decide whether a particular running trial is still
worth to complete its evaluation, based on information provided by the
`judge` method.
Expand All @@ -390,10 +392,12 @@ def should_suspend(self, trial):
return False

@property
def configuration(self):
def configuration(self) -> dict[str, Any]:
"""Return tunable elements of this algorithm in a dictionary form
appropriate for saving.

By default, returns a dictionary containing the attributes of `self` which are also
constructor arguments.
"""
dict_form = dict()
for attrname in self._param_names:
Expand All @@ -404,14 +408,13 @@ def configuration(self):
return {self.__class__.__name__.lower(): dict_form}

@property
def space(self):
def space(self) -> Space:
"""Domain of problem associated with this algorithm's instance."""
return self._space

@space.setter
def space(self, space):
"""Set space."""
self._space = space
@property
def unwrapped(self):
return self


algo_factory = GenericFactory(BaseAlgorithm)
5 changes: 4 additions & 1 deletion src/orion/algo/gridsearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,11 @@ def __init__(
if not isinstance(n_values, dict)
else n_values
)
max_trials = 10_000 if self.max_trials is None else self.max_trials
self.grid = self.build_grid(
self.space, n_values_dict, getattr(self, "max_trials", 10000)
self.space,
n_values_dict,
max_trials=max_trials,
)
self.index = 0

Expand Down
2 changes: 2 additions & 0 deletions src/orion/algo/pbt/pb2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from orion.algo.pbt.pbt import PBT
from orion.core.utils.flatten import flatten
from orion.core.utils.random_state import RandomState, control_randomness
from orion.core.worker.transformer import ReshapedSpace, TransformedSpace
from orion.core.worker.trial import Trial

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -172,6 +173,7 @@ def _generate_offspring(self, trial):
]

new_trial = trial_to_branch.branch(params=new_params)
assert isinstance(self.space, (TransformedSpace, ReshapedSpace))
new_trial = self.space.transform(self.space.reverse(new_trial))

logger.debug("Attempt %s - Creating new trial %s", attempts, new_trial)
Expand Down
9 changes: 5 additions & 4 deletions src/orion/algo/pbt/pbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
explore: dict | None = None,
fork_timeout: int = 60,
):
super().__init__(space)
if exploit is None:
exploit = {
"of_type": "PipelineExploit",
Expand Down Expand Up @@ -217,7 +218,6 @@ def __init__(

self.lineages = Lineages()

super().__init__(space)
self.seed = seed
self.population_size = population_size
self.generations = generations
Expand Down Expand Up @@ -318,14 +318,15 @@ def suggest(self, num: int) -> list[Trial]:
A list of trials representing values suggested by the algorithm.

"""

# Sample points until num is met, or population_size
num_random_samples = min(max(self.population_size - self._num_root, 0), num)
assert num > 0
logger.debug(
"PBT has %s pending or completed trials at root, %s broken trials.",
self._num_root,
len(self.lineages) - self._num_root,
)

# Sample points until num is met, or population_size
num_random_samples = min(max(self.population_size - self._num_root, 0), num)
logger.debug("Sampling %s new trials", num_random_samples)
trials = self._sample(num_random_samples)
logger.debug("Sampled %s new trials", len(trials))
Expand Down
13 changes: 10 additions & 3 deletions src/orion/algo/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
Draw and deliver samples from prior defined in problem's domain.

"""
from __future__ import annotations

from typing import Sequence

import numpy

from orion.algo.base import BaseAlgorithm
from orion.algo.space import Space


class Random(BaseAlgorithm):
Expand All @@ -23,10 +28,12 @@ class Random(BaseAlgorithm):

"""

def __init__(self, space, seed=None):
super().__init__(space, seed=seed)
def __init__(self, space: Space, seed: int | Sequence[int] | None = None):
super().__init__(space)
self.seed = seed
self.seed_rng(seed)

def seed_rng(self, seed):
def seed_rng(self, seed: int | Sequence[int] | None):
"""Seed the state of the random number generator.

:param seed: Integer seed for the random number generator.
Expand Down
27 changes: 22 additions & 5 deletions src/orion/algo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,27 @@
import copy
from collections import defaultdict
from logging import getLogger as get_logger
from typing import Any, Container, Iterator, Mapping
from typing import Any, Container, Iterable, Iterator, Mapping

from orion.core.worker.trial import Trial, TrialCM

logger = get_logger(__name__)


class Registry(Container[Trial]):
"""In-memory container for the trials that the algorithm suggests/observes/etc."""
"""In-memory container for the trials that the algorithm suggests/observes/etc.

def __init__(self):
This behaves a bit like a managed dictionary, but the "keys" are trials ids, which
(at the time of writing) can vary depending on how we chose to compute them.
"""

def __init__(self, trials: Iterable[Trial] = ()):
self._trials: dict[str, Trial] = {}
for trial in trials:
self.register(trial)

def __repr__(self) -> str:
return f"{type(self).__qualname__}({list(iter(self))})"

def __contains__(self, trial_or_id: str | Trial | Any) -> bool:
if isinstance(trial_or_id, TrialCM):
Expand Down Expand Up @@ -94,7 +103,7 @@ def get_existing(self, trial: Trial) -> Trial:
class RegistryMapping(Mapping[Trial, "list[Trial]"]):
"""A map between the original and transformed registries.

This object is used in the `SpaceTransformAlgoWrapper` to check if a trial in the original space
This object is used in the `SpaceTransform` to check if a trial in the original space
has equivalent trials in the transformed space.

The goal is to make it so the algorithms don't have to care about the transforms/etc.
Expand Down Expand Up @@ -123,10 +132,15 @@ def set_state(self, statedict: dict):
self._mapping = copy.deepcopy(statedict["_mapping"])

def __iter__(self) -> Iterator[Trial]:
"""Iterate over the trials in the original registry."""
for trial_id in self._mapping:
yield self.original_registry[trial_id]

def __len__(self) -> int:
"""Give the number of trials in the mapping.

This should be the same as the number of trials in the original registry.
"""
return len(self._mapping)

def __contains__(self, trial: Trial):
Expand Down Expand Up @@ -160,11 +174,14 @@ def register(self, original_trial: Trial, transformed_trial: Trial) -> str:
self._mapping[original_trial_id].add(transformed_trial_id)
return original_trial_id

def __repr__(self) -> str:
return f"{type(self).__qualname__}({list((trial, self.get_trials(trial)) for trial in self)})"


def _get_id(trial: Trial) -> str:
"""Returns the unique identifier to be used to store the trial.

Only to be used internally in this module. This ignores the `experiment`
attribute of the trial.
"""
return Trial.compute_trial_hash(trial, ignore_experiment=True)
return Trial.compute_trial_hash(trial)
6 changes: 3 additions & 3 deletions src/orion/benchmark/benchmark_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,9 @@ def get_or_create_benchmark(
)

if input_configure and input_benchmark.configuration != benchmark.configuration:
logger.warn(
"Benchmark with same name is found but has different configuration, "
"which will be used for this creation.\n{}".format(benchmark.configuration)
logger.warning(
f"Benchmark with same name is found but has different configuration, "
f"which will be used for this creation.\n{benchmark.configuration}"
)

if benchmark_id is None:
Expand Down
Loading