Skip to content

Commit

Permalink
Add Ruff pre-commit hook (#44)
Browse files Browse the repository at this point in the history
* Ruff pre-commit; format all files

* Fix circular import of ProblemConfig

* Further ruff formatting
  • Loading branch information
TobyBoyne authored Jul 3, 2024
1 parent 835694e commit 2ba99d4
Show file tree
Hide file tree
Showing 28 changed files with 267 additions and 163 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
repos:
- repo: /~https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.5.0
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
2 changes: 1 addition & 1 deletion entmoot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from entmoot.problem_config import ProblemConfig
from entmoot.models.enting import Enting
from entmoot.models.model_params import (
EntingParams,
Expand All @@ -8,3 +7,4 @@
)
from entmoot.optimizers.gurobi_opt import GurobiOptimizer
from entmoot.optimizers.pyomo_opt import PyomoOptimizer
from entmoot.problem_config import ProblemConfig as ProblemConfig
10 changes: 6 additions & 4 deletions entmoot/benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,11 @@ def compute_objectives(xi: Sequence, no_cat=False):
f"Allowed values are 1 and 2"
)


def build_reals_only_problem(problem_config: ProblemConfig):
"""A problem containing only real values, as used to demonstrate the NChooseK
"""A problem containing only real values, as used to demonstrate the NChooseK
constraint.
The minimum is (1.0, 2.0, 3.0, ...)"""

problem_config.add_feature("real", (0.0, 5.0), name="x1")
Expand All @@ -165,10 +166,11 @@ def build_reals_only_problem(problem_config: ProblemConfig):
problem_config.add_feature("real", (0.0, 5.0), name="x5")
problem_config.add_min_objective()


def eval_reals_only_testfunc(X: ArrayLike):
"""The function (x1 - 1)**2 + (x2 - 2)**2 + ..."""
x = np.atleast_2d(X)
xbar = np.ones_like(x)
xbar *= (np.arange(x.shape[1]) + 1)[None, :]
y = np.sum((x - xbar)**2, axis=1)
return y.reshape(-1, 1)
y = np.sum((x - xbar) ** 2, axis=1)
return y.reshape(-1, 1)
4 changes: 3 additions & 1 deletion entmoot/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,9 @@ def apply_pyomo_constraints(
for constraint in self._constraints:
features = constraint._get_feature_vars(model, feat_list)
if not isinstance(constraint, ExpressionConstraint):
raise TypeError("Only ExpressionConstraints are supported in a constraint list")
raise TypeError(
"Only ExpressionConstraints are supported in a constraint list"
)

expr = constraint._get_expr(model, features)
pyo_constraint_list.add(expr)
Expand Down
26 changes: 18 additions & 8 deletions entmoot/models/enting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@

import numpy as np

from entmoot import ProblemConfig
from entmoot.models.base_model import BaseModel
from entmoot.models.mean_models.tree_ensemble import TreeEnsemble
from entmoot.models.model_params import EntingParams
from entmoot.models.uncertainty_models.distance_based_uncertainty import (
DistanceBasedUncertainty,
)
from entmoot.problem_config import ProblemConfig
from entmoot.utils import sample


Expand Down Expand Up @@ -52,7 +52,9 @@ class Enting(BaseModel):
X_opt_pyo, _, _ = opt_pyo.solve(enting)
"""

def __init__(self, problem_config: ProblemConfig, params: Union[EntingParams, dict, None]):
def __init__(
self, problem_config: ProblemConfig, params: Union[EntingParams, dict, None]
):
if params is None:
params = {}
if isinstance(params, dict):
Expand Down Expand Up @@ -128,9 +130,9 @@ def predict(self, X: np.ndarray, is_enc=False) -> list:
f"Expected '(num_samples, {len(self._problem_config.feat_list)})', got '{X.shape}'."
)

mean_pred = self.mean_model.predict(X) #.tolist()
mean_pred = self.mean_model.predict(X) # .tolist()
unc_pred = self.unc_model.predict(X)

mean_pred = self._problem_config.transform_objective(mean_pred)
mean_pred = mean_pred.tolist()

Expand All @@ -147,7 +149,9 @@ def predict_acq(self, X: np.ndarray, is_enc=False) -> list:
acq_pred.append(mean + self._beta * unc)
return acq_pred

