Skip to content

Commit

Permalink
Merge pull request #73 from y0z/feature/benchmarks
Browse files Browse the repository at this point in the history
Introduce the `optunahub.benchmarks` module
  • Loading branch information
nabenabe0928 authored Dec 11, 2024
2 parents a3f85ea + acec775 commit 66a0045
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 2 deletions.
7 changes: 7 additions & 0 deletions docs/source/_templates/custom_summary.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{{ fullname | escape | underline }}

.. autoclass:: {{ fullname }}
:members:
:special-members: __init__, __call__
:inherited-members:
:undoc-members:
11 changes: 11 additions & 0 deletions docs/source/benchmarks.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
.. module:: optunahub.benchmarks

optunahub.benchmarks
====================

.. autosummary::
:toctree: generated/
:nosignatures:
:template: custom_summary.rst

optunahub.benchmarks.BaseProblem
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
]

templates_path = ["_templates"]
Expand Down
3 changes: 2 additions & 1 deletion docs/source/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ Reference
:maxdepth: 1

optunahub
samplers
samplers
benchmarks
1 change: 1 addition & 0 deletions docs/source/samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ optunahub.samplers
.. autosummary::
:toctree: generated/
:nosignatures:
:template: custom_summary.rst

optunahub.samplers.SimpleBaseSampler
3 changes: 2 additions & 1 deletion optunahub/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from optunahub import benchmarks
from optunahub import samplers
from optunahub.hub import load_local_module
from optunahub.hub import load_module
from optunahub.version import __version__


__all__ = ["load_module", "load_local_module", "__version__", "samplers"]
__all__ = ["__version__", "benchmarks", "load_local_module", "load_module", "samplers"]
6 changes: 6 additions & 0 deletions optunahub/benchmarks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from ._base_problem import BaseProblem


__all__ = [
"BaseProblem",
]
81 changes: 81 additions & 0 deletions optunahub/benchmarks/_base_problem.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from __future__ import annotations

from abc import ABCMeta
from abc import abstractmethod
from typing import Any
from typing import Sequence

import optuna


class BaseProblem(metaclass=ABCMeta):
"""Base class for optimization problems."""

def __call__(self, trial: optuna.Trial) -> float | Sequence[float]:
"""Objective function for Optuna. By default, this method calls :meth:`evaluate` with the parameters defined in :attr:`search_space`.
Args:
trial: Optuna trial object.
Returns:
The objective value or a sequence of the objective values for multi-objective optimization.
"""
params = {}
for name, dist in self.search_space.items():
params[name] = trial._suggest(name, dist)
trial._check_distribution(name, dist)
return self.evaluate(params)

def evaluate(self, params: dict[str, Any]) -> float | Sequence[float]:
"""Evaluate the objective function.
Args:
params: Dictionary of input parameters.
Returns:
The objective value or a sequence of the objective values for multi-objective optimization.
Example:
::
def evaluate(self, params: dict[str, Any]) -> float:
x = params["x"]
y = params["y"]
return x ** 2 + y
"""
raise NotImplementedError

@property
def search_space(self) -> dict[str, optuna.distributions.BaseDistribution]:
"""Return the search space.
Returns:
Dictionary of search space. Each dictionary element consists of the parameter name and distribution (see `optuna.distributions <https://optuna.readthedocs.io/en/stable/reference/distributions.html>`__).
Example:
::
@property
def search_space(self) -> dict[str, optuna.distributions.BaseDistribution]:
return {
"x": optuna.distributions.FloatDistribution(low=0, high=1),
"y": optuna.distributions.CategoricalDistribution(choices=[0, 1, 2]),
}
"""
raise NotImplementedError

@property
@abstractmethod
def directions(self) -> list[optuna.study.StudyDirection]:
"""Return the optimization directions.
Returns:
List of `optuna.study.direction <https://optuna.readthedocs.io/en/stable/reference/generated/optuna.study.StudyDirection.html>`__.
Example:
::
@property
def directions(self) -> list[optuna.study.StudyDirection]:
return [optuna.study.StudyDirection.MINIMIZE]
"""
...
24 changes: 24 additions & 0 deletions tests/test_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

import optuna

import optunahub


def test_base_problem() -> None:
class TestProblem(optunahub.benchmarks.BaseProblem):
def evaluate(self, params: dict[str, float]) -> float:
x = params["x"]
return x**2

@property
def search_space(self) -> dict[str, optuna.distributions.BaseDistribution]:
return {"x": optuna.distributions.FloatDistribution(low=-1, high=1)}

@property
def directions(self) -> list[optuna.study.StudyDirection]:
return [optuna.study.StudyDirection.MINIMIZE]

problem = TestProblem()
study = optuna.create_study(directions=problem.directions)
study.optimize(problem, n_trials=20) # verify no error occurs

0 comments on commit 66a0045

Please sign in to comment.