Skip to content

Commit

Permalink
Fix bug where metric_func_kwargs weren't passed through
Browse files Browse the repository at this point in the history
  • Loading branch information
madeline-scyphers committed Aug 18, 2023
1 parent a1555c5 commit 3b8b5f9
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 35 deletions.
24 changes: 14 additions & 10 deletions boa/config/__main__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import pathlib
# TODO: Move this to a docs only feature and move back to PYYAML
import pathlib # pragma: no cover

import click
from ruamel.yaml import YAML
from ruamel.yaml.compat import StringIO
import click # pragma: no cover
from ruamel.yaml import YAML # pragma: no cover
from ruamel.yaml.compat import StringIO # pragma: no cover

from boa.config.config import add_comment_recurse, generate_default_doc_config
from boa.config.config import ( # pragma: no cover
add_comment_recurse,
generate_default_doc_config,
)


class YAMLDumper(YAML):
class YAMLDumper(YAML): # pragma: no cover
def dump(self, data, stream=None, **kw):
inefficient = False
if stream is None:
Expand All @@ -18,14 +22,14 @@ def dump(self, data, stream=None, **kw):
return stream.getvalue()


@click.command()
@click.option(
@click.command() # pragma: no cover
@click.option( # pragma: no cover
"--output-path",
"-o",
type=click.Path(exists=False, file_okay=True, dir_okay=False, path_type=pathlib.Path),
default="default_config.yaml",
)
def main(output_path):
def main(output_path): # pragma: no cover
"""Generate a default config file with comments."""
d, c = generate_default_doc_config()
yaml = YAML()
Expand All @@ -35,5 +39,5 @@ def main(output_path):
yaml.dump(data, f)


if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
main()
24 changes: 19 additions & 5 deletions boa/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
NL = "\n"


def strip_white_space(s: str, strip_all=True):
def strip_white_space(s: str, strip_all=True): # pragma: no cover # Used in docs
if not s:
return s
if strip_all:
Expand All @@ -61,7 +61,7 @@ def strip_white_space(s: str, strip_all=True):
class _Utils:
_filtered_dict_fields: ClassVar[list[str]] = None

def to_dict(self) -> dict:
def to_dict(self) -> dict: # pragma: no cover # Used in docsi
def vs(inst, attrib, val):
if is_dataclass(val):
return dc_asdict(val)
Expand Down Expand Up @@ -180,6 +180,18 @@ class BOAMetric(_Utils):
wrapper functions."""
},
)
metric_func_kwargs: Optional[dict] = field(
default=None,
metadata={
"doc": """Additional keyword arguments to be passed to the metric function.
This is useful when you are setting up a metric and only want to pass the metric function
additional arguments.
Example: Passing `metric_func_kwargs={"sqaured": false}` to sklearn mean_squared_error
to get the root mean squared error instead of the mean squared error
(Though BOA already has :class:`RMSE <.RMSE>` available from sklrean built in if needed).
"""
},
)

def __init__(self, *args, lower_is_better: Optional[bool] = None, **kwargs):
if lower_is_better is not None:
Expand Down Expand Up @@ -764,7 +776,7 @@ def boa_params_to_wpr(params: list[dict], mapping, from_trial=True):
return new_params


def generate_default_doc_config():
def generate_default_doc_config(): # pragma: no cover

config = BOAConfig(
**{
Expand Down Expand Up @@ -806,7 +818,9 @@ def set_metadata_default_doc_recurse(d: dict, config):
return d, config


def add_comment_recurse(d: ruamel.yaml.comments.CommentedMap, config=None, where="before", depth=0, indent=2):
def add_comment_recurse(
d: ruamel.yaml.comments.CommentedMap, config=None, where="before", depth=0, indent=2
): # pragma: no cover
fields = fields_dict(type(config)) if attr.has(config) else {}
if isinstance(d, dict):
for key in d:
Expand Down Expand Up @@ -835,7 +849,7 @@ def add_comment_recurse(d: ruamel.yaml.comments.CommentedMap, config=None, where
return d


if __name__ == "__main__":
if __name__ == "__main__": # pragma: no cover
from tests.conftest import TEST_CONFIG_DIR

c = BOAConfig.from_jsonlike(pathlib.Path(TEST_CONFIG_DIR / "test_config_generic.yaml"))
19 changes: 12 additions & 7 deletions boa/metrics/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,14 @@ class RootMeanSquaredError(SklearnMetric):
def __init__(
self,
lower_is_better=True,
metric_func_kwargs=(("squared", False),),
metric_func_kwargs=None,
*args,
**kwargs,
):
if metric_func_kwargs == (("squared", False),):
metric_func_kwargs = dict((y, x) for x, y in metric_func_kwargs)
if isinstance(metric_func_kwargs, dict):
metric_func_kwargs.update({"squared": False})
else:
metric_func_kwargs = {"squared": False}
super().__init__(
lower_is_better=lower_is_better,
metric_func_kwargs=metric_func_kwargs,
Expand Down Expand Up @@ -297,15 +299,18 @@ def __init__(self, lower_is_better=True, *args, **kwargs):
def get_metric_from_config(config: BOAMetric, instantiate=True, **kwargs) -> ModularMetric:
kwargs["lower_is_better"] = config.minimize
kwargs["metric_name"] = config.metric
kw = {**config.to_dict(), **kwargs}
if kw.get("metric_func_kwargs") is None:
kw.pop("metric_func_kwargs")
if config.metric_type == MetricType.METRIC or config.metric_type == MetricType.BOA_METRIC:
metric = get_metric_by_class_name(instantiate=instantiate, **config.to_dict(), **kwargs)
metric = get_metric_by_class_name(instantiate=instantiate, **kw)
elif config.metric_type == MetricType.SKLEARN_METRIC:
kwargs["sklearn_"] = True
metric = get_metric_by_class_name(instantiate=instantiate, **config.to_dict(), **kwargs)
metric = get_metric_by_class_name(instantiate=instantiate, **kw)
elif config.metric_type == MetricType.SYNTHETIC_METRIC:
metric = setup_synthetic_metric(instantiate=instantiate, **config.to_dict(), **kwargs)
metric = setup_synthetic_metric(instantiate=instantiate, **kw)
elif config.metric_type == MetricType.PASSTHROUGH: # only name but no metric type
metric = PassThroughMetric(**config.to_dict(), **kwargs)
metric = PassThroughMetric(**kw)
else:
# TODO link to docs for configuration when it exists
raise KeyError("No valid configuration for metric found.")
Expand Down
3 changes: 2 additions & 1 deletion boa/metrics/modular_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,11 @@ def fetch_trial_data(self, trial: BaseTrial, **kwargs):

def _evaluate(self, params: TParameterization, **kwargs) -> float:
kwargs.update(params.pop("kwargs"))

return self.f(**get_dictionary_from_callable(self.metric_to_eval, kwargs))

def f(self, *args, **kwargs):
if self.metric_func_kwargs: # always pass the metric_func_kwargs, don't fail silently
kwargs.update(self.metric_func_kwargs)
return self.metric_to_eval(*args, **kwargs)

def clone(self) -> "Metric":
Expand Down
34 changes: 29 additions & 5 deletions tests/1unit_tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from boa import (
BaseWrapper,
BOAMetric,
Controller,
get_metric_by_class_name,
get_metric_from_config,
setup_sklearn_metric,
setup_synthetic_metric,
)

Expand Down Expand Up @@ -57,13 +59,21 @@ def test_load_metric_by_name():
assert metric_synth.name == "something"
assert metric_synth.metric_to_eval.name == "FromBotorch_Hartmann4"

metric_sklearn = get_metric_by_class_name("MSE")
assert metric_sklearn.name == "MSE"
assert metric_sklearn.metric_to_eval.__name__ == "mean_squared_error"
metric_boa = get_metric_by_class_name("MSE")
assert metric_boa.name == "MSE"
assert metric_boa.metric_to_eval.__name__ == "mean_squared_error"

metric_sklearn = get_metric_by_class_name("MSE", name="something")
metric_boa = get_metric_by_class_name("MSE", name="something")
assert metric_boa.name == "something"
assert metric_boa.metric_to_eval.__name__ == "mean_squared_error"

metric_sklearn = setup_sklearn_metric("median_absolute_error")
assert metric_sklearn.name == "median_absolute_error"
assert metric_sklearn.metric_to_eval.__name__ == "median_absolute_error"

metric_sklearn = setup_sklearn_metric("median_absolute_error", name="something")
assert metric_sklearn.name == "something"
assert metric_sklearn.metric_to_eval.__name__ == "mean_squared_error"
assert metric_sklearn.metric_to_eval.__name__ == "median_absolute_error"


def test_load_metric_from_config(synth_config, generic_config):
Expand Down Expand Up @@ -177,3 +187,17 @@ def test_pass_through_metric_passes_through_value(pass_through_config, tmp_path)
f_ret = metric.f(wrapper.fetch_trial_data(trial, {}, name))
assert f_ret == data.df["mean"].iloc[0]
assert f_ret == trial.index


def test_can_override_metric_func_kwargs():
x = [1, 2, 3, 4, 5, 6]
y = [0.1 * i for i in reversed(x)]
returns = []
normalizers = ["iqr", "std", "mean", "range"]
for normalizer in normalizers:
config = BOAMetric(**dict(metric="NRMSE", metric_func_kwargs=dict(normalizer=normalizer)))
metric = get_metric_from_config(config)
assert metric.metric_to_eval.__name__ == "normalized_root_mean_squared_error"
returns.append(metric.f(x, y))
# All the normalized values should be different, ensuring that the kwargs are passed through
assert len(set(returns)) == len(normalizers)
12 changes: 9 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def moo_main_run(tmp_path_factory, cd_to_root_and_back_session):
def stand_alone_opt_package_run(request, tmp_path_factory, cd_to_root_and_back_session):
# parametrize the test to pass in script options in config as relative and absolute paths
if getattr(request, "param", None) == "absolute":
temp_dir = tmp_path_factory.mktemp("temp_dir")
wrapper_path = (TEST_DIR / "scripts/stand_alone_opt_package/wrapper.py").resolve()
config = {
"objective": {"metrics": [{"metric": "mean", "name": "Mean"}, {"metric": "RMSE", "info_only": True}]},
Expand All @@ -167,16 +168,21 @@ def stand_alone_opt_package_run(request, tmp_path_factory, cd_to_root_and_back_s
{"bounds": [-5.0, 10.0], "name": "x0", "type": "range"},
{"bounds": [0.0, 15.0], "name": "x1", "type": "range"},
],
"script_options": {"wrapper_path": str(wrapper_path)},
"script_options": {
"wrapper_path": str(wrapper_path),
"output_dir": str(temp_dir),
"exp_name": "test_experiment",
},
}
temp_dir = tmp_path_factory.mktemp("temp_dir")
config_path = temp_dir / "config.yaml"
with open(Path(config_path), "w") as file:
json.dump(config, file)
args = f"--config-path {config_path}"
else:
config_path = TEST_DIR / "scripts/stand_alone_opt_package/stand_alone_pkg_config.yaml"
args = f"--config-path {config_path} -td"

yield dunder_main.main(split_shell_command(f"--config-path {config_path} -td"), standalone_mode=False)
yield dunder_main.main(split_shell_command(args), standalone_mode=False)


@pytest.fixture(scope="session")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@ optimization_options:
num_trials: 5
- model: GPEI
num_trials: -1
scheduler:
total_trials: 10
experiment:
name: "test_experiment"
trials: 10

model_options:
model_specific_options:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ optimization_options:
objective_thresholds: []
experiment:
name: "test_experiment"
scheduler:
total_trials: 10
n_trials: 10

parameters:
x1:
Expand Down
6 changes: 6 additions & 0 deletions tests/test_configs/test_config_gen_strat1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@ objective:
metrics:
- name: rmse
metric: RootMeanSquaredError
properties:
# You can add any property you want to the metric
# Use this to pass any information you want to the metric
# Through your wrapper
any_property: any_value

generation_strategy:
steps:
- model: SOBOL
Expand Down

1 comment on commit 3b8b5f9

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Coverage

Coverage Report
FileStmtsMissCoverMissing
__main__.py74395%108, 185, 193
ax_instantiation_utils.py47197%123
controller.py82890%72, 102–103, 105–107, 111, 180
definitions.py80100% 
instantiation_base.py40197%39
metaclasses.py47393%50–51, 57
plot.py11110%13, 15–16, 18, 21–22, 29–31, 34–35
plotting.py1352482%40, 74–75, 232–235, 283–287, 327, 411–414, 416–420, 422, 427
registry.py180100% 
runner.py45491%42, 76–78
scheduler.py481372%87–90, 97–98, 112, 119–120, 127–128, 217–218
storage.py114992%77, 80, 163, 181, 205–206, 209, 214, 217
template.py440%1, 4–6
utils.py852076%174, 188–189, 214, 224–228, 230–232, 234, 236, 240–245
config
   __main__.py00100% 
   config.py2702491%199, 205, 367, 370, 378–379, 387, 560, 562, 570–572, 574, 645, 678, 699, 706–707, 721, 731–732, 744, 769, 772
   converters.py75988%14, 22, 41, 51–52, 65, 67, 71, 79
metrics
   metric_funcs.py34488%58, 80–81, 83
   metrics.py991386%131, 308–309, 316, 322, 333–335, 339–343
   modular_metric.py1202083%39–42, 44–51, 65, 132, 143, 183, 233–234, 253–254
   synthetic_funcs.py39489%31, 35, 58, 65
scripts
   moo.py30196%44
   run_branin.py34197%56
   script_wrappers.py31293%57–58
   synth_func_cli.py210100% 
wrappers
   base_wrapper.py1461887%62–65, 67, 74, 82, 99–100, 113, 122, 140, 167–168, 219, 265, 330, 339
   script_wrapper.py72888%183, 194–195, 257, 308, 314, 319, 324
   wrapper_utils.py1231091%146–147, 149, 233, 282, 382–386
TOTAL185221588% 

Please sign in to comment.