Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Config deprecation update #138

Merged
merged 2 commits into from
Aug 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions boa/config/__main__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import pathlib
# TODO: Move this to a docs only feature and move back to PYYAML
import pathlib # pragma: no cover

import click
from ruamel.yaml import YAML
from ruamel.yaml.compat import StringIO
import click # pragma: no cover
from ruamel.yaml import YAML # pragma: no cover
from ruamel.yaml.compat import StringIO # pragma: no cover

from boa.config.config import add_comment_recurse, generate_default_doc_config
from boa.config.config import ( # pragma: no cover
add_comment_recurse,
generate_default_doc_config,
)


class YAMLDumper(YAML):
class YAMLDumper(YAML): # pragma: no cover
def dump(self, data, stream=None, **kw):
inefficient = False
if stream is None:
Expand All @@ -18,14 +22,14 @@ def dump(self, data, stream=None, **kw):
return stream.getvalue()


@click.command()
@click.option(
@click.command() # pragma: no cover
@click.option( # pragma: no cover
"--output-path",
"-o",
type=click.Path(exists=False, file_okay=True, dir_okay=False, path_type=pathlib.Path),
default="default_config.yaml",
)
def main(output_path):
def main(output_path): # pragma: no cover
"""Generate a default config file with comments."""
d, c = generate_default_doc_config()
yaml = YAML()
Expand All @@ -35,5 +39,5 @@ def main(output_path):
yaml.dump(data, f)


if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
main()
24 changes: 19 additions & 5 deletions boa/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
NL = "\n"


def strip_white_space(s: str, strip_all=True):
def strip_white_space(s: str, strip_all=True): # pragma: no cover # Used in docs
if not s:
return s
if strip_all:
Expand All @@ -61,7 +61,7 @@ def strip_white_space(s: str, strip_all=True):
class _Utils:
_filtered_dict_fields: ClassVar[list[str]] = None

def to_dict(self) -> dict:
def to_dict(self) -> dict: # pragma: no cover # Used in docsi
def vs(inst, attrib, val):
if is_dataclass(val):
return dc_asdict(val)
Expand Down Expand Up @@ -180,6 +180,18 @@ class BOAMetric(_Utils):
wrapper functions."""
},
)
metric_func_kwargs: Optional[dict] = field(
default=None,
metadata={
"doc": """Additional keyword arguments to be passed to the metric function.
This is useful when you are setting up a metric and only want to pass the metric function
additional arguments.
Example: Passing `metric_func_kwargs={"sqaured": false}` to sklearn mean_squared_error
to get the root mean squared error instead of the mean squared error
(Though BOA already has :class:`RMSE <.RMSE>` available from sklrean built in if needed).
"""
},
)

def __init__(self, *args, lower_is_better: Optional[bool] = None, **kwargs):
if lower_is_better is not None:
Expand Down Expand Up @@ -764,7 +776,7 @@ def boa_params_to_wpr(params: list[dict], mapping, from_trial=True):
return new_params


def generate_default_doc_config():
def generate_default_doc_config(): # pragma: no cover

config = BOAConfig(
**{
Expand Down Expand Up @@ -806,7 +818,9 @@ def set_metadata_default_doc_recurse(d: dict, config):
return d, config


def add_comment_recurse(d: ruamel.yaml.comments.CommentedMap, config=None, where="before", depth=0, indent=2):
def add_comment_recurse(
d: ruamel.yaml.comments.CommentedMap, config=None, where="before", depth=0, indent=2
): # pragma: no cover
fields = fields_dict(type(config)) if attr.has(config) else {}
if isinstance(d, dict):
for key in d:
Expand Down Expand Up @@ -835,7 +849,7 @@ def add_comment_recurse(d: ruamel.yaml.comments.CommentedMap, config=None, where
return d


if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
from tests.conftest import TEST_CONFIG_DIR

c = BOAConfig.from_jsonlike(pathlib.Path(TEST_CONFIG_DIR / "test_config_generic.yaml"))
19 changes: 12 additions & 7 deletions boa/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,14 @@ class RootMeanSquaredError(SklearnMetric):
def __init__(
self,
lower_is_better=True,
metric_func_kwargs=(("squared", False),),
metric_func_kwargs=None,
*args,
**kwargs,
):
if metric_func_kwargs == (("squared", False),):
metric_func_kwargs = dict((y, x) for x, y in metric_func_kwargs)
if isinstance(metric_func_kwargs, dict):
metric_func_kwargs.update({"squared": False})
else:
metric_func_kwargs = {"squared": False}
super().__init__(
lower_is_better=lower_is_better,
metric_func_kwargs=metric_func_kwargs,
Expand Down Expand Up @@ -297,15 +299,18 @@ def __init__(self, lower_is_better=True, *args, **kwargs):
def get_metric_from_config(config: BOAMetric, instantiate=True, **kwargs) -> ModularMetric:
kwargs["lower_is_better"] = config.minimize
kwargs["metric_name"] = config.metric
kw = {**config.to_dict(), **kwargs}
if kw.get("metric_func_kwargs") is None:
kw.pop("metric_func_kwargs")
if config.metric_type == MetricType.METRIC or config.metric_type == MetricType.BOA_METRIC:
metric = get_metric_by_class_name(instantiate=instantiate, **config.to_dict(), **kwargs)
metric = get_metric_by_class_name(instantiate=instantiate, **kw)
elif config.metric_type == MetricType.SKLEARN_METRIC:
kwargs["sklearn_"] = True
metric = get_metric_by_class_name(instantiate=instantiate, **config.to_dict(), **kwargs)
metric = get_metric_by_class_name(instantiate=instantiate, **kw)
elif config.metric_type == MetricType.SYNTHETIC_METRIC:
metric = setup_synthetic_metric(instantiate=instantiate, **config.to_dict(), **kwargs)
metric = setup_synthetic_metric(instantiate=instantiate, **kw)
elif config.metric_type == MetricType.PASSTHROUGH: # only name but no metric type
metric = PassThroughMetric(**config.to_dict(), **kwargs)
metric = PassThroughMetric(**kw)
else:
# TODO link to docs for configuration when it exists
raise KeyError("No valid configuration for metric found.")
Expand Down
3 changes: 2 additions & 1 deletion boa/metrics/modular_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,11 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs):

def _evaluate(self, params: TParameterization, **kwargs) -> float:
kwargs.update(params.pop("kwargs"))

return self.f(**get_dictionary_from_callable(self.metric_to_eval, kwargs))

def f(self, *args, **kwargs):
if self.metric_func_kwargs: # always pass the metric_func_kwargs, don't fail silently
kwargs.update(self.metric_func_kwargs)
return self.metric_to_eval(*args, **kwargs)

def clone(self) -> "Metric":
Expand Down
13 changes: 8 additions & 5 deletions boa/scripts/moo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import tempfile
from pathlib import Path

import torch
Expand Down Expand Up @@ -30,11 +31,13 @@ def fetch_trial_data(self, trial, metric_properties, metric_name, *args, **kwarg


def main():
config_path = Path(__file__).resolve().parent / "moo.yaml"
wrapper = Wrapper(config_path=config_path)
controller = Controller(wrapper=wrapper)
controller.initialize_scheduler()
return controller.run()
with tempfile.TemporaryDirectory() as temp_dir:
experiment_dir = Path(temp_dir)
config_path = Path(__file__).resolve().parent / "moo.yaml"
wrapper = Wrapper(config_path=config_path, experiment_dir=experiment_dir)
controller = Controller(wrapper=wrapper)
controller.initialize_scheduler()
return controller.run()


if __name__ == "__main__":
Expand Down
33 changes: 17 additions & 16 deletions boa/scripts/moo.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
# MultiObjective Optimization config
optimization_options:
objective_options:
objective_thresholds:
- branin >= -18.0
- currin >= -6.0
objectives:
- name: branin
lower_is_better: False
noise_sd: 0
- name: currin
lower_is_better: False
noise_sd: 0
objective:
objective_thresholds:
- branin >= -18.0
- currin >= -6.0
metrics:
- name: branin
lower_is_better: False
noise_sd: 0
- name: currin
lower_is_better: False
noise_sd: 0

experiment:
name: "moo_run"
trials: 50
scheduler:
n_trials: 30

parameters:
x0:
Expand All @@ -24,4 +22,7 @@ parameters:
x1:
type: range
bounds: [0, 1]
value_type: float
value_type: float

script_options:
exp_name: "moo_run"
7 changes: 4 additions & 3 deletions boa/scripts/run_branin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import shutil
import tempfile
import time
from pathlib import Path

Expand All @@ -21,9 +22,9 @@


def main():
# with tempfile.TemporaryDirectory() as exp_dir:
exp_dir = "."
return run_opt(exp_dir)
with tempfile.TemporaryDirectory() as temp_dir:
experiment_dir = Path(temp_dir)
return run_opt(exp_dir=experiment_dir)


def run_opt(exp_dir):
Expand Down
32 changes: 16 additions & 16 deletions boa/scripts/synth_func_config.yaml
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
optimization_options:
objective_options: # can also use the key moo
objectives:
- name: rmse
metric: RMSE
noise_sd: .1
objective: # can also use the key moo
metrics:
- name: rmse
metric: RMSE
noise_sd: .1

generation_strategy:
steps:
- model: SOBOL
num_trials: 5
- model: GPEI
num_trials: -1
scheduler:
total_trials: 20
on_reload:
generation_strategy:
steps:
- model: SOBOL
num_trials: 5
- model: GPEI
num_trials: -1
scheduler:
total_trials: 20

parameters:
x0:
Expand All @@ -25,10 +23,12 @@ parameters:
'type': 'range'
'value_type': 'float'

# options only needed by the model and not BOA
# You can put anything here that your model might need
model_options:
input_size: 15

script_options:
wrapper_path: ./script_wrappers.py
wrapper_name: Wrapper
append_timestamp: True
append_timestamp: True
6 changes: 6 additions & 0 deletions boa/template.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import jinja2


def render_template(template_name, **kwargs):
template = jinja2.Environment(loader=jinja2.FileSystemLoader("templates")).get_template(template_name)
return template.render(**kwargs)
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ dependencies:
- ax-platform==0.3.3
- ruamel.yaml
- attrs

- jinja2
1 change: 1 addition & 0 deletions environment_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ dependencies:
- ruamel.yaml
- domdfcoding::attr_utils
- attrs
- jinja2

## Jupyter and sphinx jupyter
- myst-nb
Expand Down
28 changes: 16 additions & 12 deletions tests/1unit_tests/test_config_deprecation_normalization.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import pytest

from boa import BOAConfig


def test_config_deprecation_normalization(
synth_config,
metric_config,
gen_strat1_config,
soo_config,
moo_config,
pass_through_config,
scripts_moo,
scripts_synth_func,
):
for config in [synth_config, metric_config, gen_strat1_config, soo_config, moo_config, pass_through_config]:
assert isinstance(config, BOAConfig)
@pytest.mark.parametrize(
"config",
[
"synth_config_deprecated",
"metric_config_deprecated",
"gen_strat1_config_deprecated",
"soo_config_deprecated",
"moo_config_deprecated",
"pass_through_config_deprecated",
], # 1. pass fixture name as a string
)
def test_config_deprecation_normalization(config, request):
config = request.getfixturevalue(config)
assert isinstance(config, BOAConfig)
4 changes: 2 additions & 2 deletions tests/1unit_tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ def test_gen_steps_from_config(gen_strat1_config):
assert gs1 == gs2


def test_auto_gen_use_saasbo(saasbo_config):
controller = Controller(config=saasbo_config, wrapper=ScriptWrapper(config=saasbo_config))
def test_auto_gen_use_saasbo(saasbo_config, tmp_path):
controller = Controller(config=saasbo_config, wrapper=ScriptWrapper(config=saasbo_config, experiment_dir=tmp_path))
exp = get_experiment(
config=controller.config, runner=WrappedJobRunner(wrapper=controller.wrapper), wrapper=controller.wrapper
)
Expand Down
Loading