Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow JVM-Package to access inplace predict method #9167

Merged
merged 24 commits into from
Sep 11, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
c5d448e
Logging in cpu predictor (+23 squashed commits)
StephanTLavavej May 12, 2023
ee0a029
Clean up some comments (+5 squashed commits)
yoquinjo Mar 23, 2023
5ce8d65
Additional documentation (+3 squashed commits)
yoquinjo May 17, 2023
058aec3
Formatting
yoquinjo May 18, 2023
79ce356
Adjust boosterimpltest import statments
yoquinjo May 18, 2023
86e38d4
Mimic assertion params from DMatrixTest
yoquinjo May 19, 2023
0c5b2df
Clean up some comments in BoosterImplTest
yoquinjo May 23, 2023
312c874
update attribution of authorship
yoquinjo May 23, 2023
bed3035
Update BoosterImplTest.java format
yoquinjo Jun 1, 2023
57d88da
Start working on inplace prediction.
trivialfis Aug 7, 2023
b923c95
Replace the implementation.
trivialfis Aug 7, 2023
3b88aaf
jni
trivialfis Aug 7, 2023
e132136
test.
trivialfis Aug 7, 2023
d623163
cleanup.
trivialfis Aug 7, 2023
e4f5ef8
win.
trivialfis Aug 7, 2023
bccd539
cleanup.
trivialfis Aug 9, 2023
8b4885d
Merge pull request #9 from trivialfis/jvm-inplace-predict
yoquinjo Aug 21, 2023
30ed3ca
Merge branch 'master' into sovrn-inplace-predict-java
yoquinjo Aug 22, 2023
509c880
EXDS-35- Cleaned up inplace_predict test. Refactored into two separat…
ByteSizedJoe Aug 30, 2023
272ed03
EXDS-35- Modified iteration range to use [0,10] since the number of r…
ByteSizedJoe Aug 30, 2023
cc8e79b
EXDS-35 - Added a few clarifying comments, explicitly defined iterati…
ByteSizedJoe Aug 31, 2023
81c23c2
EXDS-35 - Brought in trivial from dmlc's changes which include settin…
ByteSizedJoe Aug 31, 2023
dec1375
Merge remote-tracking branch 'dmlc/master' into sovrn-inplace-predict…
ByteSizedJoe Sep 1, 2023
e91026a
Change bst_d_ordinal_t kCpuId to use DeviceOrd
ByteSizedJoe Sep 8, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions include/xgboost/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,33 @@ XGB_DLL int XGBoosterPredict(BoosterHandle handle,
bst_ulong *out_len,
const float **out_result);

/*!
* \brief make prediction based on dmat (deprecated, use `XGBoosterPredictFromDMatrix` instead)
* \param handle handle
* \param dmat data matrix
* \param missing value in the input data which needs to be present as a missing value
* \param option_mask bit-mask of options taken in prediction, possible values
* 0:normal prediction
* 1:output margin instead of transformed value
* 2:output leaf index of trees instead of leaf value, note leaf index is unique per tree
* 4:output feature contributions to individual predictions
* \param ntree_limit limit number of trees used for prediction, this is only valid for boosted trees
* when the parameter is set to 0, we will use all the trees\param out_len used to store length of returning result
* \param out_len used to store length of returning result
* \param out_result used to set a pointer to array
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterInplacePredict(BoosterHandle handle,
const float *data,
size_t num_rows,
size_t num_features,
DMatrixHandle d_matrix_handle,
float missing,
int option_mask,
int ntree_limit,
const bst_ulong **len,
const float **out_result);

/*!
* \brief Make prediction from DMatrix, replacing \ref XGBoosterPredict.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,159 @@ private synchronized float[][] predict(DMatrix data,
return predicts;
}

/**
* Perform thread-safe prediction. Calls
* <code>inplace_predict(data, num_rows, num_features, Float.NaN, false, 0, false, false)</code>.
*
* @param data Flattened input matrix of features for prediction
* @param num_rows The number of preditions to make (count of input matrix rows)
* @param num_features The number of features in the model (count of input matrix columns)
*
* @return predict Result matrix
*
* @see #inplace_predict(float[] data, int num_rows, int num_features, float missing,
* boolean outputMargin, int treeLimit, boolean predLeaf,
* boolean predContribs)
*/
public float[][] inplace_predict(float[] data,
int num_rows,
int num_features) throws XGBoostError {
return this.inplace_predict(data, num_rows, num_features,
Float.NaN, false, 0, false, false);
}

