Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reap gen_multiple and replace with gen_for_multiple_with_multiple #3436

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 107 additions & 197 deletions ax/generation_strategy/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,134 +356,6 @@ def gen(
)
return gr[0]

def _gen_with_multiple_nodes(
self,
experiment: Experiment,
data: Data | None = None,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
n: int | None = None,
fixed_features: ObservationFeatures | None = None,
arms_per_node: dict[str, int] | None = None,
) -> list[GeneratorRun]:
"""Produces a List of GeneratorRuns for a single trial, either ``Trial`` or
``BatchTrial``, and if producing a ``BatchTrial``, allows for multiple
``GenerationNode``-s (and therefore models) to be used to generate
``GeneratorRun``-s for that trial.


Args:
experiment: Experiment, for which the generation strategy is producing
a new generator run in the course of `gen`, and to which that
generator run will be added as trial(s). Information stored on the
experiment (e.g., trial statuses) is used to determine which model
will be used to produce the generator run returned from this method.
data: Optional data to be passed to the underlying model's `gen`, which
is called within this method and actually produces the resulting
generator run. By default, data is all data on the `experiment`.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
n: Integer representing how many arms should be in the generator run
produced by this method. NOTE: Some underlying models may ignore
the `n` and produce a model-determined number of arms. In that
case this method will also output a generator run with number of
arms that can differ from `n`.
fixed_features: An optional set of ``ObservationFeatures`` that will be
passed down to the underlying models. Note: if provided this will
override any algorithmically determined fixed features so it is
important to specify all necessary fixed features.
arms_per_node: An optional map from node name to the number of arms to
generate from that node. If not provided, will default to the number
of arms specified in the node's ``InputConstructors`` or n if no
``InputConstructors`` are defined on the node. We expect either n or
arms_per_node to be provided, but not both, and this is an advanced
argument that should only be used by advanced users.

Returns:
A list of ``GeneratorRuns`` for a single trial.
"""
grs = []
continue_gen_for_trial = True
pending_observations = deepcopy(pending_observations) or {}
self.experiment = experiment
self._validate_arms_per_node(arms_per_node=arms_per_node)
pack_gs_gen_kwargs = self._initalize_gen_kwargs(
experiment=experiment,
grs_this_gen=grs,
data=data,
n=n,
fixed_features=fixed_features,
arms_per_node=arms_per_node,
pending_observations=pending_observations,
)
if self.optimization_complete:
raise GenerationStrategyCompleted(
f"Generation strategy {self} generated all the trials as "
"specified in its nodes."
)

while continue_gen_for_trial:
pack_gs_gen_kwargs["grs_this_gen"] = grs
should_transition, node_to_gen_from_name = (
self._curr.should_transition_to_next_node(
raise_data_required_error=False
)
)
node_to_gen_from = self.nodes_dict[node_to_gen_from_name]
if should_transition:
node_to_gen_from._previous_node_name = node_to_gen_from_name
# reset should skip as conditions may have changed, do not reset
# until now so node properites can be as up to date as possible
node_to_gen_from._should_skip = False
arms_from_node = self._determine_arms_from_node(
node_to_gen_from=node_to_gen_from,
n=n,
gen_kwargs=pack_gs_gen_kwargs,
)
fixed_features_from_node = self._determine_fixed_features_from_node(
node_to_gen_from=node_to_gen_from,
gen_kwargs=pack_gs_gen_kwargs,
)
sq_ft_from_node = self._determine_sq_features_from_node(
node_to_gen_from=node_to_gen_from, gen_kwargs=pack_gs_gen_kwargs
)
self._maybe_transition_to_next_node()
if node_to_gen_from._should_skip:
continue
self._fit_current_model(data=data, status_quo_features=sq_ft_from_node)
self._curr.generator_run_limit(raise_generation_errors=True)
if arms_from_node != 0:
try:
curr_node_gr = self._curr.gen(
n=arms_from_node,
pending_observations=pending_observations,
arms_by_signature_for_deduplication=(
experiment.arms_by_signature_for_deduplication
),
fixed_features=fixed_features_from_node,
)
except DataRequiredError as err:
# Model needs more data, so we log the error and return
# as many generator runs as we were able to produce, unless
# no trials were produced at all (in which case its safe to raise).
if len(grs) == 0:
raise
logger.debug(f"Model required more data: {err}.")
break
self._generator_runs.append(curr_node_gr)
grs.append(curr_node_gr)
# ensure that the points generated from each node are marked as pending
# points for future calls to gen
pending_observations = extend_pending_observations(
experiment=experiment,
pending_observations=pending_observations,
# only pass in the most recent generator run to avoid unnecessary
# deduplication in extend_pending_observations
generator_runs=[grs[-1]],
)
continue_gen_for_trial = self._should_continue_gen_for_trial()
return grs

