Skip to content

Commit

Permalink
[Breaking] Fix custom metric for multi output. (#5954)
Browse files Browse the repository at this point in the history
* Set output margin to true for custom metric.  This fixes only R and Python.
  • Loading branch information
trivialfis authored Jul 29, 2020
1 parent 75b8c22 commit 18349a7
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 13 deletions.
2 changes: 1 addition & 1 deletion R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ xgb.iter.eval <- function(booster_handle, watchlist, iter, feval = NULL) {
} else {
res <- sapply(seq_along(watchlist), function(j) {
w <- watchlist[[j]]
preds <- predict(booster_handle, w, ntreelimit = 0) # predict using all trees
preds <- predict(booster_handle, w, outputmargin = TRUE, ntreelimit = 0) # predict using all trees
eval_res <- feval(preds, w)
out <- eval_res$value
names(out) <- paste0(evnames[j], "-", eval_res$metric)
Expand Down
4 changes: 4 additions & 0 deletions R-package/tests/testthat/test_custom_objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@ test_that("custom objective with multi-class works", {
hess <- rnorm(dim(as.matrix(preds))[1])
return (list(grad = grad, hess = hess))
}
fake_merror <- function(preds, dtrain) {
expect_equal(dim(data)[1] * nclasses, dim(as.matrix(preds))[1])
}
param$objective <- fake_softprob
param$eval_metric <- fake_merror
bst <- xgb.train(param, dtrain, 1, num_class = nclasses)
})
31 changes: 27 additions & 4 deletions demo/guide-python/custom_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def softprob_obj(predt: np.ndarray, data: xgb.DMatrix):
return grad, hess


def predict(booster, X):
def predict(booster: xgb.Booster, X):
'''A customized prediction function that converts raw prediction to
target class.
Expand All @@ -93,15 +93,34 @@ def predict(booster, X):
return out


def merror(predt: np.ndarray, dtrain: xgb.DMatrix):
y = dtrain.get_label()
# Like custom objective, the predt is untransformed leaf weight
assert predt.shape == (kRows, kClasses)
out = np.zeros(kRows)
for r in range(predt.shape[0]):
i = np.argmax(predt[r])
out[r] = i

assert y.shape == out.shape

errors = np.zeros(kRows)
errors[y != out] = 1.0
return 'PyMError', np.sum(errors) / kRows


def plot_history(custom_results, native_results):
fig, axs = plt.subplots(2, 1)
ax0 = axs[0]
ax1 = axs[1]

pymerror = custom_results['train']['PyMError']
merror = native_results['train']['merror']

x = np.arange(0, kRounds, 1)
ax0.plot(x, custom_results['train']['merror'], label='Custom objective')
ax0.plot(x, pymerror, label='Custom objective')
ax0.legend()
ax1.plot(x, native_results['train']['merror'], label='multi:softmax')
ax1.plot(x, merror, label='multi:softmax')
ax1.legend()

plt.show()
Expand All @@ -110,10 +129,12 @@ def plot_history(custom_results, native_results):
def main(args):
custom_results = {}
# Use our custom objective function
booster_custom = xgb.train({'num_class': kClasses},
booster_custom = xgb.train({'num_class': kClasses,
'disable_default_eval_metric': True},
m,
num_boost_round=kRounds,
obj=softprob_obj,
feval=merror,
evals_result=custom_results,
evals=[(m, 'train')])

Expand All @@ -131,6 +152,8 @@ def main(args):
# We are reimplementing the loss function in XGBoost, so it should
# be the same for normal cases.
assert np.all(predt_custom == predt_native)
np.testing.assert_allclose(custom_results['train']['PyMError'],
native_results['train']['merror'])

if args.plot != 0:
plot_history(custom_results, native_results)
Expand Down
4 changes: 2 additions & 2 deletions doc/parameter.rst
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ General Parameters

- Number of parallel threads used to run XGBoost

* ``disable_default_eval_metric`` [default=0]
* ``disable_default_eval_metric`` [default=``false``]

- Flag to disable default metric. Set to >0 to disable.
- Flag to disable default metric. Set to 1 or ``true`` to disable.

* ``num_pbuffer`` [set automatically by XGBoost, no need to be set by user]

Expand Down
3 changes: 2 additions & 1 deletion python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,7 +1228,8 @@ def eval_set(self, evals, iteration=0, feval=None):
res = msg.value.decode()
if feval is not None:
for dmat, evname in evals:
feval_ret = feval(self.predict(dmat, training=False), dmat)
feval_ret = feval(self.predict(dmat, training=False,
output_margin=True), dmat)
if isinstance(feval_ret, list):
for name, val in feval_ret:
res += '\t%s-%s:%f' % (evname, name, val)
Expand Down
8 changes: 4 additions & 4 deletions src/learner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,9 @@ LearnerModelParam::LearnerModelParam(

struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
// data split mode, can be row, col, or none.
DataSplitMode dsplit;
DataSplitMode dsplit {DataSplitMode::kAuto};
// flag to disable default metric
int disable_default_eval_metric;
bool disable_default_eval_metric {false};
// FIXME(trivialfis): The following parameters belong to model itself, but can be
// specified by users. Move them to model parameter once we can get rid of binary IO.
std::string booster;
Expand All @@ -171,7 +171,7 @@ struct LearnerTrainParam : public XGBoostParameter<LearnerTrainParam> {
.add_enum("row", DataSplitMode::kRow)
.describe("Data split mode for distributed training.");
DMLC_DECLARE_FIELD(disable_default_eval_metric)
.set_default(0)
.set_default(false)
.describe("Flag to disable default metric. Set to >0 to disable");
DMLC_DECLARE_FIELD(booster)
.set_default("gbtree")
Expand Down Expand Up @@ -253,7 +253,7 @@ class LearnerConfiguration : public Learner {
void Configure() override {
// Varient of double checked lock
if (!this->need_configuration_) { return; }
std::lock_guard<std::mutex> gard(config_lock_);
std::lock_guard<std::mutex> guard(config_lock_);
if (!this->need_configuration_) { return; }

monitor_.Start("Configure");
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def test_early_stopping_nonparallel(self):
eval_set=[(X_test, y_test)])
assert clf3.best_score == 1

@pytest.mark.skipif(**tm.no_sklearn())
def evalerror(self, preds, dtrain):
from sklearn.metrics import mean_squared_error

labels = dtrain.get_label()
preds = 1.0 / (1.0 + np.exp(-preds))
return 'rmse', mean_squared_error(labels, preds)

@staticmethod
Expand Down

0 comments on commit 18349a7

Please sign in to comment.