Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add number of suggest() attempts in Algo wrapper #883

Merged
merged 9 commits into from
Apr 26, 2022
Merged
97 changes: 59 additions & 38 deletions src/orion/core/worker/primary_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def __init__(self, space: Space, algorithm: AlgoType):
original_registry=self.registry,
transformed_registry=self.algorithm.registry,
)
self.max_suggest_attempts = 100

@property
def original_space(self) -> Space:
Expand Down Expand Up @@ -116,7 +117,7 @@ def set_state(self, state_dict: dict) -> None:
self.registry.set_state(state_dict["registry"])
self.registry_mapping.set_state(state_dict["registry_mapping"])

def suggest(self, num: int) -> list[Trial] | None:
def suggest(self, num: int) -> list[Trial]:
"""Suggest a `num` of new sets of parameters.

Parameters
Expand All @@ -137,48 +138,68 @@ def suggest(self, num: int) -> list[Trial] | None:
New parameters must be compliant with the problem's domain `orion.algo.space.Space`.

"""
transformed_trials = self.algorithm.suggest(num)

if transformed_trials is None:
return None

trials: list[Trial] = []
for transformed_trial in transformed_trials:
if transformed_trial not in self.transformed_space:
raise ValueError(
f"Trial {transformed_trial.id} not contained in space:\n"
f"Params: {transformed_trial.params}\n"
f"Space: {self.transformed_space}"
)
original = self.transformed_space.reverse(transformed_trial)
if original in self.registry:
logger.debug(
"Already have a trial that matches %s in the registry.", original
)
# We already have a trial that is equivalent to this one.
# Fetch the actual trial (with the status and possibly results)
original = self.registry.get_existing(original)
logger.debug("Matching trial (with results/status): %s", original)

# Copy over the status and results from the original to the transformed trial
# and observe it.
transformed_trial = _copy_status_and_results(
original_trial=original, transformed_trial=transformed_trial
)
for suggest_attempt in range(1, self.max_suggest_attempts + 1):
bouthilx marked this conversation as resolved.
Show resolved Hide resolved
transformed_trials: list[Trial] | None = self.algorithm.suggest(num)
transformed_trials = transformed_trials or []

for transformed_trial in transformed_trials:
if transformed_trial not in self.transformed_space:
raise ValueError(
f"Trial {transformed_trial.id} not contained in space:\n"
f"Params: {transformed_trial.params}\n"
f"Space: {self.transformed_space}"
)
original = self.transformed_space.reverse(transformed_trial)
if original in self.registry:
logger.debug(
"Already have a trial that matches %s in the registry.",
original,
)
# We already have a trial that is equivalent to this one.
# Fetch the actual trial (with the status and possibly results)
original = self.registry.get_existing(original)
logger.debug("Matching trial (with results/status): %s", original)

# Copy over the status and results from the original to the transformed trial
# and observe it.
transformed_trial = _copy_status_and_results(
original_trial=original, transformed_trial=transformed_trial
)
logger.debug(
"Transformed trial (with results/status): %s", transformed_trial
)
self.algorithm.observe([transformed_trial])
else:
# We haven't seen this trial before. Register it.
self.registry.register(original)
trials.append(original)

# NOTE: Here we DON'T register the transformed trial, we let the algorithm do it
# itself in its `suggest`.
# Register the equivalence between these trials.
self.registry_mapping.register(original, transformed_trial)

if trials:
if suggest_attempt > 1:
logger.debug(
f"Succeeded in suggesting new trials after {suggest_attempt} attempts."
)
return trials

if self.is_done:
logger.debug(
"Transformed trial (with results/status): %s", transformed_trial
f"Algorithm is done! (after {suggest_attempt} sampling attempts)."
)
self.algorithm.observe([transformed_trial])
else:
# We haven't seen this trial before. Register it.
self.registry.register(original)
trials.append(original)

# NOTE: Here we DON'T register the transformed trial, we let the algorithm do it itself
# in its `suggest`.
# Register the equivalence between these trials.
self.registry_mapping.register(original, transformed_trial)
return trials
break