def gen_for_multiple_trials_with_multiple_models(
self,
experiment: Experiment,
Expand Down Expand Up @@ -750,7 +622,7 @@ def _step_repr(self, step_str_rep: str) -> str:
for step in self._nodes:
num_trials = remaining_trials
for criterion in step.transition_criteria:
# backwards compatility of num_trials with MinTrials criterion
# backwards compatibility of num_trials with MinTrials criterion
if (
criterion.criterion_class == "MinTrials"
and isinstance(criterion, TrialBasedCriterion)
Expand Down Expand Up @@ -819,27 +691,20 @@ def __repr__(self) -> str:
return gs_str

# ------------------------- Candidate generation helpers. -------------------------

def _gen_multiple(
def _gen_with_multiple_nodes(
self,
experiment: Experiment,
num_generator_runs: int,
data: Data | None = None,
n: int = 1,
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
status_quo_features: ObservationFeatures | None = None,
**model_gen_kwargs: Any,
n: int | None = None,
fixed_features: ObservationFeatures | None = None,
arms_per_node: dict[str, int] | None = None,
) -> list[GeneratorRun]:
"""Produce multiple generator runs at once, to be made into multiple
trials on the experiment.
"""Produces a List of GeneratorRuns for a single trial, either ``Trial`` or
``BatchTrial``, and if producing a ``BatchTrial``, allows for multiple
``GenerationNode``-s (and therefore models) to be used to generate
``GeneratorRun``-s for that trial.

NOTE: This is used to ensure that maximum parallelism and number
of trials per node are not violated when producing many generator
runs from this generation strategy in a row. Without this function,
if one generates multiple generator runs without first making any
of them into running trials, generation strategy cannot enforce that it only
produces as many generator runs as are allowed by the parallelism
limit and the limit on number of trials in current node.

Args:
experiment: Experiment, for which the generation strategy is producing
Expand All @@ -850,64 +715,109 @@ def _gen_multiple(
data: Optional data to be passed to the underlying model's `gen`, which
is called within this method and actually produces the resulting
generator run. By default, data is all data on the `experiment`.
n: Integer representing how many arms should be in the generator run
produced by this method. NOTE: Some underlying models may ignore
the ``n`` and produce a model-determined number of arms. In that
case this method will also output a generator run with number of
arms that can differ from ``n``.
pending_observations: A map from metric name to pending
observations for that metric, used by some models to avoid
resuggesting points that are currently being evaluated.
model_gen_kwargs: Keyword arguments that are passed through to
``GenerationNode.gen``, which will pass them through to
``GeneratorSpec.gen``, which will pass them to ``Adapter.gen``.
status_quo_features: An ``ObservationFeature`` of the status quo arm,
needed by some models during fit to accomadate relative constraints.
Includes the status quo parameterization and target trial index.
n: Integer representing how many arms should be in the generator run
produced by this method. NOTE: Some underlying models may ignore
the `n` and produce a model-determined number of arms. In that
case this method will also output a generator run with number of
arms that can differ from `n`.
fixed_features: An optional set of ``ObservationFeatures`` that will be
passed down to the underlying models. Note: if provided this will
override any algorithmically determined fixed features so it is
important to specify all necessary fixed features.
arms_per_node: An optional map from node name to the number of arms to
generate from that node. If not provided, will default to the number
of arms specified in the node's ``InputConstructors`` or n if no
``InputConstructors`` are defined on the node. We expect either n or
arms_per_node to be provided, but not both, and this is an advanced
argument that should only be used by advanced users.

Returns:
A list of ``GeneratorRuns`` for a single trial.
"""
self.experiment = experiment
self._maybe_transition_to_next_node()
self._fit_current_model(data=data, status_quo_features=status_quo_features)
# Get GeneratorRun limit that respects the node's transition criterion that
# affect the number of generator runs that can be produced.
gr_limit = self._curr.generator_run_limit(raise_generation_errors=True)
if gr_limit == -1:
num_generator_runs = max(num_generator_runs, 1)
else:
num_generator_runs = max(min(num_generator_runs, gr_limit), 1)
generator_runs = []
if self.optimization_complete:
raise GenerationStrategyCompleted(
f"Generation strategy {self} generated all the trials as "
"specified in its nodes."
)
grs = []
continue_gen_for_trial = True
pending_observations = deepcopy(pending_observations) or {}
for _ in range(num_generator_runs):
try:
generator_run = self._curr.gen(
n=n,
pending_observations=pending_observations,
arms_by_signature_for_deduplication=(
experiment.arms_by_signature_for_deduplication
),
**model_gen_kwargs,
)
self._validate_arms_per_node(arms_per_node=arms_per_node)
pack_gs_gen_kwargs = self._initialize_gen_kwargs(
experiment=experiment,
grs_this_gen=grs,
data=data,
n=n,
fixed_features=fixed_features,
arms_per_node=arms_per_node,
pending_observations=pending_observations,
)

except DataRequiredError as err:
# Model needs more data, so we log the error and return
# as many generator runs as we were able to produce, unless
# no trials were produced at all (in which case its safe to raise).
if len(generator_runs) == 0:
raise
logger.debug(f"Model required more data: {err}.")
break

self._generator_runs.append(generator_run)
generator_runs.append(generator_run)

# Extend the `pending_observation` with newly generated point(s)
# to avoid repeating them.
pending_observations = extend_pending_observations(
experiment=experiment,
pending_observations=pending_observations,
generator_runs=[generator_run],
while continue_gen_for_trial:
pack_gs_gen_kwargs["grs_this_gen"] = grs
should_transition, node_to_gen_from_name = (
self._curr.should_transition_to_next_node(
raise_data_required_error=False
)
)
node_to_gen_from = self.nodes_dict[node_to_gen_from_name]
if should_transition:
node_to_gen_from._previous_node_name = node_to_gen_from_name
# reset should skip as conditions may have changed, do not reset
# until now so node properties can be as up to date as possible
node_to_gen_from._should_skip = False
arms_from_node = self._determine_arms_from_node(
node_to_gen_from=node_to_gen_from,
n=n,
gen_kwargs=pack_gs_gen_kwargs,
)
fixed_features_from_node = self._determine_fixed_features_from_node(
node_to_gen_from=node_to_gen_from,
gen_kwargs=pack_gs_gen_kwargs,
)
return generator_runs
sq_ft_from_node = self._determine_sq_features_from_node(
node_to_gen_from=node_to_gen_from, gen_kwargs=pack_gs_gen_kwargs
)
self._maybe_transition_to_next_node()
if node_to_gen_from._should_skip:
continue
self._fit_current_model(data=data, status_quo_features=sq_ft_from_node)
self._curr.generator_run_limit(raise_generation_errors=True)
if arms_from_node != 0:
try:
curr_node_gr = self._curr.gen(
n=arms_from_node,
pending_observations=pending_observations,
arms_by_signature_for_deduplication=(
experiment.arms_by_signature_for_deduplication
),
fixed_features=fixed_features_from_node,
)
except DataRequiredError as err:
# Model needs more data, so we log the error and return
# as many generator runs as we were able to produce, unless
# no trials were produced at all (in which case its safe to raise).
if len(grs) == 0:
raise
logger.debug(f"Model required more data: {err}.")
break
self._generator_runs.append(curr_node_gr)
grs.append(curr_node_gr)
# ensure that the points generated from each node are marked as pending
# points for future calls to gen
pending_observations = extend_pending_observations(
experiment=experiment,
pending_observations=pending_observations,
# only pass in the most recent generator run to avoid unnecessary
# deduplication in extend_pending_observations
generator_runs=[grs[-1]],
)
continue_gen_for_trial = self._should_continue_gen_for_trial()
return grs

def _should_continue_gen_for_trial(self) -> bool:
"""Determine if we should continue generating for the current trial, or end
Expand All @@ -934,7 +844,7 @@ def _should_continue_gen_for_trial(self) -> bool:
for tc in self._curr.transition_edges[next_node]
)

def _initalize_gen_kwargs(
def _initialize_gen_kwargs(
self,
experiment: Experiment,
grs_this_gen: list[GeneratorRun],
Expand Down Expand Up @@ -1059,7 +969,7 @@ def _determine_arms_from_node(
gen_kwargs: The kwargs passed to the ``GenerationStrategy``'s
gen call, including arms_per_node: an optional map from node name to
the number of arms to generate from that node. If not provided, will
default to the numberof arms specified in the node's
default to the number of arms specified in the node's
``InputConstructors`` or n if no``InputConstructors`` are defined on
the node.

Expand Down Expand Up @@ -1103,7 +1013,7 @@ def _fit_current_model(
data: Optional ``Data`` to fit or update with; if not specified, generation
strategy will obtain the data via ``experiment.lookup_data``.
status_quo_features: An ``ObservationFeature`` of the status quo arm,
needed by some models during fit to accomadate relative constraints.
needed by some models during fit to accommodate relative constraints.
Includes the status quo parameterization and target trial index.
"""
data = self.experiment.lookup_data() if data is None else data
Expand Down
Loading