Skip to content

Commit

Permalink
merge.
Browse files Browse the repository at this point in the history
remove.
  • Loading branch information
trivialfis committed Nov 21, 2024
1 parent 1ee28f7 commit ac4824e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 31 deletions.
34 changes: 6 additions & 28 deletions demo/dask/dask_learning_to_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,39 +113,17 @@ def ranking_demo(client: Client, args: argparse.Namespace) -> None:
)


def no_group_split(client: Client, df: dd.DataFrame) -> dd.DataFrame:
"""A function to prevent query group from being scattered to different
workers. Please see the tutorial in the document for the implication for not having
partition boundary based on query groups.
"""

# The shuffle here is costly.
df["qid"] = df.qid.astype("category").cat.as_known().cat.codes
df = df.sort_values(by="qid")
cnt = df.groupby("qid").qid.count()
div = cnt.index.compute().values.tolist()
div = sorted(div)
div = tuple(div + [div[-1] + 1])

df = df.set_index(
"qid",
drop=False,
divisions=div,
).persist()

wait([df])
return df


def ranking_wo_split_demo(client: Client, args: argparse.Namespace) -> None:
"""Learning to rank with data partitioned according to query groups."""
from xgboost.dask import no_group_split

df_train, df_valid, df_test = load_mlsr_10k(args.device, args.data, args.cache)

df_train, df_valid, df_test = [
no_group_split(client, df) for df in (df_train, df_valid, df_test)
]
df_train, qid = no_group_split(client, df_train.drop(["qid"], axis=1), df_train.qid)
df_train["qid"] = qid

df_valid, qid = no_group_split(client, df_valid.drop(["qid"], axis=1), df_valid.qid)
df_valid["qid"] = qid

X = df_train[df_train.columns.difference(["y", "qid"])]
Xy_train = dxgb.DaskQuantileDMatrix(client, X, label=df_train.y, qid=df_train.qid)
Expand Down
5 changes: 2 additions & 3 deletions python-package/xgboost/dask/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,11 @@ def no_group_split(
qid_uid = str(uuid.uuid4())
while qid_uid in df.columns:
qid_uid = str(uuid.uuid4())
qid_uid = "qid"
df[qid_uid] = qid

df[qid_uid] = df.qid.astype("category").cat.as_known().cat.codes
df[qid_uid] = df[qid_uid].astype("category").cat.as_known().cat.codes
df = df.sort_values(by=qid_uid)
cnt = df.groupby(qid_uid).qid.count()
cnt = df.groupby(qid_uid)[qid_uid].count()
div = cnt.index.compute().values.tolist()
div = sorted(div)
div = tuple(div + [div[-1] + 1])
Expand Down

0 comments on commit ac4824e

Please sign in to comment.