diff --git a/python-package/lightgbm/engine.py b/python-package/lightgbm/engine.py index 3a0c93fba332..ca5428250b8f 100644 --- a/python-package/lightgbm/engine.py +++ b/python-package/lightgbm/engine.py @@ -11,7 +11,7 @@ from . import callback from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor, - _LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction, + _LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction, _LGBM_EvalFunctionResultType, _LGBM_FeatureNameConfiguration, _log_warning) from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold @@ -22,9 +22,15 @@ ] -_LGBM_CustomMetricFunction = Callable[ - [np.ndarray, Dataset], - Union[Tuple[str, float, bool], List[Tuple[str, float, bool]]] +_LGBM_CustomMetricFunction = Union[ + Callable[ + [np.ndarray, Dataset], + _LGBM_EvalFunctionResultType, + ], + Callable[ + [np.ndarray, Dataset], + List[_LGBM_EvalFunctionResultType] + ], ] _LGBM_PreprocFunction = Callable[ diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index e5186415783e..8f42a6c47400 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -33,32 +33,50 @@ scipy.sparse.spmatrix ] _LGBM_ScikitCustomObjectiveFunction = Union[ + # f(labels, preds) Callable[ - [np.ndarray, np.ndarray], + [Optional[np.ndarray], np.ndarray], Tuple[np.ndarray, np.ndarray] ], + # f(labels, preds, weights) Callable[ - [np.ndarray, np.ndarray, np.ndarray], + [Optional[np.ndarray], np.ndarray, Optional[np.ndarray]], Tuple[np.ndarray, np.ndarray] ], + # f(labels, preds, weights, group) Callable[ - [np.ndarray, np.ndarray, np.ndarray, np.ndarray], + [Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]], Tuple[np.ndarray, np.ndarray] ], ] _LGBM_ScikitCustomEvalFunction = Union[ + # f(labels, preds) Callable[ - [np.ndarray, np.ndarray], - Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] + [Optional[np.ndarray], np.ndarray], + _LGBM_EvalFunctionResultType ], Callable[ - [np.ndarray, np.ndarray, np.ndarray], - Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] + [Optional[np.ndarray], np.ndarray], + List[_LGBM_EvalFunctionResultType] ], + # f(labels, preds, weights) Callable[ - [np.ndarray, np.ndarray, np.ndarray, np.ndarray], - Union[_LGBM_EvalFunctionResultType, List[_LGBM_EvalFunctionResultType]] + [Optional[np.ndarray], np.ndarray, Optional[np.ndarray]], + _LGBM_EvalFunctionResultType ], + Callable[ + [Optional[np.ndarray], np.ndarray, Optional[np.ndarray]], + List[_LGBM_EvalFunctionResultType] + ], + # f(labels, preds, weights, group) + Callable[ + [Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]], + _LGBM_EvalFunctionResultType + ], + Callable[ + [Optional[np.ndarray], np.ndarray, Optional[np.ndarray], Optional[np.ndarray]], + List[_LGBM_EvalFunctionResultType] + ] ] _LGBM_ScikitEvalMetricType = Union[ str, @@ -135,11 +153,11 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np. labels = dataset.get_label() argc = len(signature(self.func).parameters) if argc == 2: - grad, hess = self.func(labels, preds) + grad, hess = self.func(labels, preds) # type: ignore[call-arg] elif argc == 3: - grad, hess = self.func(labels, preds, dataset.get_weight()) + grad, hess = self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg] elif argc == 4: - grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) + grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore [call-arg] else: raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}") return grad, hess @@ -213,11 +231,11 @@ def __call__( labels = dataset.get_label() argc = len(signature(self.func).parameters) if argc == 2: - return self.func(labels, preds) + return self.func(labels, preds) # type: ignore[call-arg] elif argc == 3: - return self.func(labels, preds, dataset.get_weight()) + return self.func(labels, preds, dataset.get_weight()) # type: ignore[call-arg] elif argc == 4: - return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) + return self.func(labels, preds, dataset.get_weight(), dataset.get_group()) # type: ignore[call-arg] else: raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}") @@ -819,7 +837,7 @@ def _get_meta_data(collection, name, i): num_boost_round=self.n_estimators, valid_sets=valid_sets, valid_names=eval_names, - feval=eval_metrics_callable, + feval=eval_metrics_callable, # type: ignore[arg-type] init_model=init_model, feature_name=feature_name, callbacks=callbacks