logger.warning(
f"Unable to sample a new trial from the algorithm, even after "
f"{self.max_suggest_attempts} attempts! Returning an empty list."
)
return []

def observe(self, trials: list[Trial]) -> None:
"""Observe evaluated trials.
Expand Down
7 changes: 6 additions & 1 deletion tests/functional/algos/test_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,12 @@ def test_cardinality_stop_loguniform(algorithm):


@pytest.mark.parametrize(
"algorithm", algorithm_configs.values(), ids=list(algorithm_configs.keys())
"algorithm",
[
pytest.param(value, marks=pytest.mark.skipif(key == "tpe", reason="Flaky test"))
for key, value in algorithm_configs.items()
],
ids=list(algorithm_configs.keys()),
)
def test_with_fidelity(algorithm):
"""Test a scenario with fidelity."""
Expand Down
74 changes: 71 additions & 3 deletions tests/unittests/core/test_primary_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from __future__ import annotations

import copy
import logging
import typing
from typing import Any, ClassVar, TypeVar

import pytest
from pytest import MonkeyPatch

from orion.algo.base import BaseAlgorithm, algo_factory
from orion.algo.space import Space
Expand Down Expand Up @@ -147,6 +149,63 @@ def test_judge(
del fixed_suggestion._params[-1]
palgo.judge(fixed_suggestion, 8)

def test_insists_when_algo_doesnt_suggest_new_trials(
self,
algo_wrapper: SpaceTransformAlgoWrapper[StupidAlgo],
monkeypatch: MonkeyPatch,
):
"""Test that when the algo can't produce a new trial, the wrapper insists and asks again."""
calls: int = 0
algo_wrapper.max_suggest_attempts = 10

# Make the wrapper insist enough so that it actually
# gets a trial after asking enough times:

def _suggest(num: int) -> list[Trial]:
nonlocal calls
calls += 1
if calls < 5:
return []
return [algo_wrapper.algorithm.fixed_suggestion]

monkeypatch.setattr(algo_wrapper.algorithm, "suggest", _suggest)
trial = algo_wrapper.suggest(1)[0]
assert calls == 5
assert trial in algo_wrapper.space

def test_warns_when_unable_to_sample_new_trial(
self,
algo_wrapper: SpaceTransformAlgoWrapper[StupidAlgo],
caplog: pytest.LogCaptureFixture,
monkeypatch: MonkeyPatch,
):
"""Test that when the algo can't produce a new trial even after the max number of attempts,
a warning is logged and an empty list is returned.
"""

calls: int = 0

def _suggest(num: int) -> list[Trial]:
nonlocal calls
calls += 1
if calls < 5:
return []
return [algo_wrapper.algorithm.fixed_suggestion]

monkeypatch.setattr(algo_wrapper.algorithm, "suggest", _suggest)

algo_wrapper.max_suggest_attempts = 3

with caplog.at_level(logging.WARNING):
out = algo_wrapper.suggest(1)
assert calls == 3
assert out == []
assert len(caplog.record_tuples) == 1
log_record = caplog.record_tuples[0]
assert log_record[1] == logging.WARNING and log_record[2].startswith(
"Unable to sample a new trial"
)


class StupidAlgo(BaseAlgorithm):
"""A dumb algo that always returns the same trial."""
Expand All @@ -155,14 +214,23 @@ class StupidAlgo(BaseAlgorithm):
requires_shape: ClassVar[str | None] = "flattened"
requires_dist: ClassVar[str | None] = "linear"

def __init__(self, space: Space, fixed_suggestion: Trial):
def __init__(
self,
space: Space,
fixed_suggestion: Trial,
):
super().__init__(space)
self.fixed_suggestion = fixed_suggestion
assert fixed_suggestion in space

def suggest(self, num):
self.register(self.fixed_suggestion)
return [self.fixed_suggestion]
# NOTE: can't register the trial if it's already here. The fixed suggestion is always "new",
# but the algorithm actually observes it at some point. Therefore, we don't overwrite what's
# already in the registry.
if not self.has_suggested(self.fixed_suggestion):
self.register(self.fixed_suggestion)
return [self.fixed_suggestion]
return []


@pytest.fixture()
Expand Down