diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index f79eceda9ed7..cc50c956fd3c 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -463,7 +463,8 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values, CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; auto *learner = static_cast(handle); - auto x = xgboost::data::DenseAdapter(values, n_rows, n_cols); + std::shared_ptr x{ + new xgboost::data::DenseAdapter(values, n_rows, n_cols)}; HostDeviceVector* p_predt { nullptr }; std::string type { c_type }; learner->InplacePredict(x, type, missing, &p_predt); @@ -494,7 +495,8 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; auto *learner = static_cast(handle); - auto x = data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col); + std::shared_ptr x{ + new xgboost::data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col)}; HostDeviceVector* p_predt { nullptr }; std::string type { c_type }; learner->InplacePredict(x, type, missing, &p_predt); diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 7fc49b43f74d..f1a486d8c307 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -69,7 +69,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle, auto *learner = static_cast(handle); std::string json_str{c_json_strs}; - auto x = data::CudfAdapter(json_str); + auto x = std::make_shared(json_str); HostDeviceVector* p_predt { nullptr }; std::string type { c_type }; learner->InplacePredict(x, type, missing, &p_predt); @@ -97,7 +97,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle, auto *learner = static_cast(handle); std::string json_str{c_json_strs}; - auto x = data::CupyAdapter(json_str); + auto x = std::make_shared(json_str); HostDeviceVector* p_predt { nullptr }; std::string type { c_type }; learner->InplacePredict(x, type, missing, &p_predt); diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index d2e039e48d20..ff3f3f8f51e3 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -34,36 +34,27 @@ struct IsValidFunctor : public thrust::unary_function { }; class CudfAdapterBatch : public detail::NoMetaInfo { + friend class CudfAdapter; + public: CudfAdapterBatch() = default; - CudfAdapterBatch(common::Span columns, - common::Span column_ptr, size_t num_elements) + CudfAdapterBatch(common::Span columns, size_t num_rows) : columns_(columns), - column_ptr_(column_ptr), - num_elements_(num_elements) {} - size_t Size() const { return num_elements_; } + num_rows_(num_rows) {} + size_t Size() const { return num_rows_ * columns_.size(); } __device__ COOTuple GetElement(size_t idx) const { - size_t column_idx = - thrust::upper_bound(thrust::seq,column_ptr_.begin(), column_ptr_.end(), idx) - column_ptr_.begin() - 1; - auto& column = columns_[column_idx]; - size_t row_idx = idx - column_ptr_[column_idx]; + size_t column_idx = idx % columns_.size(); + size_t row_idx = idx / columns_.size(); + auto const& column = columns_[column_idx]; float value = column.valid.Data() == nullptr || column.valid.Check(row_idx) ? column.GetElement(row_idx) : std::numeric_limits::quiet_NaN(); return {row_idx, column_idx, value}; } - __device__ float GetValue(size_t ridx, bst_feature_t fidx) const { - auto const& column = columns_[fidx]; - float value = column.valid.Data() == nullptr || column.valid.Check(ridx) - ? column.GetElement(ridx) - : std::numeric_limits::quiet_NaN(); - return value; - } private: common::Span columns_; - common::Span column_ptr_; - size_t num_elements_; + size_t num_rows_; }; /*! @@ -127,7 +118,6 @@ class CudfAdapter : public detail::SingleBatchDataIter { CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat(); CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian(); std::vector columns; - std::vector column_ptr({0}); auto first_column = ArrayInterface(get(json_columns[0])); device_idx_ = dh::CudaGetPointerDevice(first_column.data); CHECK_NE(device_idx_, -1); @@ -137,7 +127,6 @@ class CudfAdapter : public detail::SingleBatchDataIter { auto column = ArrayInterface(get(json_col)); columns.push_back(column); CHECK_EQ(column.num_cols, 1); - column_ptr.emplace_back(column_ptr.back() + column.num_rows); num_rows_ = std::max(num_rows_, size_t(column.num_rows)); CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data)) << "All columns should use the same device."; @@ -145,23 +134,20 @@ class CudfAdapter : public detail::SingleBatchDataIter { << "All columns should have same number of rows."; } columns_ = columns; - column_ptr_ = column_ptr; - batch_ = CudfAdapterBatch(dh::ToSpan(columns_), dh::ToSpan(column_ptr_), - column_ptr.back()); + batch_ = CudfAdapterBatch(dh::ToSpan(columns_), num_rows_); + } + const CudfAdapterBatch& Value() const override { + CHECK_EQ(batch_.columns_.data(), columns_.data().get()); + return batch_; } - const CudfAdapterBatch& Value() const override { return batch_; } size_t NumRows() const { return num_rows_; } size_t NumColumns() const { return columns_.size(); } size_t DeviceIdx() const { return device_idx_; } - // Cudf is column major - bool IsRowMajor() { return false; } - private: CudfAdapterBatch batch_; dh::device_vector columns_; - dh::device_vector column_ptr_; // Exclusive scan of column sizes size_t num_rows_{0}; int device_idx_; }; @@ -201,8 +187,6 @@ class CupyAdapter : public detail::SingleBatchDataIter { size_t NumColumns() const { return array_interface_.num_cols; } size_t DeviceIdx() const { return device_idx_; } - bool IsRowMajor() { return true; } - private: ArrayInterface array_interface_; CupyAdapterBatch batch_; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index b752e3a3e6aa..f2a0a2ea913d 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -154,8 +154,8 @@ struct WriteCompressedEllpackFunctor { // Here the data is already correctly ordered and simply needs to be compacted // to remove missing data template -void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst, - int device_idx, float missing) { +void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst, + int device_idx, float missing) { // Some witchcraft happens here // The goal is to copy valid elements out of the input to an ellpack matrix // with a given row stride, using no extra working memory Standard stream @@ -209,51 +209,6 @@ void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst, }); } -template -void CopyDataColumnMajor(AdapterT* adapter, const AdapterBatchT& batch, - EllpackPageImpl* dst, float missing) { - // Step 1: Get the sizes of the input columns - dh::caching_device_vector column_sizes(adapter->NumColumns(), 0); - auto d_column_sizes = column_sizes.data().get(); - // Populate column sizes - dh::LaunchN(adapter->DeviceIdx(), batch.Size(), [=] __device__(size_t idx) { - const auto& e = batch.GetElement(idx); - atomicAdd(reinterpret_cast( // NOLINT - &d_column_sizes[e.column_idx]), - static_cast(1)); // NOLINT - }); - - thrust::host_vector host_column_sizes = column_sizes; - - // Step 2: Iterate over columns, place elements in correct row, increment - // temporary row pointers - dh::caching_device_vector temp_row_ptr(adapter->NumRows(), 0); - auto d_temp_row_ptr = temp_row_ptr.data().get(); - auto row_stride = dst->row_stride; - size_t begin = 0; - auto device_accessor = dst->GetDeviceAccessor(adapter->DeviceIdx()); - common::CompressedBufferWriter writer(device_accessor.NumSymbols()); - auto d_compressed_buffer = dst->gidx_buffer.DevicePointer(); - data::IsValidFunctor is_valid(missing); - for (auto size : host_column_sizes) { - size_t end = begin + size; - dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) { - auto writer_non_const = - writer; // For some reason this variable gets captured as const - const auto& e = batch.GetElement(idx + begin); - if (!is_valid(e)) return; - size_t output_position = - e.row_idx * row_stride + d_temp_row_ptr[e.row_idx]; - auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx); - writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx, - output_position); - d_temp_row_ptr[e.row_idx] += 1; - }); - - begin = end; - } -} - void WriteNullValues(EllpackPageImpl* dst, int device_idx, common::Span row_counts) { // Write the null values @@ -284,12 +239,7 @@ EllpackPageImpl::EllpackPageImpl(AdapterT* adapter, float missing, bool is_dense *this = EllpackPageImpl(adapter->DeviceIdx(), cuts, is_dense, row_stride, adapter->NumRows()); - if (adapter->IsRowMajor()) { - CopyDataRowMajor(batch, this, adapter->DeviceIdx(), missing); - } else { - CopyDataColumnMajor(adapter, batch, this, missing); - } - + CopyDataToEllpack(batch, this, adapter->DeviceIdx(), missing); WriteNullValues(this, adapter->DeviceIdx(), row_counts_span); } diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index f7faeca78ecc..f8b775c7a06d 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -35,51 +35,12 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, thrust::device_pointer_cast(offset.data())); } -template -void CopyDataColumnMajor(AdapterT* adapter, common::Span data, - int device_idx, float missing, - common::Span row_ptr) { - // Step 1: Get the sizes of the input columns - dh::device_vector column_sizes(adapter->NumColumns()); - auto d_column_sizes = column_sizes.data().get(); - auto& batch = adapter->Value(); - // Populate column sizes - dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) { - const auto& e = batch.GetElement(idx); - atomicAdd(reinterpret_cast( // NOLINT - &d_column_sizes[e.column_idx]), - static_cast(1)); // NOLINT - }); - - thrust::host_vector host_column_sizes = column_sizes; - - // Step 2: Iterate over columns, place elements in correct row, increment - // temporary row pointers - dh::device_vector temp_row_ptr( - thrust::device_pointer_cast(row_ptr.data()), - thrust::device_pointer_cast(row_ptr.data() + row_ptr.size())); - auto d_temp_row_ptr = temp_row_ptr.data().get(); - size_t begin = 0; - IsValidFunctor is_valid(missing); - for (auto size : host_column_sizes) { - size_t end = begin + size; - dh::LaunchN(device_idx, end - begin, [=] __device__(size_t idx) { - const auto& e = batch.GetElement(idx + begin); - if (!is_valid(e)) return; - data[d_temp_row_ptr[e.row_idx]] = Entry(e.column_idx, e.value); - d_temp_row_ptr[e.row_idx] += 1; - }); - - begin = end; - } -} - // Here the data is already correctly ordered and simply needs to be compacted // to remove missing data template -void CopyDataRowMajor(AdapterT* adapter, common::Span data, - int device_idx, float missing, - common::Span row_ptr) { +void CopyDataToDMatrix(AdapterT* adapter, common::Span data, + int device_idx, float missing, + common::Span row_ptr) { auto& batch = adapter->Value(); auto transform_f = [=] __device__(size_t idx) { const auto& e = batch.GetElement(idx); @@ -116,13 +77,8 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing); info_.num_nonzero_ = sparse_page_.offset.HostVector().back(); sparse_page_.data.Resize(info_.num_nonzero_); - if (adapter->IsRowMajor()) { - CopyDataRowMajor(adapter, sparse_page_.data.DeviceSpan(), - adapter->DeviceIdx(), missing, s_offset); - } else { - CopyDataColumnMajor(adapter, sparse_page_.data.DeviceSpan(), - adapter->DeviceIdx(), missing, s_offset); - } + CopyDataToDMatrix(adapter, sparse_page_.data.DeviceSpan(), + adapter->DeviceIdx(), missing, s_offset); info_.num_col_ = adapter->NumColumns(); info_.num_row_ = adapter->NumRows(); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index f1c5332414b2..2a96d82ad958 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -271,12 +271,12 @@ class CPUPredictor : public Predictor { PredictionCacheEntry *out_preds, uint32_t tree_begin, uint32_t tree_end) const { auto threads = omp_get_max_threads(); - auto m = dmlc::get(x); - CHECK_EQ(m.NumColumns(), model.learner_model_param->num_feature) + auto m = dmlc::get>(x); + CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) << "Number of columns in data must equal to trained model."; MetaInfo info; - info.num_col_ = m.NumColumns(); - info.num_row_ = m.NumRows(); + info.num_col_ = m->NumColumns(); + info.num_row_ = m->NumRows(); this->InitOutPredictions(info, &(out_preds->predictions), model); std::vector workspace(info.num_col_ * 8 * threads); auto &predictions = out_preds->predictions.HostVector(); @@ -284,17 +284,17 @@ class CPUPredictor : public Predictor { InitThreadTemp(threads, model.learner_model_param->num_feature, &thread_temp); size_t constexpr kUnroll = 8; PredictBatchKernel(AdapterView( - &m, missing, common::Span{workspace}), + m.get(), missing, common::Span{workspace}), &predictions, model, tree_begin, tree_end, &thread_temp); } void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, uint32_t tree_begin, unsigned tree_end) const override { - if (x.type() == typeid(data::DenseAdapter)) { + if (x.type() == typeid(std::shared_ptr)) { this->DispatchedInplacePredict( x, model, missing, out_preds, tree_begin, tree_end); - } else if (x.type() == typeid(data::CSRAdapter)) { + } else if (x.type() == typeid(std::shared_ptr)) { this->DispatchedInplacePredict( x, model, missing, out_preds, tree_begin, tree_end); } else { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 084a01fd4fde..0e57276d692f 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -118,14 +118,18 @@ struct EllpackLoader { } }; -struct CuPyAdapterLoader { - data::CupyAdapterBatch batch; +template +struct DeviceAdapterLoader { + Batch batch; bst_feature_t columns; float* smem; bool use_shared; - DEV_INLINE CuPyAdapterLoader(data::CupyAdapterBatch const batch, bool use_shared, - bst_feature_t num_features, bst_row_t num_rows, size_t entry_start) : + using BatchT = Batch; + + DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared, + bst_feature_t num_features, bst_row_t num_rows, + size_t entry_start) : batch{batch}, columns{num_features}, use_shared{use_shared} { @@ -155,39 +159,6 @@ struct CuPyAdapterLoader { } }; -struct CuDFAdapterLoader { - data::CudfAdapterBatch batch; - bst_feature_t columns; - float* smem; - bool use_shared; - - DEV_INLINE CuDFAdapterLoader(data::CudfAdapterBatch const batch, bool use_shared, - bst_feature_t num_features, - bst_row_t num_rows, size_t entry_start) - : batch{batch}, columns{num_features}, use_shared{use_shared} { - extern __shared__ float _smem[]; - smem = _smem; - if (use_shared) { - uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; - size_t shared_elements = blockDim.x * num_features; - dh::BlockFill(smem, shared_elements, nanf("")); - __syncthreads(); - if (global_idx < num_rows) { - for (size_t i = 0; i < columns; ++i) { - smem[threadIdx.x * columns + i] = batch.GetValue(global_idx, i); - } - } - } - __syncthreads(); - } - DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const { - if (use_shared) { - return smem[threadIdx.x * columns + fidx]; - } - return batch.GetValue(ridx, fidx); - } -}; - template __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, Loader* loader) { @@ -429,7 +400,7 @@ class GPUPredictor : public xgboost::Predictor { out_preds->Size() == dmat->Info().num_row_); } - template + template void DispatchedInplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, @@ -439,22 +410,22 @@ class GPUPredictor : public xgboost::Predictor { DeviceModel d_model; d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id); - auto m = dmlc::get(x); - CHECK_EQ(m.NumColumns(), model.learner_model_param->num_feature) + auto m = dmlc::get>(x); + CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) << "Number of columns in data must equal to trained model."; - CHECK_EQ(this->generic_param_->gpu_id, m.DeviceIdx()) + CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx()) << "XGBoost is running on device: " << this->generic_param_->gpu_id << ", " - << "but data is on: " << m.DeviceIdx(); + << "but data is on: " << m->DeviceIdx(); MetaInfo info; - info.num_col_ = m.NumColumns(); - info.num_row_ = m.NumRows(); + info.num_col_ = m->NumColumns(); + info.num_row_ = m->NumRows(); this->InitOutPredictions(info, &(out_preds->predictions), model); const uint32_t BLOCK_THREADS = 128; auto GRID_SIZE = static_cast(common::DivRoundUp(info.num_row_, BLOCK_THREADS)); auto shared_memory_bytes = - static_cast(sizeof(float) * m.NumColumns() * BLOCK_THREADS); + static_cast(sizeof(float) * m->NumColumns() * BLOCK_THREADS); bool use_shared = true; if (shared_memory_bytes > max_shared_memory_bytes) { shared_memory_bytes = 0; @@ -463,22 +434,24 @@ class GPUPredictor : public xgboost::Predictor { size_t entry_start = 0; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, - m.Value(), + PredictKernel, + m->Value(), dh::ToSpan(d_model.nodes), out_preds->predictions.DeviceSpan(), dh::ToSpan(d_model.tree_segments), dh::ToSpan(d_model.tree_group), - tree_begin, tree_end, m.NumColumns(), info.num_row_, + tree_begin, tree_end, m->NumColumns(), info.num_row_, entry_start, use_shared, output_groups); } void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, uint32_t tree_begin, unsigned tree_end) const override { - if (x.type() == typeid(data::CupyAdapter)) { - this->DispatchedInplacePredict( + if (x.type() == typeid(std::shared_ptr)) { + this->DispatchedInplacePredict< + data::CupyAdapter, DeviceAdapterLoader>( x, model, missing, out_preds, tree_begin, tree_end); - } else if (x.type() == typeid(data::CudfAdapter)) { - this->DispatchedInplacePredict( + } else if (x.type() == typeid(std::shared_ptr)) { + this->DispatchedInplacePredict< + data::CudfAdapter, DeviceAdapterLoader>( x, model, missing, out_preds, tree_begin, tree_end); } else { LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor."; diff --git a/tests/cpp/data/test_device_adapter.cu b/tests/cpp/data/test_device_adapter.cu index e652db377c5c..181304b3aaea 100644 --- a/tests/cpp/data/test_device_adapter.cu +++ b/tests/cpp/data/test_device_adapter.cu @@ -36,13 +36,12 @@ void TestCudfAdapter() EXPECT_NO_THROW({ dh::LaunchN(0, batch.Size(), [=] __device__(size_t idx) { auto element = batch.GetElement(idx); - if (idx < kRowsA) { + KERNEL_CHECK(element.row_idx == idx / 2); + if (idx % 2 == 0) { KERNEL_CHECK(element.column_idx == 0); - KERNEL_CHECK(element.row_idx == idx); KERNEL_CHECK(element.value == element.row_idx * 2.0f); } else { KERNEL_CHECK(element.column_idx == 1); - KERNEL_CHECK(element.row_idx == idx - kRowsA); KERNEL_CHECK(element.value == element.row_idx * 2.0f); } }); diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index 41d95ae05a8f..cc34b1399ab9 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -149,7 +149,8 @@ TEST(CpuPredictor, InplacePredict) { HostDeviceVector data; gen.GenerateDense(&data); ASSERT_EQ(data.Size(), kRows * kCols); - data::DenseAdapter x{data.HostPointer(), kRows, kCols}; + std::shared_ptr x{ + new data::DenseAdapter(data.HostPointer(), kRows, kCols)}; TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); } @@ -158,8 +159,9 @@ TEST(CpuPredictor, InplacePredict) { HostDeviceVector rptrs; HostDeviceVector columns; gen.GenerateCSR(&data, &rptrs, &columns); - data::CSRAdapter x(rptrs.HostPointer(), columns.HostPointer(), - data.HostPointer(), kRows, data.Size(), kCols); + std::shared_ptr x{new data::CSRAdapter( + rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), kRows, + data.Size(), kCols)}; TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); } } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index e612bec95737..aee33ef76eec 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -129,7 +129,7 @@ TEST(GPUPredictor, InplacePredictCupy) { gen.Device(0); HostDeviceVector data; std::string interface_str = gen.GenerateArrayInterface(&data); - data::CupyAdapter x{interface_str}; + auto x = std::make_shared(interface_str); TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0); } @@ -139,7 +139,7 @@ TEST(GPUPredictor, InplacePredictCuDF) { gen.Device(0); std::vector> storage(kCols); auto interface_str = gen.GenerateColumnarArrayInterface(&storage); - data::CudfAdapter x {interface_str}; + auto x = std::make_shared(interface_str); TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0); } @@ -154,7 +154,7 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT gen.Device(1); HostDeviceVector data; std::string interface_str = gen.GenerateArrayInterface(&data); - data::CupyAdapter x{interface_str}; + auto x = std::make_shared(interface_str); TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 1); EXPECT_THROW(TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0), dmlc::Error);