def add_to_gurobipy_model(self, core_model, weights: Optional[tuple[float, ...]] = None) -> None:
def add_to_gurobipy_model(
self, core_model, weights: Optional[tuple[float, ...]] = None
) -> None:
"""
Enriches the core model by adding variables and constraints based on information
from the tree model.
Expand All @@ -173,7 +177,9 @@ def add_to_gurobipy_model(self, core_model, weights: Optional[tuple[float, ...]]
if weights is not None:
moo_weights = weights
else:
moo_weights = sample(len(self._problem_config.obj_list), 1, self._problem_config.rng)[0]
moo_weights = sample(
len(self._problem_config.obj_list), 1, self._problem_config.rng
)[0]

for idx, obj in enumerate(self._problem_config.obj_list):
core_model.addConstr(
Expand All @@ -184,7 +190,9 @@ def add_to_gurobipy_model(self, core_model, weights: Optional[tuple[float, ...]]
core_model.setObjective(core_model._mu + self._beta * core_model._unc)
core_model.update()

def add_to_pyomo_model(self, core_model, weights: Optional[tuple[float, ...]] = None) -> None:
def add_to_pyomo_model(
self, core_model, weights: Optional[tuple[float, ...]] = None
) -> None:
"""
Enriches the core model by adding variables and constraints based on information
from the tree model.
Expand Down Expand Up @@ -212,7 +220,9 @@ def add_to_pyomo_model(self, core_model, weights: Optional[tuple[float, ...]] =
if weights is not None:
moo_weights = weights
else:
moo_weights = sample(len(self._problem_config.obj_list), 1, self._problem_config.rng)[0]
moo_weights = sample(
len(self._problem_config.obj_list), 1, self._problem_config.rng
)[0]

objectives_position_name = list(
enumerate([obj.name for obj in self._problem_config.obj_list])
Expand Down
1 change: 0 additions & 1 deletion entmoot/models/mean_models/lgbm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ def read_lgbm_tree_model_dict(tree_model_dict, cat_idx):
ordered_tree_list = []

for tree in tree_model_dict["tree_info"]:

# generate list of nodes in tree
root_node = [tree["tree_structure"]]
node_list = []
Expand Down
13 changes: 9 additions & 4 deletions entmoot/models/mean_models/tree_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
from entmoot.models.mean_models.lgbm_utils import read_lgbm_tree_model_dict
from entmoot.models.mean_models.meta_tree_ensemble import MetaTreeModel
from entmoot.models.model_params import TreeTrainParams
from entmoot.problem_config import ProblemConfig, Categorical
from entmoot.problem_config import Categorical, ProblemConfig


class TreeEnsemble(BaseModel):
def __init__(self, problem_config: ProblemConfig, params: Union[TreeTrainParams, dict, None] = None):
def __init__(
self,
problem_config: ProblemConfig,
params: Union[TreeTrainParams, dict, None] = None,
):
if params is None:
params = {}
if isinstance(params, dict):
Expand Down Expand Up @@ -98,14 +102,15 @@ def _train_lgbm(self, X, y):
tree_model = lgb.train(
self._train_params,
train_data,
#verbose_eval=False,
# verbose_eval=False,
)
else:
# train for non-categorical vars
train_data = lgb.Dataset(X, label=y, params={"verbose": -1})

tree_model = lgb.train(
self._train_params, train_data#, verbose_eval=False
self._train_params,
train_data, # , verbose_eval=False
)
return tree_model

Expand Down
26 changes: 16 additions & 10 deletions entmoot/models/model_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

class ParamValidationError(ValueError):
"""A model parameter takes an invalid value."""

pass


@dataclass
class UncParams:
"""
This dataclass contains all uncertainty parameters.
"""

#: weight for penalty/exploration part in objective function
beta: float = 1.96
#: the predictions of the GBT model are cut off, if their absolute value exceeds
Expand All @@ -35,19 +38,20 @@ def __post_init__(self):
raise ParamValidationError(
f"Value for 'beta' is {self.beta}; must be positive."
)

if self.acq_sense not in ("exploration", "penalty"):
raise ParamValidationError(
f"Value for 'acq_sense' is '{self.acq_sense}'; must be in ('exploration', 'penalty')."
)


@dataclass
class TrainParams:
"""
This dataclass contains all hyperparameters that are used by lightbm during training and
documented here https://lightgbm.readthedocs.io/en/latest/Parameters.html
"""

# lightgbm training hyperparameters
objective: str = "regression"
metric: str = "rmse"
Expand All @@ -64,13 +68,14 @@ class TreeTrainParams:
"""
This dataclass contains all parameters needed for the tree training.
"""
train_params: "TrainParams" = field(default_factory=dict) # type: ignore

train_params: "TrainParams" = field(default_factory=dict) # type: ignore
train_lib: Literal["lgbm"] = "lgbm"

def __post_init__(self):
if isinstance(self.train_params, dict):
self.train_params = TrainParams(**self.train_params)

if self.train_lib not in ("lgbm",):
raise ParamValidationError(
f"Value for 'train_lib' is {self.train_lib}; must be in ('lgbm',)"
Expand All @@ -80,15 +85,16 @@ def __post_init__(self):
@dataclass
class EntingParams:
"""Contains parameters for a mean and uncertainty model.
Provides a structured dataclass for the parameters of an Enting model,
Provides a structured dataclass for the parameters of an Enting model,
alongside default values and some light data validation."""
unc_params: "UncParams" = field(default_factory=dict) # type: ignore
tree_train_params: "TreeTrainParams" = field(default_factory=dict) # type: ignore


unc_params: "UncParams" = field(default_factory=dict) # type: ignore
tree_train_params: "TreeTrainParams" = field(default_factory=dict) # type: ignore

def __post_init__(self):
if isinstance(self.unc_params, dict):
self.unc_params = UncParams(**self.unc_params)

if isinstance(self.tree_train_params, dict):
self.tree_train_params = TreeTrainParams(**self.tree_train_params)
self.tree_train_params = TreeTrainParams(**self.tree_train_params)
5 changes: 2 additions & 3 deletions entmoot/models/uncertainty_models/base_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from entmoot.models.base_model import BaseModel
from entmoot.problem_config import ProblemConfig, Categorical
from entmoot.problem_config import Categorical, ProblemConfig


class NonCatDistance(BaseModel):
Expand Down Expand Up @@ -74,15 +74,14 @@ def get_gurobipy_model_constr(self, model_core):

def get_pyomo_model_constr(self, model_core):
raise NotImplementedError()

def get_gurobipy_model_constr_terms(self, model) -> list:
raise NotImplementedError()

def get_pyomo_model_constr_terms(self, model) -> list:
raise NotImplementedError()



class CatDistance(BaseModel):
def __init__(self, problem_config: ProblemConfig, acq_sense):
self._problem_config = problem_config
Expand Down
28 changes: 21 additions & 7 deletions entmoot/models/uncertainty_models/distance_based_uncertainty.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,20 @@


@overload
def distance_func_mapper(dist_name: str, cat: Literal[True]) -> type[CatDistance] | None: ...
def distance_func_mapper(
dist_name: str, cat: Literal[True]
) -> type[CatDistance] | None: ...


@overload
def distance_func_mapper(dist_name: str, cat: Literal[False]) -> type[NonCatDistance] | None: ...
def distance_func_mapper(
dist_name: str, cat: Literal[False]
) -> type[NonCatDistance] | None: ...


def distance_func_mapper(dist_name: str, cat: bool) -> type[CatDistance] | type[NonCatDistance] | None:
def distance_func_mapper(
dist_name: str, cat: bool
) -> type[CatDistance] | type[NonCatDistance] | None:
"""Given a string, return the distance function"""
non_cat_dists = {
"euclidean_squared": EuclideanSquaredDistance,
Expand All @@ -41,7 +49,9 @@ def distance_func_mapper(dist_name: str, cat: bool) -> type[CatDistance] | type[


class DistanceBasedUncertainty(BaseModel):
def __init__(self, problem_config: ProblemConfig, params: Union[UncParams, dict, None] = None):
def __init__(
self, problem_config: ProblemConfig, params: Union[UncParams, dict, None] = None
):
if params is None:
params = {}
if isinstance(params, dict):
Expand Down Expand Up @@ -100,7 +110,7 @@ def __init__(self, problem_config: ProblemConfig, params: Union[UncParams, dict,
)
else:
self.cat_unc_model: CatDistance = cat_distance(
problem_config=self._problem_config,
problem_config=self._problem_config,
acq_sense=params.acq_sense,
)

Expand Down Expand Up @@ -171,7 +181,7 @@ def add_to_gurobipy_model(self, model):
model.addVar(name=f"bin_penalty_{i}", vtype="B")
)

big_m_term = big_m * (1 - model._bin_penalty[-1]) # type: ignore
big_m_term = big_m * (1 - model._bin_penalty[-1]) # type: ignore

if self._dist_metric == "l2":
# take sqrt for l2 distance
Expand Down Expand Up @@ -265,7 +275,11 @@ def add_to_pyomo_model(self, model):

def constrs_bin_penalty_sum(model_obj):
return (
sum(model_obj._bin_penalty[k] for k in model.indices_constrs_cat_noncat_contr) == 1
sum(
model_obj._bin_penalty[k]
for k in model.indices_constrs_cat_noncat_contr
)
== 1
)

model.constrs_bin_penalty_sum = pyo.Constraint(rule=constrs_bin_penalty_sum)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ def _array_predict(self, X):
raise NotImplementedError()

def get_gurobipy_model_constr_terms(self, model):

from gurobipy import quicksum

features = model._all_feat
Expand All @@ -31,7 +30,6 @@ def get_gurobipy_model_constr_terms(self, model):
return constr_list

def get_pyomo_model_constr_terms(self, model):

features = model._all_feat

constr_list = []
Expand Down
1 change: 0 additions & 1 deletion entmoot/models/uncertainty_models/l2_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ def get_gurobipy_model_constr_terms(self, model):
return constr_list

def get_pyomo_model_constr_terms(self, model):

features = model._all_feat

constr_list = []
Expand Down
Loading

0 comments on commit 2ba99d4

Please sign in to comment.