/**
* Perform thread-safe prediction. Calls
* <code>inplace_predict(data, num_rows, num_features, missing, false, 0, false, false)</code>.
*
* @param data Flattened input matrix of features for prediction
* @param num_rows The number of preditions to make (count of input matrix rows)
* @param num_features The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
*
* @return predict Result matrix
*
* @see #inplace_predict(float[] data, int num_rows, int num_features, float missing,
* boolean outputMargin, int treeLimit, boolean predLeaf,
* boolean predContribs)
*/
public float[][] inplace_predict(float[] data,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wbo4958 I just learned about the existence of BigDenseMatrix, do you think we should return the prediction in that?

int num_rows,
int num_features,
float missing) throws XGBoostError {
return this.inplace_predict(data, num_rows, num_features,
missing, false, 0, false, false);
}

/**
* Perform thread-safe prediction. Calls
* <code>inplace_predict(data, num_rows, num_features, missing,
* outputMargin, 0, false, false)</code>.
*
* @param data Flattened input matrix of features for prediction
* @param num_rows The number of preditions to make (count of input matrix rows)
* @param num_features The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
* @param outputMargin Whether to only predict margin value instead of transformed prediction
*
* @return predict Result matrix
*
* @see #inplace_predict(float[] data, int num_rows, int num_features, float missing,
* boolean outputMargin, int treeLimit, boolean predLeaf,
* boolean predContribs)
*/

public float[][] inplace_predict(float[] data,
int num_rows,
int num_features,
float missing,
boolean outputMargin) throws XGBoostError {
return this.inplace_predict(data, num_rows, num_features, missing,
outputMargin, 0, false, false);
}

/**
* Perform thread-safe prediction. Calls
* <code>inplace_predict(data, num_rows, num_features, missing,
* outputMargin, treeLimit, false, false)</code>.
*
* @param data Flattened input matrix of features for prediction
* @param num_rows The number of preditions to make (count of input matrix rows)
* @param num_features The number of features in the model (count of input matrix columns)
* @param missing Value indicating missing element in the <code>data</code> input matrix
* @param outputMargin Whether to only predict margin value instead of transformed prediction
* @param treeLimit limit number of trees, 0 means all trees.
*
* @return predict Result matrix
*
* @see #inplace_predict(float[] data, int num_rows, int num_features, float missing,
* boolean outputMargin, int treeLimit, boolean predLeaf,
* boolean predContribs)
*/
public float[][] inplace_predict(float[] data,
int num_rows,
int num_features,
float missing,
boolean outputMargin,
int treeLimit) throws XGBoostError {
return this.inplace_predict(data, num_rows, num_features, missing,
outputMargin, treeLimit, false, false);
}

