diff --git a/demo/dask/dask_learning_to_rank.py b/demo/dask/dask_learning_to_rank.py new file mode 100644 index 000000000000..3567176ba0bf --- /dev/null +++ b/demo/dask/dask_learning_to_rank.py @@ -0,0 +1,186 @@ +""" +Learning to rank with the Dask Interface +======================================== + +This is a demonstration of using XGBoost for learning to rank tasks using the +MSLR_10k_letor dataset. For more infomation about the dataset, please visit its +`description page `_. + +""" + +from __future__ import annotations + +import argparse +import os +from contextlib import contextmanager +from typing import Generator + +import dask +import numpy as np +from dask import array as da +from dask import dataframe as dd +from distributed import Client, LocalCluster, wait +from sklearn.datasets import load_svmlight_file + +from xgboost import dask as dxgb + + +def load_mlsr_10k( + device: str, data_path: str, cache_path: str +) -> tuple[dd.DataFrame, dd.DataFrame, dd.DataFrame]: + """Load the MSLR10k dataset from data_path and save parquet files in the cache_path.""" + root_path = os.path.expanduser(args.data) + cache_path = os.path.expanduser(args.cache) + + # Use only the Fold1 for demo: + # Train, Valid, Test + # {S1,S2,S3}, S4, S5 + fold = 1 + + if not os.path.exists(cache_path): + os.mkdir(cache_path) + fold_path = os.path.join(root_path, f"Fold{fold}") + train_path = os.path.join(fold_path, "train.txt") + valid_path = os.path.join(fold_path, "vali.txt") + test_path = os.path.join(fold_path, "test.txt") + + X_train, y_train, qid_train = load_svmlight_file( + train_path, query_id=True, dtype=np.float32 + ) + columns = [f"f{i}" for i in range(X_train.shape[1])] + X_train = dd.from_array(X_train.toarray(), columns=columns) + y_train = y_train.astype(np.int32) + qid_train = qid_train.astype(np.int32) + + X_train["y"] = dd.from_array(y_train) + X_train["qid"] = dd.from_array(qid_train) + X_train.to_parquet(os.path.join(cache_path, "train"), engine="pyarrow") + + X_valid, y_valid, qid_valid = load_svmlight_file( + valid_path, query_id=True, dtype=np.float32 + ) + X_valid = dd.from_array(X_valid.toarray(), columns=columns) + y_valid = y_valid.astype(np.int32) + qid_valid = qid_valid.astype(np.int32) + + X_valid["y"] = dd.from_array(y_valid) + X_valid["qid"] = dd.from_array(qid_valid) + X_valid.to_parquet(os.path.join(cache_path, "valid"), engine="pyarrow") + + X_test, y_test, qid_test = load_svmlight_file( + test_path, query_id=True, dtype=np.float32 + ) + + X_test = dd.from_array(X_test.toarray(), columns=columns) + y_test = y_test.astype(np.int32) + qid_test = qid_test.astype(np.int32) + + X_test["y"] = dd.from_array(y_test) + X_test["qid"] = dd.from_array(qid_test) + X_test.to_parquet(os.path.join(cache_path, "test"), engine="pyarrow") + + df_train = dd.read_parquet( + os.path.join(cache_path, "train"), calculate_divisions=True + ) + df_valid = dd.read_parquet( + os.path.join(cache_path, "valid"), calculate_divisions=True + ) + df_test = dd.read_parquet( + os.path.join(cache_path, "test"), calculate_divisions=True + ) + + return df_train, df_valid, df_test + + +def ranking_demo(client: Client, args: argparse.Namespace) -> None: + df_train, df_valid, df_test = load_mlsr_10k(args.device, args.data, args.cache) + + X_train: dd.DataFrame = df_train[df_train.columns.difference(["y", "qid"])] + y_train = df_train[["y", "qid"]] + Xy_train = dxgb.DaskQuantileDMatrix(client, X_train, y_train.y, qid=y_train.qid) + + X_valid: dd.DataFrame = df_valid[df_valid.columns.difference(["y", "qid"])] + y_valid = df_valid[["y", "qid"]] + Xy_valid = dxgb.DaskQuantileDMatrix( + client, X_valid, y_valid.y, qid=y_valid.qid, ref=Xy_train + ) + + dxgb.train( + client, + {"objective": "rank:ndcg", "device": args.device}, + Xy_train, + evals=[(Xy_train, "Train"), (Xy_valid, "Valid")], + ) + + +def ranking_wo_split_demo(client: Client, args: argparse.Namespace) -> None: + """Learning to rank with data partitioned according to query groups.""" + df_tr, df_va, df_te = load_mlsr_10k(args.device, args.data, args.cache) + + X_tr = df_tr[df_tr.columns.difference(["y", "qid"])] + X_va = df_va[df_va.columns.difference(["y", "qid"])] + + ltr = dxgb.DaskXGBRanker(allow_group_split=False) + ltr.client = client + ltr = ltr.fit( + X_tr, + df_tr.y, + qid=df_tr.qid, + eval_set=[(X_tr, df_tr.y), (X_va, df_va.y)], + eval_qid=[df_tr.qid, df_va.qid], + ) + + df_te = df_te.persist() + wait([df_te]) + X_te = df_te[df_te.columns.difference(["y", "qid"])] + predt = ltr.predict(X_te).compute() + y = client.compute(df_te.y) + + +@contextmanager +def gen_client(device: str) -> Generator[Client, None, None]: + match device: + case "cuda": + from dask_cuda import LocalCUDACluster + + with LocalCUDACluster() as cluster: + with Client(cluster) as client: + with dask.config.set( + {"array.backend": "cupy", "dataframe.backend": "cudf"} + ): + yield client + case "cpu": + with LocalCluster() as cluster: + with Client(cluster) as client: + yield client + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Demonstration of learning to rank using XGBoost." + ) + parser.add_argument( + "--data", + type=str, + help="Root directory of the MSLR-WEB10K data.", + required=True, + ) + parser.add_argument( + "--cache", + type=str, + help="Directory for caching processed data.", + required=True, + ) + parser.add_argument("--device", choices=["cpu", "cuda"], default="cpu") + parser.add_argument( + "--no-split", + action="store_true", + help="Flag to indicate query groups should not be split.", + ) + args = parser.parse_args() + + with gen_client(args.device) as client: + if args.no_split: + ranking_wo_split_demo(client, args) + else: + ranking_demo(client, args) diff --git a/demo/rank/README.md b/demo/rank/README.md deleted file mode 100644 index 1f112b4cbe7d..000000000000 --- a/demo/rank/README.md +++ /dev/null @@ -1,41 +0,0 @@ -Learning to rank -==== -XGBoost supports accomplishing ranking tasks. In ranking scenario, data are often grouped and we need the [group information file](../../doc/tutorials/input_format.rst#group-input-format) to specify ranking tasks. The model used in XGBoost for ranking is the LambdaRank. See [parameters](../../doc/parameter.rst) for supported metrics. - -### Parameters -The configuration setting is similar to the regression and binary classification setting, except user need to specify the objectives: - -``` -... -objective="rank:pairwise" -... -``` -For more usage details please refer to the [binary classification demo](../binary_classification), - -Instructions -==== -The dataset for ranking demo is from LETOR04 MQ2008 fold1. -Before running the examples, you need to get the data by running: - -``` -./wgetdata.sh -``` - -### Command Line -Run the example: -``` -./runexp.sh -``` - -### Python -There are two ways of doing ranking in python. - -Run the example using `xgboost.train`: -``` -python rank.py -``` - -Run the example using `XGBRanker`: -``` -python rank_sklearn.py -``` diff --git a/demo/rank/mq2008.conf b/demo/rank/mq2008.conf deleted file mode 100644 index de2d2121dda3..000000000000 --- a/demo/rank/mq2008.conf +++ /dev/null @@ -1,26 +0,0 @@ -# General Parameters, see comment for each definition - -# specify objective -objective="rank:pairwise" - -# Tree Booster Parameters -# step size shrinkage -eta = 0.1 -# minimum loss reduction required to make a further partition -gamma = 1.0 -# minimum sum of instance weight(hessian) needed in a child -min_child_weight = 0.1 -# maximum depth of a tree -max_depth = 6 - -# Task parameters -# the number of round to do boosting -num_round = 4 -# 0 means do not save any model except the final round model -save_period = 0 -# The path of training data -data = "mq2008.train" -# The path of validation data, used to monitor training process, here [test] sets name of the validation set -eval[test] = "mq2008.vali" -# The path of test data -test:data = "mq2008.test" diff --git a/demo/rank/rank.py b/demo/rank/rank.py deleted file mode 100644 index 57cf04245342..000000000000 --- a/demo/rank/rank.py +++ /dev/null @@ -1,41 +0,0 @@ -#!/usr/bin/python -from sklearn.datasets import load_svmlight_file - -import xgboost as xgb -from xgboost import DMatrix - -# This script demonstrate how to do ranking with xgboost.train -x_train, y_train = load_svmlight_file("mq2008.train") -x_valid, y_valid = load_svmlight_file("mq2008.vali") -x_test, y_test = load_svmlight_file("mq2008.test") - -group_train = [] -with open("mq2008.train.group", "r") as f: - data = f.readlines() - for line in data: - group_train.append(int(line.split("\n")[0])) - -group_valid = [] -with open("mq2008.vali.group", "r") as f: - data = f.readlines() - for line in data: - group_valid.append(int(line.split("\n")[0])) - -group_test = [] -with open("mq2008.test.group", "r") as f: - data = f.readlines() - for line in data: - group_test.append(int(line.split("\n")[0])) - -train_dmatrix = DMatrix(x_train, y_train) -valid_dmatrix = DMatrix(x_valid, y_valid) -test_dmatrix = DMatrix(x_test) - -train_dmatrix.set_group(group_train) -valid_dmatrix.set_group(group_valid) - -params = {'objective': 'rank:ndcg', 'eta': 0.1, 'gamma': 1.0, - 'min_child_weight': 0.1, 'max_depth': 6} -xgb_model = xgb.train(params, train_dmatrix, num_boost_round=4, - evals=[(valid_dmatrix, 'validation')]) -pred = xgb_model.predict(test_dmatrix) diff --git a/demo/rank/rank_sklearn.py b/demo/rank/rank_sklearn.py deleted file mode 100644 index fe2635f379ea..000000000000 --- a/demo/rank/rank_sklearn.py +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/python -from sklearn.datasets import load_svmlight_file - -import xgboost as xgb - -# This script demonstrate how to do ranking with XGBRanker -x_train, y_train = load_svmlight_file("mq2008.train") -x_valid, y_valid = load_svmlight_file("mq2008.vali") -x_test, y_test = load_svmlight_file("mq2008.test") - -group_train = [] -with open("mq2008.train.group", "r") as f: - data = f.readlines() - for line in data: - group_train.append(int(line.split("\n")[0])) - -group_valid = [] -with open("mq2008.vali.group", "r") as f: - data = f.readlines() - for line in data: - group_valid.append(int(line.split("\n")[0])) - -group_test = [] -with open("mq2008.test.group", "r") as f: - data = f.readlines() - for line in data: - group_test.append(int(line.split("\n")[0])) - -params = {'objective': 'rank:ndcg', 'learning_rate': 0.1, - 'gamma': 1.0, 'min_child_weight': 0.1, - 'max_depth': 6, 'n_estimators': 4} -model = xgb.sklearn.XGBRanker(**params) -model.fit(x_train, y_train, group_train, verbose=True, - eval_set=[(x_valid, y_valid)], eval_group=[group_valid]) -pred = model.predict(x_test) diff --git a/demo/rank/runexp.sh b/demo/rank/runexp.sh deleted file mode 100755 index a5ed5d1e0b9a..000000000000 --- a/demo/rank/runexp.sh +++ /dev/null @@ -1,4 +0,0 @@ -#!/bin/bash - -../../xgboost mq2008.conf -../../xgboost mq2008.conf task=pred model_in=0004.model diff --git a/demo/rank/trans_data.py b/demo/rank/trans_data.py deleted file mode 100644 index a93cf48ca718..000000000000 --- a/demo/rank/trans_data.py +++ /dev/null @@ -1,41 +0,0 @@ -import sys - - -def save_data(group_data,output_feature,output_group): - if len(group_data) == 0: - return - - output_group.write(str(len(group_data))+"\n") - for data in group_data: - # only include nonzero features - feats = [ p for p in data[2:] if float(p.split(':')[1]) != 0.0 ] - output_feature.write(data[0] + " " + " ".join(feats) + "\n") - -if __name__ == "__main__": - if len(sys.argv) != 4: - print ("Usage: python trans_data.py [Ranksvm Format Input] [Output Feature File] [Output Group File]") - sys.exit(0) - - fi = open(sys.argv[1]) - output_feature = open(sys.argv[2],"w") - output_group = open(sys.argv[3],"w") - - group_data = [] - group = "" - for line in fi: - if not line: - break - if "#" in line: - line = line[:line.index("#")] - splits = line.strip().split(" ") - if splits[1] != group: - save_data(group_data,output_feature,output_group) - group_data = [] - group = splits[1] - group_data.append(splits) - - save_data(group_data,output_feature,output_group) - - fi.close() - output_feature.close() - output_group.close() diff --git a/demo/rank/wgetdata.sh b/demo/rank/wgetdata.sh deleted file mode 100755 index 613d0183c4d3..000000000000 --- a/demo/rank/wgetdata.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash -if [ -f MQ2008.rar ] -then - echo "Use downloaded data to run experiment." -else - echo "Downloading data." - wget https://s3-us-west-2.amazonaws.com/xgboost-examples/MQ2008.rar - unrar x MQ2008.rar - mv -f MQ2008/Fold1/*.txt . -fi - -python trans_data.py train.txt mq2008.train mq2008.train.group - -python trans_data.py test.txt mq2008.test mq2008.test.group - -python trans_data.py vali.txt mq2008.vali mq2008.vali.group diff --git a/doc/tutorials/dask.rst b/doc/tutorials/dask.rst index 6e68d83a0083..e039e8e5aa11 100644 --- a/doc/tutorials/dask.rst +++ b/doc/tutorials/dask.rst @@ -355,15 +355,18 @@ Working with asyncio .. versionadded:: 1.2.0 -XGBoost's dask interface supports the new ``asyncio`` in Python and can be integrated into -asynchronous workflows. For using dask with asynchronous operations, please refer to -`this dask example `_ and document in -`distributed `_. To use XGBoost's -dask interface asynchronously, the ``client`` which is passed as an argument for training and -prediction must be operating in asynchronous mode by specifying ``asynchronous=True`` when the -``client`` is created (example below). All functions (including ``DaskDMatrix``) provided -by the functional interface will then return coroutines which can then be awaited to retrieve -their result. +XGBoost's dask interface supports the new :py:mod:`asyncio` in Python and can be +integrated into asynchronous workflows. For using dask with asynchronous operations, +please refer to `this dask example +`_ and document in `distributed +`_. To use XGBoost's Dask +interface asynchronously, the ``client`` which is passed as an argument for training and +prediction must be operating in asynchronous mode by specifying ``asynchronous=True`` when +the ``client`` is created (example below). All functions (including ``DaskDMatrix``) +provided by the functional interface will then return coroutines which can then be awaited +to retrieve their result. Please note that XGBoost is a compute-bounded application, where +parallelism is more important than concurrency. The support for `asyncio` is more about +compatibility instead of performance gain. Functional interface: @@ -526,6 +529,28 @@ See /~https://github.com/coiled/dask-xgboost-nyctaxi for a set of examples of usin with dask and optuna. +**************** +Learning to Rank +**************** + + .. versionadded:: 3.0.0 + + .. note:: + + Position debiasing is not yet supported. + +Similar to the (Py)Spark interface, the XGBoost Dask interface can automatically sort and +group the samples based on input query ID since version 3.0. However, the automatic +grouping in the Dask interface has some caveats that one needs to be aware of, namely it +increases memory usage and it groups only worker-local data. This should be similar with +other interfaces. + +For the memory usage part, XGBoost first checks whether the query ID is sorted before +actually sorting the samples. If you don't want XGBoost to sort and group the data during +training, one solution is to sort it beforehand. See :ref:`ltr-dist` for more info about +the implication of worker-local grouping, +:ref:`sphx_glr_python_dask-examples_dask_learning_to_rank.py` for a worked example. + .. _tracker-ip: *************** diff --git a/doc/tutorials/learning_to_rank.rst b/doc/tutorials/learning_to_rank.rst index 4d2cbad4aa47..756f5089df23 100644 --- a/doc/tutorials/learning_to_rank.rst +++ b/doc/tutorials/learning_to_rank.rst @@ -165,10 +165,24 @@ On the other hand, if you have comparatively small amount of training data: For any method chosen, you can modify ``lambdarank_num_pair_per_sample`` to control the amount of pairs generated. +.. _ltr-dist: + ******************** Distributed Training ******************** -XGBoost implements distributed learning-to-rank with integration of multiple frameworks including Dask, Spark, and PySpark. The interface is similar to the single-node counterpart. Please refer to document of the respective XGBoost interface for details. Scattering a query group onto multiple workers is theoretically sound but can affect the model accuracy. For most of the use cases, the small discrepancy is not an issue, as the amount of training data is usually large when distributed training is used. As a result, users don't need to partition the data based on query groups. As long as each data partition is correctly sorted by query IDs, XGBoost can aggregate sample gradients accordingly. + +XGBoost implements distributed learning-to-rank with integration of multiple frameworks +including :doc:`Dask `, :doc:`Spark `, and +:doc:`PySpark `. The interface is similar to the single-node +counterpart. Please refer to document of the respective XGBoost interface for details. + +.. warning:: + + Position-debiasing is not yet supported for existing distributed interfaces. + +XGBoost works with collective operations, which means data is scattered to multiple workers. We can divide the data partitions by query group and ensure no query group is split among workers. However, this requires a costly sort and groupby operation and might only be necessary for selected use cases. Splitting and scattering a query group to multiple workers is theoretically sound but can affect the model's accuracy. For most use cases, the small discrepancy is not an issue, as the amount of training data is usually large when distributed training is used. For a longer explanation, assuming the pairwise ranking method is used, we calculate the gradient based on relevance degree by constructing pairs within a query group. If a single query group is split among workers and we use worker-local data for gradient calculation, then we are simply sampling pairs from a smaller group for each worker to calculate the gradient and the evaluation metric. The comparison between each pair doesn't change because a group is split into sub-groups, what changes is the number of total and effective pairs and normalizers like `IDCG`. One can generate more pairs from a large group than it's from two smaller subgroups. As a result, the obtained gradient is still valid from a theoretical standpoint but might not be optimal. + +Unless there's a very small number of query groups, we don't need to partition the data based on query groups. As long as each data partitions within a worker are correctly sorted by query IDs, XGBoost can aggregate sample gradients accordingly. And both the (Py)Spark interface and the Dask interface can sort the data according to query ID, please see respected tutorials for more information. ******************* Reproducible Result diff --git a/python-package/pyproject.toml b/python-package/pyproject.toml index 367f77147f1b..4deff3ec7b6d 100644 --- a/python-package/pyproject.toml +++ b/python-package/pyproject.toml @@ -21,6 +21,7 @@ classifiers = [ "License :: OSI Approved :: Apache Software License", "Development Status :: 5 - Production/Stable", "Operating System :: OS Independent", + "Typing :: Typed", "Programming Language :: Python", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", diff --git a/python-package/xgboost/compat.py b/python-package/xgboost/compat.py index 3cffcaa2585c..ac9e8bbf67c7 100644 --- a/python-package/xgboost/compat.py +++ b/python-package/xgboost/compat.py @@ -32,15 +32,12 @@ def lazy_isinstance(instance: Any, module: str, name: str) -> bool: # pandas try: - from pandas import DataFrame, MultiIndex, Series - from pandas import concat as pandas_concat + from pandas import DataFrame, Series PANDAS_INSTALLED = True except ImportError: - MultiIndex = object DataFrame = object Series = object - pandas_concat = None PANDAS_INSTALLED = False @@ -132,7 +129,9 @@ def concat(value: Sequence[_T]) -> _T: # pylint: disable=too-many-return-statem # other sparse format will be converted to CSR. return scipy_sparse.vstack(value, format="csr") if PANDAS_INSTALLED and isinstance(value[0], (DataFrame, Series)): - return pandas_concat(value, axis=0) + from pandas import concat as pd_concat + + return pd_concat(value, axis=0) if lazy_isinstance(value[0], "cudf.core.dataframe", "DataFrame") or lazy_isinstance( value[0], "cudf.core.series", "Series" ): diff --git a/python-package/xgboost/dask/__init__.py b/python-package/xgboost/dask/__init__.py index f7329de71887..ecf3a20603ea 100644 --- a/python-package/xgboost/dask/__init__.py +++ b/python-package/xgboost/dask/__init__.py @@ -54,7 +54,6 @@ dask.config.set({"xgboost.scheduler_address": "192.0.0.100:12345"}) """ -import collections import logging from collections import defaultdict from contextlib import contextmanager @@ -75,6 +74,7 @@ Tuple, TypeAlias, TypedDict, + TypeGuard, TypeVar, Union, ) @@ -87,19 +87,17 @@ from dask import dataframe as dd from .. import collective, config -from .._typing import _T, FeatureNames, FeatureTypes, IterationRange +from .._typing import FeatureNames, FeatureTypes, IterationRange from ..callback import TrainingCallback from ..collective import Config as CollConfig from ..collective import _Args as CollArgs from ..collective import _ArgVals as CollArgsVals -from ..compat import DataFrame, concat, lazy_isinstance +from ..compat import DataFrame, lazy_isinstance from ..core import ( Booster, - DataIter, DMatrix, Metric, Objective, - QuantileDMatrix, XGBoostError, _check_distributed_params, _deprecate_positional_args, @@ -122,6 +120,7 @@ ) from ..tracker import RabitTracker from ..training import train as worker_train +from .data import _create_dmatrix, _create_quantile_dmatrix, no_group_split from .utils import get_address_from_user, get_n_threads _DaskCollection: TypeAlias = Union[da.Array, dd.DataFrame, dd.Series] @@ -249,14 +248,6 @@ def __init__(self, **args: CollArgsVals) -> None: self.args["DMLC_TASK_ID"] = f"[xgboost.dask-{wid}]:" + str(worker.address) -def dconcat(value: Sequence[_T]) -> _T: - """Concatenate sequence of partitions.""" - try: - return concat(value) - except TypeError: - return dd.multi.concat(list(value), axis=0) - - def _get_client(client: Optional["distributed.Client"]) -> "distributed.Client": """Simple wrapper around testing None.""" if not isinstance(client, (type(distributed.get_client()), type(None))): @@ -597,111 +588,6 @@ def first_valid(results: Iterable[Optional[_MapRetT]]) -> Optional[_MapRetT]: return result -_DataParts = List[Dict[str, Any]] - - -def _get_worker_parts(list_of_parts: _DataParts) -> Dict[str, List[Any]]: - assert isinstance(list_of_parts, list) - result: Dict[str, List[Any]] = {} - - def append(i: int, name: str) -> None: - if name in list_of_parts[i]: - part = list_of_parts[i][name] - else: - part = None - if part is not None: - if name not in result: - result[name] = [] - result[name].append(part) - - for i, _ in enumerate(list_of_parts): - append(i, "data") - append(i, "label") - append(i, "weight") - append(i, "base_margin") - append(i, "qid") - append(i, "label_lower_bound") - append(i, "label_upper_bound") - - return result - - -class DaskPartitionIter(DataIter): # pylint: disable=R0902 - """A data iterator for `DaskQuantileDMatrix`.""" - - def __init__( - self, - data: List[Any], - label: Optional[List[Any]] = None, - *, - weight: Optional[List[Any]] = None, - base_margin: Optional[List[Any]] = None, - qid: Optional[List[Any]] = None, - label_lower_bound: Optional[List[Any]] = None, - label_upper_bound: Optional[List[Any]] = None, - feature_names: Optional[FeatureNames] = None, - feature_types: Optional[Union[Any, List[Any]]] = None, - feature_weights: Optional[Any] = None, - ) -> None: - self._data = data - self._label = label - self._weight = weight - self._base_margin = base_margin - self._qid = qid - self._label_lower_bound = label_lower_bound - self._label_upper_bound = label_upper_bound - self._feature_names = feature_names - self._feature_types = feature_types - self._feature_weights = feature_weights - - assert isinstance(self._data, collections.abc.Sequence) - - types = (collections.abc.Sequence, type(None)) - assert isinstance(self._label, types) - assert isinstance(self._weight, types) - assert isinstance(self._base_margin, types) - assert isinstance(self._label_lower_bound, types) - assert isinstance(self._label_upper_bound, types) - - self._iter = 0 # set iterator to 0 - super().__init__(release_data=True) - - def _get(self, attr: str) -> Optional[Any]: - if getattr(self, attr) is not None: - return getattr(self, attr)[self._iter] - return None - - def data(self) -> Any: - """Utility function for obtaining current batch of data.""" - return self._data[self._iter] - - def reset(self) -> None: - """Reset the iterator""" - self._iter = 0 - - def next(self, input_data: Callable) -> bool: - """Yield next batch of data""" - if self._iter == len(self._data): - # Return False when there's no more batch. - return False - - input_data( - data=self.data(), - label=self._get("_label"), - weight=self._get("_weight"), - group=None, - qid=self._get("_qid"), - base_margin=self._get("_base_margin"), - label_lower_bound=self._get("_label_lower_bound"), - label_upper_bound=self._get("_label_upper_bound"), - feature_names=self._feature_names, - feature_types=self._feature_types, - feature_weights=self._feature_weights, - ) - self._iter += 1 - return True - - class DaskQuantileDMatrix(DaskDMatrix): """A dask version of :py:class:`QuantileDMatrix`.""" @@ -759,110 +645,6 @@ def _create_fn_args(self, worker_addr: str) -> Dict[str, Any]: return args -def _create_quantile_dmatrix( - *, - feature_names: Optional[FeatureNames], - feature_types: Optional[Union[Any, List[Any]]], - feature_weights: Optional[Any], - missing: float, - nthread: int, - parts: Optional[_DataParts], - max_bin: int, - enable_categorical: bool, - max_quantile_batches: Optional[int], - ref: Optional[DMatrix] = None, -) -> QuantileDMatrix: - worker = distributed.get_worker() - if parts is None: - msg = f"worker {worker.address} has an empty DMatrix." - LOGGER.warning(msg) - - d = QuantileDMatrix( - numpy.empty((0, 0)), - feature_names=feature_names, - feature_types=feature_types, - max_bin=max_bin, - ref=ref, - enable_categorical=enable_categorical, - max_quantile_batches=max_quantile_batches, - ) - return d - - unzipped_dict = _get_worker_parts(parts) - it = DaskPartitionIter( - **unzipped_dict, - feature_types=feature_types, - feature_names=feature_names, - feature_weights=feature_weights, - ) - - dmatrix = QuantileDMatrix( - it, - missing=missing, - nthread=nthread, - max_bin=max_bin, - ref=ref, - enable_categorical=enable_categorical, - max_quantile_batches=max_quantile_batches, - ) - return dmatrix - - -def _create_dmatrix( - *, - feature_names: Optional[FeatureNames], - feature_types: Optional[Union[Any, List[Any]]], - feature_weights: Optional[Any], - missing: float, - nthread: int, - enable_categorical: bool, - parts: Optional[_DataParts], -) -> DMatrix: - """Get data that local to worker from DaskDMatrix. - - Returns - ------- - A DMatrix object. - - """ - worker = distributed.get_worker() - list_of_parts = parts - if list_of_parts is None: - msg = f"worker {worker.address} has an empty DMatrix." - LOGGER.warning(msg) - d = DMatrix( - numpy.empty((0, 0)), - feature_names=feature_names, - feature_types=feature_types, - enable_categorical=enable_categorical, - ) - return d - - T = TypeVar("T") - - def concat_or_none(data: Sequence[Optional[T]]) -> Optional[T]: - if any(part is None for part in data): - return None - return dconcat(data) - - unzipped_dict = _get_worker_parts(list_of_parts) - concated_dict: Dict[str, Any] = {} - for key, value in unzipped_dict.items(): - v = concat_or_none(value) - concated_dict[key] = v - - dmatrix = DMatrix( - **concated_dict, - missing=missing, - feature_names=feature_names, - feature_types=feature_types, - nthread=nthread, - enable_categorical=enable_categorical, - feature_weights=feature_weights, - ) - return dmatrix - - def _dmatrix_from_list_of_parts(is_quantile: bool, **kwargs: Any) -> DMatrix: if is_quantile: return _create_quantile_dmatrix(**kwargs) @@ -2119,6 +1901,20 @@ def _argmax(x: Any) -> Any: """, ["estimators", "model"], + extra_parameters=""" + allow_group_split : + + .. versionadded:: 3.0.0 + + Whether a query group can be split among multiple workers. When set to `False`, + inputs must be Dask dataframes or series. + + .. warning:: + + GPU is not yet supported when the `dask-expr` is enabled. In addition, async + environment may not work. + +""", end_note=""" .. note:: @@ -2131,36 +1927,36 @@ def __init__( self, *, objective: str = "rank:pairwise", + allow_group_split: bool = True, coll_cfg: Optional[CollConfig] = None, **kwargs: Any, ) -> None: if callable(objective): raise ValueError("Custom objective function not supported by XGBRanker.") + self.allow_group_split = allow_group_split super().__init__(objective=objective, coll_cfg=coll_cfg, **kwargs) + def _wrapper_params(self) -> Set[str]: + params = super()._wrapper_params() + params.add("allow_group_split") + return params + async def _fit_async( self, X: _DataT, y: _DaskCollection, *, - group: Optional[_DaskCollection], qid: Optional[_DaskCollection], sample_weight: Optional[_DaskCollection], base_margin: Optional[_DaskCollection], eval_set: Optional[Sequence[Tuple[_DaskCollection, _DaskCollection]]], sample_weight_eval_set: Optional[Sequence[_DaskCollection]], base_margin_eval_set: Optional[Sequence[_DaskCollection]], - eval_group: Optional[Sequence[_DaskCollection]], eval_qid: Optional[Sequence[_DaskCollection]], verbose: Union[int, bool], xgb_model: Optional[Union[XGBModel, Booster]], feature_weights: Optional[_DaskCollection], ) -> "DaskXGBRanker": - msg = "Use the `qid` instead of the `group` with the dask interface." - if not (group is None and eval_group is None): - raise ValueError(msg) - if qid is None: - raise ValueError("`qid` is required for ranking.") params = self.get_xgb_params() dtrain, evals = await _async_wrap_evaluation_matrices( self.client, @@ -2227,8 +2023,105 @@ def fit( base_margin_eval_set: Optional[Sequence[_DaskCollection]] = None, feature_weights: Optional[_DaskCollection] = None, ) -> "DaskXGBRanker": - args = {k: v for k, v in locals().items() if k not in ("self", "__class__")} - return self._client_sync(self._fit_async, **args) + msg = "Use the `qid` instead of the `group` with the dask interface." + if not (group is None and eval_group is None): + raise ValueError(msg) + if qid is None: + raise ValueError("`qid` is required for ranking.") + + def check_df(X: _DaskCollection) -> TypeGuard[dd.DataFrame]: + if not isinstance(X, dd.DataFrame): + raise TypeError( + "When `allow_group_split` is set to False, X is required to be" + " a dataframe." + ) + return True + + def check_ser( + qid: Optional[_DaskCollection], name: str + ) -> TypeGuard[Optional[dd.Series]]: + if not isinstance(qid, dd.Series) and qid is not None: + raise TypeError( + f"When `allow_group_split` is set to False, {name} is required to be" + " a series." + ) + return True + + if not self.allow_group_split: + assert ( + check_df(X) + and check_ser(qid, "qid") + and check_ser(y, "y") + and check_ser(sample_weight, "sample_weight") + and check_ser(base_margin, "base_margin") + ) + assert qid is not None and y is not None + X_id = id(X) + X, qid, y, sample_weight, base_margin = no_group_split( + X, + qid, + y=y, + sample_weight=sample_weight, + base_margin=base_margin, + ) + + if eval_set is not None: + new_eval_set = [] + new_eval_qid = [] + new_sample_weight_eval_set = [] + new_base_margin_eval_set = [] + assert eval_qid + for i, (Xe, ye) in enumerate(eval_set): + we = sample_weight_eval_set[i] if sample_weight_eval_set else None + be = base_margin_eval_set[i] if base_margin_eval_set else None + assert check_df(Xe) + assert eval_qid + qe = eval_qid[i] + assert ( + eval_qid + and check_ser(qe, "qid") + and check_ser(ye, "y") + and check_ser(we, "sample_weight") + and check_ser(be, "base_margin") + ) + assert qe is not None and ye is not None + if id(Xe) != X_id: + Xe, qe, ye, we, be = no_group_split(Xe, qe, ye, we, be) + else: + Xe, qe, ye, we, be = X, qid, y, sample_weight, base_margin + + new_eval_set.append((Xe, ye)) + new_eval_qid.append(qe) + + if we is not None: + new_sample_weight_eval_set.append(we) + if be is not None: + new_base_margin_eval_set.append(be) + + eval_set = new_eval_set + eval_qid = new_eval_qid + sample_weight_eval_set = ( + new_sample_weight_eval_set if new_sample_weight_eval_set else None + ) + base_margin_eval_set = ( + new_base_margin_eval_set if new_base_margin_eval_set else None + ) + + return self._client_sync( + self._fit_async, + X=X, + y=y, + qid=qid, + sample_weight=sample_weight, + base_margin=base_margin, + eval_set=eval_set, + eval_qid=eval_qid, + verbose=verbose, + xgb_model=xgb_model, + sample_weight_eval_set=sample_weight_eval_set, + base_margin_eval_set=base_margin_eval_set, + feature_weights=feature_weights, + ) # FIXME(trivialfis): arguments differ due to additional parameters like group and # qid. diff --git a/python-package/xgboost/dask/data.py b/python-package/xgboost/dask/data.py new file mode 100644 index 000000000000..edc4ea49c42a --- /dev/null +++ b/python-package/xgboost/dask/data.py @@ -0,0 +1,368 @@ +# pylint: disable=too-many-arguments +"""Copyright 2019-2024, XGBoost contributors""" + +import logging +from collections.abc import Sequence +from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union, overload + +import distributed +import numpy as np +import pandas as pd +from dask import dataframe as dd + +from .. import collective as coll +from .._typing import _T, FeatureNames +from ..compat import concat, import_cupy +from ..core import DataIter, DMatrix, QuantileDMatrix +from ..data import is_on_cuda + +LOGGER = logging.getLogger("[xgboost.dask]") + +_DataParts = List[Dict[str, Any]] + + +def dconcat(value: Sequence[_T]) -> _T: + """Concatenate sequence of partitions.""" + try: + return concat(value) + except TypeError: + return dd.multi.concat(list(value), axis=0) + + +meta = [ + "label", + "weight", + "base_margin", + "qid", + "label_lower_bound", + "label_upper_bound", +] + + +class DaskPartitionIter(DataIter): # pylint: disable=R0902 + """A data iterator for the `DaskQuantileDMatrix`.""" + + def __init__( + self, + data: List[Any], + feature_names: Optional[FeatureNames] = None, + feature_types: Optional[Union[Any, List[Any]]] = None, + feature_weights: Optional[Any] = None, + **kwargs: Optional[List[Any]], + ) -> None: + types = (Sequence, type(None)) + # Samples + self._data = data + for k in meta: + setattr(self, k, kwargs.get(k, None)) + assert isinstance(getattr(self, k), types) + + # Feature info + self._feature_names = feature_names + self._feature_types = feature_types + self._feature_weights = feature_weights + + assert isinstance(self._data, Sequence) + + self._iter = 0 # set iterator to 0 + super().__init__(release_data=True) + + def _get(self, attr: str) -> Optional[Any]: + if getattr(self, attr) is not None: + return getattr(self, attr)[self._iter] + return None + + def data(self) -> Any: + """Utility function for obtaining current batch of data.""" + return self._data[self._iter] + + def reset(self) -> None: + """Reset the iterator""" + self._iter = 0 + + def next(self, input_data: Callable) -> bool: + """Yield next batch of data""" + if self._iter == len(self._data): + # Return False when there's no more batch. + return False + + kwargs = {k: self._get(k) for k in meta} + input_data( + data=self.data(), + group=None, + feature_names=self._feature_names, + feature_types=self._feature_types, + feature_weights=self._feature_weights, + **kwargs, + ) + self._iter += 1 + return True + + +@overload +def _add_column(df: dd.DataFrame, col: dd.Series) -> Tuple[dd.DataFrame, str]: ... + + +@overload +def _add_column(df: dd.DataFrame, col: None) -> Tuple[dd.DataFrame, None]: ... + + +def _add_column( + df: dd.DataFrame, col: Optional[dd.Series] +) -> Tuple[dd.DataFrame, Optional[str]]: + if col is None: + return df, col + + trails = 0 + uid = f"{col.name}_{trails}" + while uid in df.columns: + trails += 1 + uid = f"{col.name}_{trails}" + + df = df.assign(**{uid: col}) + return df, uid + + +def no_group_split( + df: dd.DataFrame, + qid: dd.Series, + y: dd.Series, + sample_weight: Optional[dd.Series], + base_margin: Optional[dd.Series], +) -> Tuple[ + dd.DataFrame, dd.Series, dd.Series, Optional[dd.Series], Optional[dd.Series] +]: + """A function to prevent query group from being scattered to different + workers. Please see the tutorial in the document for the implication for not having + partition boundary based on query groups. + + """ + + df, qid_uid = _add_column(df, qid) + df, y_uid = _add_column(df, y) + df, w_uid = _add_column(df, sample_weight) + df, bm_uid = _add_column(df, base_margin) + + df = df.persist() + df[qid_uid] = df[qid_uid].astype("category").cat.as_known().cat.codes + # The shuffle here is costly. + df = df.sort_values(by=qid_uid) + cnt = df.groupby(qid_uid)[qid_uid].count() + div = cnt.index.compute().values.tolist() + div = sorted(div) + div = tuple(div + [div[-1] + 1]) + + df = df.set_index( + qid_uid, + drop=False, + divisions=div, + ).persist() + + q, y, w, bm = [ + df[uid] if uid is not None else None for uid in [qid_uid, y_uid, w_uid, bm_uid] + ] + + uids = [uid for uid in [qid_uid, y_uid, w_uid, bm_uid] if uid is not None] + df = df.drop(uids, axis=1).persist() + return df, q, y, w, bm + + +def sort_data_by_qid(**kwargs: List[Any]) -> Dict[str, List[Any]]: + """Sort worker-local data by query ID for learning to rank tasks.""" + data_parts = kwargs.get("data") + assert data_parts is not None + n_parts = len(data_parts) + + if is_on_cuda(data_parts[0]): + from cudf import DataFrame + else: + from pandas import DataFrame + + def get_dict(i: int) -> dict: + def _get(attr: Optional[List[Any]]) -> Optional[Any]: + if attr is not None: + return attr[i] + return None + + data = {k: _get(kwargs.get(k, None)) for k in meta} + data = {k: v for k, v in data.items() if v is not None} + return data + + # This function was created for the `dd.from_mapq constructor for sorting with a + # Dask DF. We did not proceed with that route but kept some of the utilities. It + # might be necessary to try again in the future since concatenating and sorting is + # extremely expensive in terms of memory usage. + def map_fn(i: int) -> pd.DataFrame: + data = get_dict(i) + return DataFrame(data) + + qid_parts = [map_fn(i) for i in range(n_parts)] + dfq = concat(qid_parts) + if dfq.qid.is_monotonic_increasing: + return kwargs + + LOGGER.warning( + "[r%d]: Sorting data with %d partitions for ranking. " + "This is a costly operation and will increase the memory usage significantly. " + "To avoid this warning, sort the data based on qid before passing it into " + "XGBoost. Alternatively, you can use set the `allow_group_split` to False.", + coll.get_rank(), + n_parts, + ) + # I tried to construct a new dask DF to perform the sort, but it's quite difficult + # to get the partition alignment right. Along with the still maturing shuffle + # implementation and GPU compatibility, a simple concat is used. + # + # In case it might become useful one day, I managed to get a CPU version working, + # albeit qutie slow (much slower than concatenated sort). The implementation merges + # everything into a single Dask DF and runs `DF.sort_values`, then retrieve the + # individual X,y,qid, ... from calculated partition values `client.compute([p for p + # in df.partitions])`. It was to avoid creating mismatched partitions. + dfx = concat(data_parts) + + if is_on_cuda(dfq): + cp = import_cupy() + sorted_idx = cp.argsort(dfq.qid) + else: + sorted_idx = np.argsort(dfq.qid) + dfq = dfq.iloc[sorted_idx, :] + + if hasattr(dfx, "iloc"): + dfx = dfx.iloc[sorted_idx, :] + else: + dfx = dfx[sorted_idx, :] + + kwargs.update({"data": [dfx]}) + for i, c in enumerate(dfq.columns): + assert c in kwargs + kwargs.update({c: [dfq[c]]}) + + return kwargs + + +def _get_worker_parts(list_of_parts: _DataParts) -> Dict[str, List[Any]]: + assert isinstance(list_of_parts, list) + result: Dict[str, List[Any]] = {} + + def append(i: int, name: str) -> None: + if name in list_of_parts[i]: + part = list_of_parts[i][name] + else: + part = None + if part is not None: + if name not in result: + result[name] = [] + result[name].append(part) + + for i, _ in enumerate(list_of_parts): + append(i, "data") + for k in meta: + append(i, k) + + qid = result.get("qid", None) + if qid is not None: + result = sort_data_by_qid(**result) + return result + + +def _create_quantile_dmatrix( + *, + feature_names: Optional[FeatureNames], + feature_types: Optional[Union[Any, List[Any]]], + feature_weights: Optional[Any], + missing: float, + nthread: int, + parts: Optional[_DataParts], + max_bin: int, + enable_categorical: bool, + max_quantile_batches: Optional[int], + ref: Optional[DMatrix] = None, +) -> QuantileDMatrix: + worker = distributed.get_worker() + if parts is None: + msg = f"Worker {worker.address} has an empty DMatrix." + LOGGER.warning(msg) + + Xy = QuantileDMatrix( + np.empty((0, 0)), + feature_names=feature_names, + feature_types=feature_types, + max_bin=max_bin, + ref=ref, + enable_categorical=enable_categorical, + max_quantile_batches=max_quantile_batches, + ) + return Xy + + unzipped_dict = _get_worker_parts(parts) + it = DaskPartitionIter( + **unzipped_dict, + feature_types=feature_types, + feature_names=feature_names, + feature_weights=feature_weights, + ) + Xy = QuantileDMatrix( + it, + missing=missing, + nthread=nthread, + max_bin=max_bin, + ref=ref, + enable_categorical=enable_categorical, + max_quantile_batches=max_quantile_batches, + ) + return Xy + + +def _create_dmatrix( # pylint: disable=too-many-locals + *, + feature_names: Optional[FeatureNames], + feature_types: Optional[Union[Any, List[Any]]], + feature_weights: Optional[Any], + missing: float, + nthread: int, + enable_categorical: bool, + parts: Optional[_DataParts], +) -> DMatrix: + """Get data that local to worker from DaskDMatrix. + + Returns + ------- + A DMatrix object. + + """ + worker = distributed.get_worker() + list_of_parts = parts + if list_of_parts is None: + msg = f"Worker {worker.address} has an empty DMatrix." + LOGGER.warning(msg) + Xy = DMatrix( + np.empty((0, 0)), + feature_names=feature_names, + feature_types=feature_types, + enable_categorical=enable_categorical, + ) + return Xy + + T = TypeVar("T") + + def concat_or_none(data: Sequence[Optional[T]]) -> Optional[T]: + if any(part is None for part in data): + return None + return dconcat(data) + + unzipped_dict = _get_worker_parts(list_of_parts) + concated_dict: Dict[str, Any] = {} + for key, value in unzipped_dict.items(): + v = concat_or_none(value) + concated_dict[key] = v + + Xy = DMatrix( + **concated_dict, + missing=missing, + feature_names=feature_names, + feature_types=feature_types, + nthread=nthread, + enable_categorical=enable_categorical, + feature_weights=feature_weights, + ) + return Xy diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 5239aa238502..e400e5c57f44 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -426,7 +426,7 @@ def is_pd_cat_dtype(dtype: PandasDType) -> bool: return isinstance(dtype, CategoricalDtype) - from pandas.api.types import is_categorical_dtype + from pandas.api.types import is_categorical_dtype # type: ignore return is_categorical_dtype(dtype) @@ -442,7 +442,7 @@ def is_pd_sparse_dtype(dtype: PandasDType) -> bool: return isinstance(dtype, SparseDtype) - from pandas.api.types import is_sparse + from pandas.api.types import is_sparse # type: ignore return is_sparse(dtype) @@ -455,7 +455,7 @@ def pandas_pa_type(ser: Any) -> np.ndarray: # No copy, callstack: # pandas.core.internals.managers.SingleBlockManager.array_values() # pandas.core.internals.blocks.EABackedBlock.values - d_array: pd.arrays.ArrowExtensionArray = ser.array + d_array: pd.arrays.ArrowExtensionArray = ser.array # type: ignore # no copy in __arrow_array__ # ArrowExtensionArray._data is a chunked array aa: pa.ChunkedArray = d_array.__arrow_array__() @@ -1517,6 +1517,11 @@ def _proxy_transform( raise TypeError("Value type is not supported for data iterator:" + str(type(data))) +def is_on_cuda(data: Any) -> bool: + """Whether the data is a CUDA-based data structure.""" + return any(p(data) for p in (_is_cudf_df, _is_cudf_ser, _is_cupy_alike, _is_dlpack)) + + def dispatch_proxy_set_data( proxy: _ProxyDMatrix, data: DataType, diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 63448bf1458d..25448657c8ad 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -14,6 +14,7 @@ Optional, Protocol, Sequence, + Set, Tuple, Type, TypeVar, @@ -29,7 +30,13 @@ # Do not use class names on scikit-learn directly. Re-define the classes on # .compat to guarantee the behavior without scikit-learn -from .compat import SKLEARN_INSTALLED, XGBClassifierBase, XGBModelBase, XGBRegressorBase +from .compat import ( + SKLEARN_INSTALLED, + XGBClassifierBase, + XGBModelBase, + XGBRegressorBase, + import_cupy, +) from .config import config_context from .core import ( Booster, @@ -827,6 +834,19 @@ def _doc_link_template(self) -> str: base = "https://xgboost.readthedocs.io/en" return f"{base}/{rel}/python/python_api.html#{module}.{name}" + def _wrapper_params(self) -> Set[str]: + wrapper_specific = { + "importance_type", + "kwargs", + "missing", + "n_estimators", + "enable_categorical", + "early_stopping_rounds", + "callbacks", + "feature_types", + } + return wrapper_specific + def get_booster(self) -> Booster: """Get the underlying xgboost Booster of this model. @@ -905,16 +925,7 @@ def get_xgb_params(self) -> Dict[str, Any]: params: Dict[str, Any] = self.get_params() # Parameters that should not go into native learner. - wrapper_specific = { - "importance_type", - "kwargs", - "missing", - "n_estimators", - "enable_categorical", - "early_stopping_rounds", - "callbacks", - "feature_types", - } + wrapper_specific = self._wrapper_params() filtered = {} for k, v in params.items(): if k not in wrapper_specific and not callable(v): @@ -1231,9 +1242,9 @@ def predict( validate_features=validate_features, ) if _is_cupy_alike(predts): - import cupy # pylint: disable=import-error + cp = import_cupy() - predts = cupy.asnumpy(predts) # ensure numpy array is used. + predts = cp.asnumpy(predts) # ensure numpy array is used. return predts except TypeError: # coo, csc, dt @@ -1508,13 +1519,13 @@ def fit( # booster in a Python property. This way we can have efficient and # thread-safe prediction. if _is_cudf_df(y) or _is_cudf_ser(y): - import cupy as cp # pylint: disable=E0401 + cp = import_cupy() classes = cp.unique(y.values) self.n_classes_ = len(classes) expected_classes = cp.array(self.classes_) elif _is_cupy_alike(y): - import cupy as cp # pylint: disable=E0401 + cp = import_cupy() classes = cp.unique(y) self.n_classes_ = len(classes) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index c30cf947dae0..5f2d9f5226ad 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -62,7 +62,7 @@ from .._typing import ArrayLike from ..collective import Config -from ..compat import is_cudf_available, is_cupy_available +from ..compat import import_cupy, is_cudf_available, is_cupy_available from ..config import config_context from ..core import Booster, _check_distributed_params, _py_version from ..sklearn import DEFAULT_N_ESTIMATORS, XGBClassifier, XGBModel, _can_use_qdm @@ -1073,6 +1073,8 @@ def _fit(self, dataset: DataFrame) -> "_SparkXGBModel": log_level = get_logger_level(_LOG_TAG) + use_rmm = get_config()["use_rmm"] + def _train_booster( pandas_df_iter: Iterator[pd.DataFrame], ) -> Iterator[pd.DataFrame]: @@ -1132,9 +1134,9 @@ def _train_booster( _rabit_args = json.loads(messages[0])["rabit_msg"] evals_result: Dict[str, Any] = {} - with config_context(verbosity=verbosity), CommunicatorContext( - context, **_rabit_args - ): + with config_context( + verbosity=verbosity, use_rmm=use_rmm + ), CommunicatorContext(context, **_rabit_args): dtrain, dvalid = create_dmatrix_from_partitions( iterator=pandas_df_iter, feature_cols=feature_prop.features_cols_names, @@ -1448,7 +1450,7 @@ def predict_udf(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.Series]: if run_on_gpu: if is_cudf_available() and is_cupy_available(): if is_local: - import cupy as cp # pylint: disable=import-error + cp = import_cupy() total_gpus = cp.cuda.runtime.getDeviceCount() if total_gpus > 0: @@ -1490,7 +1492,7 @@ def to_gpu_if_possible(data: ArrayLike) -> ArrayLike: X = _read_csr_matrix_from_unwrapped_spark_vec(data) else: if feature_col_names is not None: - tmp = data[feature_col_names] + tmp: ArrayLike = data[feature_col_names] else: tmp = stack_series(data[alias.data]) X = to_gpu_if_possible(tmp) diff --git a/python-package/xgboost/testing/__init__.py b/python-package/xgboost/testing/__init__.py index 78ec99f64b37..eb5b6717dd1a 100644 --- a/python-package/xgboost/testing/__init__.py +++ b/python-package/xgboost/testing/__init__.py @@ -457,7 +457,11 @@ def make_categorical( def make_ltr( - n_samples: int, n_features: int, n_query_groups: int, max_rel: int + n_samples: int, + n_features: int, + n_query_groups: int, + max_rel: int, + sort_qid: bool = True, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Make a dataset for testing LTR.""" rng = np.random.default_rng(1994) @@ -470,7 +474,8 @@ def make_ltr( w = rng.normal(0, 1.0, size=n_query_groups) w -= np.min(w) w /= np.max(w) - qid = np.sort(qid) + if sort_qid: + qid = np.sort(qid) return X, y, qid, w @@ -834,12 +839,6 @@ def setup_rmm_pool(_: Any, pytestconfig: pytest.Config) -> None: ) -def get_client_workers(client: Any) -> List[str]: - "Get workers from a dask client." - workers = client.scheduler_info()["workers"] - return list(workers.keys()) - - def demo_dir(path: str) -> str: """Look for the demo directory based on the test file name.""" path = normpath(os.path.dirname(path)) diff --git a/python-package/xgboost/testing/dask.py b/python-package/xgboost/testing/dask.py index 93514a97fbfd..69ae82ff06cd 100644 --- a/python-package/xgboost/testing/dask.py +++ b/python-package/xgboost/testing/dask.py @@ -1,12 +1,13 @@ """Tests for dask shared by different test modules.""" -from typing import Any, List, Literal, cast +from typing import Any, List, Literal, Tuple, cast import numpy as np import pandas as pd from dask import array as da from dask import dataframe as dd from distributed import Client, get_worker +from sklearn.datasets import make_classification import xgboost as xgb import xgboost.testing as tm @@ -21,8 +22,6 @@ def check_init_estimation_clf( tree_method: str, device: Literal["cpu", "cuda"], client: Client ) -> None: """Test init estimation for classsifier.""" - from sklearn.datasets import make_classification - X, y = make_classification(n_samples=4096 * 2, n_features=32, random_state=1994) clf = xgb.XGBClassifier( n_estimators=1, max_depth=1, tree_method=tree_method, device=device @@ -174,3 +173,47 @@ def check_external_memory( # pylint: disable=too-many-locals def get_rabit_args(client: Client, n_workers: int) -> Any: """Get RABIT collective communicator arguments for tests.""" return client.sync(_get_rabit_args, client, n_workers) + + +def get_client_workers(client: Any) -> List[str]: + "Get workers from a dask client." + workers = client.scheduler_info()["workers"] + return list(workers.keys()) + + +def make_ltr( + client: Client, n_samples: int, n_features: int, n_query_groups: int, max_rel: int +) -> Tuple[dd.DataFrame, dd.Series, dd.Series]: + """Synthetic dataset for learning to rank.""" + workers = get_client_workers(client) + n_samples_per_worker = n_samples // len(workers) + + def make(n: int, seed: int) -> pd.DataFrame: + rng = np.random.default_rng(seed) + X, y = make_classification( + n, n_features, n_informative=n_features, n_redundant=0, n_classes=max_rel + ) + qid = rng.integers(size=(n,), low=0, high=n_query_groups) + df = pd.DataFrame(X, columns=[f"f{i}" for i in range(n_features)]) + df["qid"] = qid + df["y"] = y + return df + + futures = [] + i = 0 + for k in range(0, n_samples, n_samples_per_worker): + fut = client.submit( + make, n=n_samples_per_worker, seed=k, workers=[workers[i % len(workers)]] + ) + futures.append(fut) + i += 1 + + last = n_samples - (n_samples_per_worker * len(workers)) + if last != 0: + fut = client.submit(make, n=last, seed=n_samples_per_worker * len(workers)) + futures.append(fut) + + meta = make(1, 0) + df = dd.from_delayed(futures, meta=meta) + assert isinstance(df, dd.DataFrame) + return df.drop(["qid", "y"], axis=1), df.y, df.qid diff --git a/python-package/xgboost/testing/data_iter.py b/python-package/xgboost/testing/data_iter.py index bd612e2c3e84..1a8f769b88e3 100644 --- a/python-package/xgboost/testing/data_iter.py +++ b/python-package/xgboost/testing/data_iter.py @@ -6,6 +6,7 @@ from xgboost import testing as tm +from ..compat import import_cupy from ..core import DataIter, DMatrix, ExtMemQuantileDMatrix, QuantileDMatrix @@ -21,7 +22,7 @@ def run_mixed_sparsity(device: str) -> None: y = [y_0, y_1, y_2] if device.startswith("cuda"): - import cupy as cp # pylint: disable=import-error + cp = import_cupy() X = [cp.array(batch) for batch in X] diff --git a/src/data/data.cc b/src/data/data.cc index 47836bb5134b..713ad4a1a514 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -539,7 +539,9 @@ void MetaInfo::SetInfoFromHost(Context const* ctx, StringView key, Json arr) { } else if (key == "label") { CopyTensorInfoImpl(ctx, arr, &this->labels); if (this->num_row_ != 0 && this->labels.Shape(0) != this->num_row_) { - CHECK_EQ(this->labels.Size() % this->num_row_, 0) << "Incorrect size for labels."; + CHECK_EQ(this->labels.Size() % this->num_row_, 0) + << "Incorrect size for labels: (" << this->labels.Shape(0) << "," << this->labels.Shape(1) + << ") v.s. " << this->num_row_; size_t n_targets = this->labels.Size() / this->num_row_; this->labels.Reshape(this->num_row_, n_targets); } diff --git a/src/objective/lambdarank_obj.cc b/src/objective/lambdarank_obj.cc index 94acf5a238d9..c50a55b3a17c 100644 --- a/src/objective/lambdarank_obj.cc +++ b/src/objective/lambdarank_obj.cc @@ -1,5 +1,5 @@ /** - * Copyright (c) 2023, XGBoost contributors + * Copyright 2023-2024, XGBoost contributors */ #include "lambdarank_obj.h" @@ -23,7 +23,6 @@ #include "../common/optional_weight.h" // for MakeOptionalWeights, OptionalWeights #include "../common/ranking_utils.h" // for RankingCache, LambdaRankParam, MAPCache, NDCGC... #include "../common/threading_utils.h" // for ParallelFor, Sched -#include "../common/transform_iterator.h" // for IndexTransformIter #include "init_estimation.h" // for FitIntercept #include "xgboost/base.h" // for bst_group_t, GradientPair, kRtEps, GradientPai... #include "xgboost/context.h" // for Context diff --git a/src/objective/lambdarank_obj.cu b/src/objective/lambdarank_obj.cu index c8da43748ae8..7863094c808c 100644 --- a/src/objective/lambdarank_obj.cu +++ b/src/objective/lambdarank_obj.cu @@ -1,5 +1,5 @@ /** - * Copyright 2015-2023 by XGBoost contributors + * Copyright 2015-2024, XGBoost contributors * * \brief CUDA implementation of lambdarank. */ diff --git a/src/objective/lambdarank_obj.cuh b/src/objective/lambdarank_obj.cuh index 2e5724f7f1fd..296dd18368b7 100644 --- a/src/objective/lambdarank_obj.cuh +++ b/src/objective/lambdarank_obj.cuh @@ -71,13 +71,13 @@ struct KernelInputs { std::int32_t iter; }; /** - * \brief Functor for generating pairs + * @brief Functor for generating pairs */ template struct MakePairsOp { KernelInputs args; /** - * \brief Make pair for the topk pair method. + * @brief Make pair for the topk pair method. */ [[nodiscard]] XGBOOST_DEVICE std::tuple WithTruncation( std::size_t idx, bst_group_t g) const { @@ -86,9 +86,6 @@ struct MakePairsOp { auto data_group_begin = static_cast(args.d_group_ptr[g]); std::size_t n_data = args.d_group_ptr[g + 1] - data_group_begin; - // obtain group segment data. - auto g_label = args.labels.Slice(linalg::Range(data_group_begin, data_group_begin + n_data), 0); - auto g_sorted_idx = args.d_sorted_idx.subspan(data_group_begin, n_data); std::size_t i = 0, j = 0; common::UnravelTrapeziodIdx(idx_in_thread_group, n_data, &i, &j); @@ -97,7 +94,7 @@ struct MakePairsOp { return std::make_tuple(rank_high, rank_low); } /** - * \brief Make pair for the mean pair method + * @brief Make pair for the mean pair method */ XGBOOST_DEVICE std::tuple WithSampling(std::size_t idx, bst_group_t g) const { diff --git a/tests/ci_build/lint_python.py b/tests/ci_build/lint_python.py index e97b13f2c465..76860d9d1e35 100644 --- a/tests/ci_build/lint_python.py +++ b/tests/ci_build/lint_python.py @@ -107,6 +107,7 @@ class LintersPaths: "tests/python/test_model_io.py", "tests/test_distributed/test_federated/", "tests/test_distributed/test_gpu_federated/", + "tests/test_distributed/test_with_dask/test_ranking.py", "tests/test_distributed/test_with_dask/test_external_memory.py", "tests/test_distributed/test_with_spark/test_data.py", "tests/test_distributed/test_gpu_with_spark/test_data.py", diff --git a/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py b/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py index 3bc7d46eb721..100c7861d55f 100644 --- a/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py +++ b/tests/test_distributed/test_gpu_with_dask/test_gpu_with_dask.py @@ -469,7 +469,7 @@ def test_empty_partition(self, local_cuda_client: Client) -> None: np.testing.assert_allclose(predt, in_predt) def test_empty_dmatrix_auc(self, local_cuda_client: Client) -> None: - n_workers = len(tm.get_client_workers(local_cuda_client)) + n_workers = len(tm.dask.get_client_workers(local_cuda_client)) run_empty_dmatrix_auc(local_cuda_client, "cuda", n_workers) def test_auc(self, local_cuda_client: Client) -> None: @@ -494,7 +494,7 @@ def test_data_initialization(self, local_cuda_client: Client) -> None: fw = fw - fw.min() m = dxgb.DaskDMatrix(local_cuda_client, X, y, feature_weights=fw) - workers = tm.get_client_workers(local_cuda_client) + workers = tm.dask.get_client_workers(local_cuda_client) rabit_args = get_rabit_args(local_cuda_client, len(workers)) def worker_fn(worker_addr: str, data_ref: Dict) -> None: @@ -595,7 +595,7 @@ def test_with_asyncio(local_cuda_client: Client) -> None: ) def test_invalid_nccl(local_cuda_client: Client) -> None: client = local_cuda_client - workers = tm.get_client_workers(client) + workers = tm.dask.get_client_workers(client) args = get_rabit_args(client, len(workers)) def run(wid: int) -> None: @@ -634,7 +634,7 @@ def make_model() -> None: assert err.getvalue().find("NCCL") == -1 client = local_cuda_client - workers = tm.get_client_workers(client) + workers = tm.dask.get_client_workers(client) args = get_rabit_args(client, len(workers)) # nccl is loaded diff --git a/tests/test_distributed/test_with_dask/test_external_memory.py b/tests/test_distributed/test_with_dask/test_external_memory.py index 7643f7305e27..ccd5740618e7 100644 --- a/tests/test_distributed/test_with_dask/test_external_memory.py +++ b/tests/test_distributed/test_with_dask/test_external_memory.py @@ -13,7 +13,7 @@ async def test_external_memory( client: Client, s: Scheduler, a: Worker, b: Worker, is_qdm: bool ) -> None: - workers = tm.get_client_workers(client) + workers = tm.dask.get_client_workers(client) n_workers = len(workers) args = await get_rabit_args(client, n_workers) diff --git a/tests/test_distributed/test_with_dask/test_ranking.py b/tests/test_distributed/test_with_dask/test_ranking.py new file mode 100644 index 000000000000..82247786f813 --- /dev/null +++ b/tests/test_distributed/test_with_dask/test_ranking.py @@ -0,0 +1,91 @@ +"""Copyright 2019-2024, XGBoost contributors""" + +import os +from typing import Generator + +import numpy as np +import pytest +import scipy.sparse +from dask import dataframe as dd +from distributed import Client, LocalCluster, Scheduler, Worker +from distributed.utils_test import gen_cluster + +from xgboost import dask as dxgb +from xgboost import testing as tm +from xgboost.testing import dask as dtm + + +@pytest.fixture(scope="module") +def cluster() -> Generator: + n_threads = os.cpu_count() + assert n_threads is not None + with LocalCluster( + n_workers=2, threads_per_worker=n_threads // 2, dashboard_address=":0" + ) as dask_cluster: + yield dask_cluster + + +@pytest.fixture +def client(cluster: LocalCluster) -> Generator: + with Client(cluster) as dask_client: + yield dask_client + + +def test_dask_ranking(client: Client) -> None: + dpath = "demo/rank/" + mq2008 = tm.data.get_mq2008(dpath) + data = [] + for d in mq2008: + if isinstance(d, scipy.sparse.csr_matrix): + d[d == 0] = np.inf + d = d.toarray() + d[d == 0] = np.nan + d[np.isinf(d)] = 0 + data.append(dd.from_array(d, chunksize=32)) + else: + data.append(dd.from_array(d, chunksize=32)) + + ( + x_train, + y_train, + qid_train, + x_test, + y_test, + qid_test, + x_valid, + y_valid, + qid_valid, + ) = data + qid_train = qid_train.astype(np.uint32) + qid_valid = qid_valid.astype(np.uint32) + qid_test = qid_test.astype(np.uint32) + + rank = dxgb.DaskXGBRanker( + n_estimators=2500, eval_metric=["ndcg"], early_stopping_rounds=10 + ) + rank.fit( + x_train, + y_train, + qid=qid_train, + eval_set=[(x_test, y_test), (x_train, y_train)], + eval_qid=[qid_test, qid_train], + verbose=True, + ) + assert rank.n_features_in_ == 46 + assert rank.best_score > 0.98 + + +def test_no_group_split(client: Client) -> None: + X_tr, q_tr, y_tr = dtm.make_ltr(client, 4096, 128, 4, 5) + X_va, q_va, y_va = dtm.make_ltr(client, 1024, 128, 4, 5) + + ltr = dxgb.DaskXGBRanker(allow_group_split=False, n_estimators=32) + ltr.fit( + X_tr, + y_tr, + qid=q_tr, + eval_set=[(X_tr, y_tr), (X_va, y_va)], + eval_qid=[q_tr, q_va], + verbose=True, + ) + print(X_tr.shape, X_tr.columns) diff --git a/tests/test_distributed/test_with_dask/test_with_dask.py b/tests/test_distributed/test_with_dask/test_with_dask.py index dac0860babf8..77db640c2a78 100644 --- a/tests/test_distributed/test_with_dask/test_with_dask.py +++ b/tests/test_distributed/test_with_dask/test_with_dask.py @@ -78,7 +78,7 @@ def make_categorical( n_categories: int, onehot: bool = False, ) -> Tuple[dd.DataFrame, dd.Series]: - workers = tm.get_client_workers(client) + workers = tm.dask.get_client_workers(client) n_workers = len(workers) dfs = [] @@ -1200,50 +1200,6 @@ def test_dask_aft_survival() -> None: run_aft_survival(client, DaskDMatrix) -def test_dask_ranking(client: "Client") -> None: - dpath = "demo/rank/" - mq2008 = tm.data.get_mq2008(dpath) - data = [] - for d in mq2008: - if isinstance(d, scipy.sparse.csr_matrix): - d[d == 0] = np.inf - d = d.toarray() - d[d == 0] = np.nan - d[np.isinf(d)] = 0 - data.append(dd.from_array(d, chunksize=32)) - else: - data.append(dd.from_array(d, chunksize=32)) - - ( - x_train, - y_train, - qid_train, - x_test, - y_test, - qid_test, - x_valid, - y_valid, - qid_valid, - ) = data - qid_train = qid_train.astype(np.uint32) - qid_valid = qid_valid.astype(np.uint32) - qid_test = qid_test.astype(np.uint32) - - rank = dxgb.DaskXGBRanker( - n_estimators=2500, eval_metric=["ndcg"], early_stopping_rounds=10 - ) - rank.fit( - x_train, - y_train, - qid=qid_train, - eval_set=[(x_test, y_test), (x_train, y_train)], - eval_qid=[qid_test, qid_train], - verbose=True, - ) - assert rank.n_features_in_ == 46 - assert rank.best_score > 0.98 - - @pytest.mark.parametrize("booster", ["dart", "gbtree"]) def test_dask_predict_leaf(booster: str, client: "Client") -> None: from sklearn.datasets import load_digits @@ -1379,7 +1335,7 @@ def load_dmatrix(rabit_args: Dict[str, Union[int, str]], tmpdir: str) -> None: assert Xy.num_col() == 4 with tempfile.TemporaryDirectory() as tmpdir: - workers = tm.get_client_workers(client) + workers = tm.dask.get_client_workers(client) rabit_args = get_rabit_args(client, len(workers)) futures = [] for w in workers: @@ -1635,7 +1591,7 @@ def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool: with LocalCluster(n_workers=2, dashboard_address=":0") as cluster: with Client(cluster) as client: - workers = tm.get_client_workers(client) + workers = tm.dask.get_client_workers(client) rabit_args = get_rabit_args(client, len(workers)) futures = [] for i, _ in enumerate(workers): @@ -1648,7 +1604,7 @@ def local_test(rabit_args: Dict[str, Union[int, str]], worker_id: int) -> bool: def test_n_workers(self) -> None: with LocalCluster(n_workers=2, dashboard_address=":0") as cluster: with Client(cluster) as client: - workers = tm.get_client_workers(client) + workers = tm.dask.get_client_workers(client) from sklearn.datasets import load_breast_cancer X, y = load_breast_cancer(return_X_y=True) @@ -1771,7 +1727,7 @@ def test_no_duplicated_partition(self) -> None: X, y, _ = generate_array() n_partitions = X.npartitions m = dxgb.DaskDMatrix(client, X, y) - workers = tm.get_client_workers(client) + workers = tm.dask.get_client_workers(client) rabit_args = get_rabit_args(client, len(workers)) n_workers = len(workers) @@ -1991,7 +1947,7 @@ def test_parallel_submits(client: "Client") -> None: from sklearn.datasets import load_digits futures = [] - workers = tm.get_client_workers(client) + workers = tm.dask.get_client_workers(client) n_submits = len(workers) for i in range(n_submits): X_, y_ = load_digits(return_X_y=True) @@ -2078,7 +2034,7 @@ def test_parallel_submit_multi_clients() -> None: with LocalCluster(n_workers=4, dashboard_address=":0") as cluster: with Client(cluster) as client: - workers = tm.get_client_workers(client) + workers = tm.dask.get_client_workers(client) n_submits = len(workers) assert n_submits == 4