/**
* Perform thread-safe prediction.
*
* @param data Flattened input matrix of features for prediction
* @param num_rows The number of preditions to make (count of input matrix rows)
* @param num_features The number of features in the model (count of input matrix columns)
* @param d_matrix_h The handle for a dmatrix
* @param missing Value indicating missing element in the <code>data</code> input matrix
* @param outputMargin Whether to only predict margin value instead of transformed prediction
* @param treeLimit limit number of trees, 0 means all trees.
* @param predLeaf prediction minimum to keep leafs
* @param predContribs prediction feature contributions
*
* @return predict Result matrix
*/
public float[][] inplace_predict(float[] data,
int num_rows,
int num_features,
float missing,
boolean outputMargin,
int treeLimit,
boolean predLeaf,
boolean predContribs) throws XGBoostError {
int optionMask = 0;
if (outputMargin) {
optionMask = 1;
}
if (predLeaf) {
optionMask = 2;
yoquinjo marked this conversation as resolved.
Show resolved Hide resolved
}
if (predContribs) {
optionMask = 4;
yoquinjo marked this conversation as resolved.
Show resolved Hide resolved
}
DMatrix d_mat = new DMatrix(data, num_rows, num_features, missing);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, inplace_predict doesn't need the DMatrix, which is the primary motivation of developing this method.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay - I'm looking through the python code that calls predictFromDense to see how that works - should we be implementing and using proxy DMatrix on the Java side of things? I'll start looking at making that change if you can just let me know if that would be the correct approach.

Historically, we didn't have DMatrix here when we initially made our changes in 1.4 - this addition of the DMatrix here was due to changes in cpu_predictor.cc that now requires a DMatrix, we were originally using the parameter dmlc::any const &x in the old version of inplacePredict there and sending a null pointer for the DMatrix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for sharing! We unified the interface under the proxy DMatrix class. Let me do some more investigation into the JVM implementation first. I think I can help with the C++ code.

float[][] rawPredicts = new float[1][];
XGBoostJNI.checkCall(XGBoostJNI.XGBoosterInplacePredict(handle, data, num_rows, num_features,
d_mat.getHandle(), missing,
optionMask, treeLimit, rawPredicts)); // pass missing and treelimit here?

// System.out.println("Booster.inplace_predict rawPredicts[0].length = " +
// rawPredicts[0].length);

int row = num_rows;
int col = rawPredicts[0].length / row;
float[][] predicts = new float[row][col];
int r, c;
for (int i = 0; i < rawPredicts[0].length; i++) {
r = i / col;
c = i % col;
predicts[r][c] = rawPredicts[0][i];
}
return predicts;
}

/**
* Predict leaf indices given the data
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ public final static native int XGBoosterEvalOneIter(long handle, int iter, long[
public final static native int XGBoosterPredict(long handle, long dmat, int option_mask,
int ntree_limit, float[][] predicts);

public final static native int XGBoosterInplacePredict(long handle, float[] data, int num_rows, int num_features, long d_matrix_handle,
float missing, int option_mask, int ntree_limit,
float[][] predicts);

public final static native int XGBoosterLoadModel(long handle, String fname);

public final static native int XGBoosterSaveModel(long handle, String fname);
Expand Down
26 changes: 26 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,32 @@ JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterInplacePredict
* Signature: (J[FIII[[F)I
*/
JNIEXPORT jint JNICALL Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterInplacePredict
(JNIEnv *jenv, jclass jcls, jlong jhandle, jfloatArray jdata, jint num_rows, jint num_features, jlong d_matrix_handle,
jfloat missing, jint option_mask, jint treeLimit, jobjectArray jout) {
BoosterHandle handle = (BoosterHandle) jhandle;
DMatrixHandle dmat = (DMatrixHandle) d_matrix_handle;
jfloat* data = jenv->GetFloatArrayElements(jdata, 0);
const bst_ulong *len;
float *result;
int ret = XGBoosterInplacePredict(handle, data, num_rows, num_features, dmat, missing, option_mask, treeLimit,
&len, (const float **) &result);
JVM_CHECK_CALL(ret);
jenv->ReleaseFloatArrayElements(jdata, data, 0);
if (*len) {
jsize jlen = (jsize) *len;
jfloatArray jarray = jenv->NewFloatArray(jlen);
jenv->SetFloatArrayRegion(jarray, 0, jlen, (jfloat *) result);
jenv->SetObjectArrayElement(jout, 0, jarray);
}
return ret;
}

/*
* Class: ml_dmlc_xgboost4j_java_XGBoostJNI
* Method: XGBoosterLoadModel
Expand Down
7 changes: 7 additions & 0 deletions jvm-packages/xgboost4j/src/native/xgboost4j.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading