From a58525567b82970113a75239043075ad2c22073b Mon Sep 17 00:00:00 2001 From: Hyunsu Cho Date: Tue, 22 Sep 2020 02:04:58 +0000 Subject: [PATCH] Revert #5755 --- src/c_api/c_api.cc | 6 +- src/c_api/c_api.cu | 4 +- src/data/device_adapter.cuh | 52 ++++++++++++---- src/data/ellpack_page.cu | 53 +++++++++++++++- src/data/simple_dmatrix.cu | 54 ++++++++++++++-- src/predictor/cpu_predictor.cc | 14 ++--- src/predictor/gpu_predictor.cu | 76 ++++++++++++++++------- tests/cpp/data/test_device_adapter.cu | 5 +- tests/cpp/predictor/test_cpu_predictor.cc | 8 +-- tests/cpp/predictor/test_gpu_predictor.cu | 6 +- 10 files changed, 212 insertions(+), 66 deletions(-) diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 397f83e69bf8..00a4434f34da 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -539,8 +539,7 @@ XGB_DLL int XGBoosterPredictFromDense(BoosterHandle handle, float *values, CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; auto *learner = static_cast(handle); - std::shared_ptr x{ - new xgboost::data::DenseAdapter(values, n_rows, n_cols)}; + auto x = xgboost::data::DenseAdapter(values, n_rows, n_cols); HostDeviceVector* p_predt { nullptr }; std::string type { c_type }; learner->InplacePredict(x, type, missing, &p_predt); @@ -571,8 +570,7 @@ XGB_DLL int XGBoosterPredictFromCSR(BoosterHandle handle, CHECK_EQ(cache_id, 0) << "Cache ID is not supported yet"; auto *learner = static_cast(handle); - std::shared_ptr x{ - new xgboost::data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col)}; + auto x = data::CSRAdapter(indptr, indices, data, nindptr - 1, nelem, num_col); HostDeviceVector* p_predt { nullptr }; std::string type { c_type }; learner->InplacePredict(x, type, missing, &p_predt); diff --git a/src/c_api/c_api.cu b/src/c_api/c_api.cu index 5af04894d3f9..b56239a7d547 100644 --- a/src/c_api/c_api.cu +++ b/src/c_api/c_api.cu @@ -46,7 +46,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterfaceColumns(BoosterHandle handle, auto *learner = static_cast(handle); std::string json_str{c_json_strs}; - auto x = std::make_shared(json_str); + auto x = data::CudfAdapter(json_str); HostDeviceVector* p_predt { nullptr }; std::string type { c_type }; learner->InplacePredict(x, type, missing, &p_predt); @@ -74,7 +74,7 @@ XGB_DLL int XGBoosterPredictFromArrayInterface(BoosterHandle handle, auto *learner = static_cast(handle); std::string json_str{c_json_strs}; - auto x = std::make_shared(json_str); + auto x = data::CupyAdapter(json_str); HostDeviceVector* p_predt { nullptr }; std::string type { c_type }; learner->InplacePredict(x, type, missing, &p_predt); diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 709368f5c756..db5208b8e279 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -34,29 +34,44 @@ struct IsValidFunctor : public thrust::unary_function { }; class CudfAdapterBatch : public detail::NoMetaInfo { - friend class CudfAdapter; - public: CudfAdapterBatch() = default; - CudfAdapterBatch(common::Span columns, size_t num_rows) + CudfAdapterBatch(common::Span columns, + common::Span column_ptr, size_t num_elements, size_t num_rows) : columns_(columns), + column_ptr_(column_ptr), + num_elements_(num_elements), num_rows_(num_rows) {} - size_t Size() const { return num_rows_ * columns_.size(); } + size_t Size() const { return num_elements_; } __device__ COOTuple GetElement(size_t idx) const { - size_t column_idx = idx % columns_.size(); - size_t row_idx = idx / columns_.size(); - auto const& column = columns_[column_idx]; + size_t column_idx = + thrust::upper_bound(thrust::seq, column_ptr_.begin(), column_ptr_.end(), idx) + - column_ptr_.begin() - 1; + auto& column = columns_[column_idx]; + size_t row_idx = idx - column_ptr_[column_idx]; float value = column.valid.Data() == nullptr || column.valid.Check(row_idx) ? column.GetElement(row_idx) : std::numeric_limits::quiet_NaN(); return {row_idx, column_idx, value}; } + __device__ float GetValue(size_t ridx, bst_feature_t fidx) const { + auto const& column = columns_[fidx]; + float value = column.valid.Data() == nullptr || column.valid.Check(ridx) + ? column.GetElement(ridx) + : std::numeric_limits::quiet_NaN(); + return value; + } - XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; } + XGBOOST_DEVICE bst_row_t NumRows() const { return num_elements_ / columns_.size(); } XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); } + // Cudf is column major + bool IsRowMajor() { return false; } + private: common::Span columns_; + common::Span column_ptr_; + size_t num_elements_; size_t num_rows_; }; @@ -121,6 +136,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { CHECK_EQ(typestr.size(), 3) << ArrayInterfaceErrors::TypestrFormat(); CHECK_NE(typestr.front(), '>') << ArrayInterfaceErrors::BigEndian(); std::vector columns; + std::vector column_ptr({0}); auto first_column = ArrayInterface(get(json_columns[0])); num_rows_ = first_column.num_rows; if (num_rows_ == 0) { @@ -134,6 +150,7 @@ class CudfAdapter : public detail::SingleBatchDataIter { auto column = ArrayInterface(get(json_col)); columns.push_back(column); CHECK_EQ(column.num_cols, 1); + column_ptr.emplace_back(column_ptr.back() + column.num_rows); num_rows_ = std::max(num_rows_, size_t(column.num_rows)); CHECK_EQ(device_idx_, dh::CudaGetPointerDevice(column.data)) << "All columns should use the same device."; @@ -141,20 +158,23 @@ class CudfAdapter : public detail::SingleBatchDataIter { << "All columns should have same number of rows."; } columns_ = columns; - batch_ = CudfAdapterBatch(dh::ToSpan(columns_), num_rows_); - } - const CudfAdapterBatch& Value() const override { - CHECK_EQ(batch_.columns_.data(), columns_.data().get()); - return batch_; + column_ptr_ = column_ptr; + batch_ = CudfAdapterBatch(dh::ToSpan(columns_), dh::ToSpan(column_ptr_), + column_ptr.back(), num_rows_); } + const CudfAdapterBatch& Value() const override { return batch_; } size_t NumRows() const { return num_rows_; } size_t NumColumns() const { return columns_.size(); } size_t DeviceIdx() const { return device_idx_; } + // Cudf is column major + bool IsRowMajor() { return false; } + private: CudfAdapterBatch batch_; dh::device_vector columns_; + dh::device_vector column_ptr_; // Exclusive scan of column sizes size_t num_rows_{0}; int device_idx_; }; @@ -177,6 +197,9 @@ class CupyAdapterBatch : public detail::NoMetaInfo { XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.num_rows; } XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.num_cols; } + // Cupy is row major + bool IsRowMajor() { return true; } + private: ArrayInterface array_interface_; }; @@ -200,6 +223,9 @@ class CupyAdapter : public detail::SingleBatchDataIter { size_t NumColumns() const { return array_interface_.num_cols; } size_t DeviceIdx() const { return device_idx_; } + // Cupy is row major + bool IsRowMajor() { return true; } + private: ArrayInterface array_interface_; CupyAdapterBatch batch_; diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 39e845f2d765..482b520d9aae 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -154,7 +154,7 @@ struct WriteCompressedEllpackFunctor { // Here the data is already correctly ordered and simply needs to be compacted // to remove missing data template -void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst, +void CopyDataRowMajor(const AdapterBatchT& batch, EllpackPageImpl* dst, int device_idx, float missing) { // Some witchcraft happens here // The goal is to copy valid elements out of the input to an ellpack matrix @@ -209,6 +209,51 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst, }); } +template +void CopyDataColumnMajor(const AdapterBatchT& batch, EllpackPageImpl* dst, + int device_idx, float missing) { + // Step 1: Get the sizes of the input columns + dh::caching_device_vector column_sizes(batch.NumCols(), 0); + auto d_column_sizes = column_sizes.data().get(); + // Populate column sizes + dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) { + const auto& e = batch.GetElement(idx); + atomicAdd(reinterpret_cast( // NOLINT + &d_column_sizes[e.column_idx]), + static_cast(1)); // NOLINT + }); + + thrust::host_vector host_column_sizes = column_sizes; + + // Step 2: Iterate over columns, place elements in correct row, increment + // temporary row pointers + dh::caching_device_vector temp_row_ptr(batch.NumRows(), 0); + auto d_temp_row_ptr = temp_row_ptr.data().get(); + auto row_stride = dst->row_stride; + size_t begin = 0; + auto device_accessor = dst->GetDeviceAccessor(device_idx); + common::CompressedBufferWriter writer(device_accessor.NumSymbols()); + auto d_compressed_buffer = dst->gidx_buffer.DevicePointer(); + data::IsValidFunctor is_valid(missing); + for (auto size : host_column_sizes) { + size_t end = begin + size; + dh::LaunchN(device_idx, end - begin, [=] __device__(size_t idx) { + auto writer_non_const = + writer; // For some reason this variable gets captured as const + const auto& e = batch.GetElement(idx + begin); + if (!is_valid(e)) return; + size_t output_position = + e.row_idx * row_stride + d_temp_row_ptr[e.row_idx]; + auto bin_idx = device_accessor.SearchBin(e.value, e.column_idx); + writer_non_const.AtomicWriteSymbol(d_compressed_buffer, bin_idx, + output_position); + d_temp_row_ptr[e.row_idx] += 1; + }); + + begin = end; + } +} + void WriteNullValues(EllpackPageImpl* dst, int device_idx, common::Span row_counts) { // Write the null values @@ -237,7 +282,11 @@ EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device, dh::safe_cuda(cudaSetDevice(device)); *this = EllpackPageImpl(device, cuts, is_dense, row_stride, n_rows); - CopyDataToEllpack(batch, this, device, missing); + if (batch.IsRowMajor()) { + CopyDataRowMajor(batch, this, device, missing); + } else { + CopyDataColumnMajor(batch, this, device, missing); + } WriteNullValues(this, device, row_counts_span); } diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index f8b775c7a06d..bcea01d24bbb 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -35,12 +35,51 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, thrust::device_pointer_cast(offset.data())); } +template +void CopyDataColumnMajor(AdapterT* adapter, common::Span data, + int device_idx, float missing, + common::Span row_ptr) { + // Step 1: Get the sizes of the input columns + dh::device_vector column_sizes(adapter->NumColumns()); + auto d_column_sizes = column_sizes.data().get(); + auto& batch = adapter->Value(); + // Populate column sizes + dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) { + const auto& e = batch.GetElement(idx); + atomicAdd(reinterpret_cast( // NOLINT + &d_column_sizes[e.column_idx]), + static_cast(1)); // NOLINT + }); + + thrust::host_vector host_column_sizes = column_sizes; + + // Step 2: Iterate over columns, place elements in correct row, increment + // temporary row pointers + dh::device_vector temp_row_ptr( + thrust::device_pointer_cast(row_ptr.data()), + thrust::device_pointer_cast(row_ptr.data() + row_ptr.size())); + auto d_temp_row_ptr = temp_row_ptr.data().get(); + size_t begin = 0; + IsValidFunctor is_valid(missing); + for (auto size : host_column_sizes) { + size_t end = begin + size; + dh::LaunchN(device_idx, end - begin, [=] __device__(size_t idx) { + const auto& e = batch.GetElement(idx + begin); + if (!is_valid(e)) return; + data[d_temp_row_ptr[e.row_idx]] = Entry(e.column_idx, e.value); + d_temp_row_ptr[e.row_idx] += 1; + }); + + begin = end; + } +} + // Here the data is already correctly ordered and simply needs to be compacted // to remove missing data template -void CopyDataToDMatrix(AdapterT* adapter, common::Span data, - int device_idx, float missing, - common::Span row_ptr) { +void CopyDataRowMajor(AdapterT* adapter, common::Span data, + int device_idx, float missing, + common::Span row_ptr) { auto& batch = adapter->Value(); auto transform_f = [=] __device__(size_t idx) { const auto& e = batch.GetElement(idx); @@ -77,8 +116,13 @@ SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing); info_.num_nonzero_ = sparse_page_.offset.HostVector().back(); sparse_page_.data.Resize(info_.num_nonzero_); - CopyDataToDMatrix(adapter, sparse_page_.data.DeviceSpan(), - adapter->DeviceIdx(), missing, s_offset); + if (adapter->IsRowMajor()) { + CopyDataRowMajor(adapter, sparse_page_.data.DeviceSpan(), + adapter->DeviceIdx(), missing, s_offset); + } else { + CopyDataColumnMajor(adapter, sparse_page_.data.DeviceSpan(), + adapter->DeviceIdx(), missing, s_offset); + } info_.num_col_ = adapter->NumColumns(); info_.num_row_ = adapter->NumRows(); diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 8d116999e62c..538112f43306 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -269,12 +269,12 @@ class CPUPredictor : public Predictor { PredictionCacheEntry *out_preds, uint32_t tree_begin, uint32_t tree_end) const { auto threads = omp_get_max_threads(); - auto m = dmlc::get>(x); - CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) + auto m = dmlc::get(x); + CHECK_EQ(m.NumColumns(), model.learner_model_param->num_feature) << "Number of columns in data must equal to trained model."; MetaInfo info; - info.num_col_ = m->NumColumns(); - info.num_row_ = m->NumRows(); + info.num_col_ = m.NumColumns(); + info.num_row_ = m.NumRows(); this->InitOutPredictions(info, &(out_preds->predictions), model); std::vector workspace(info.num_col_ * 8 * threads); auto &predictions = out_preds->predictions.HostVector(); @@ -282,17 +282,17 @@ class CPUPredictor : public Predictor { InitThreadTemp(threads, model.learner_model_param->num_feature, &thread_temp); size_t constexpr kUnroll = 8; PredictBatchKernel(AdapterView( - m.get(), missing, common::Span{workspace}), + &m, missing, common::Span{workspace}), &predictions, model, tree_begin, tree_end, &thread_temp); } void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, uint32_t tree_begin, unsigned tree_end) const override { - if (x.type() == typeid(std::shared_ptr)) { + if (x.type() == typeid(data::DenseAdapter)) { this->DispatchedInplacePredict( x, model, missing, out_preds, tree_begin, tree_end); - } else if (x.type() == typeid(std::shared_ptr)) { + } else if (x.type() == typeid(data::CSRAdapter)) { this->DispatchedInplacePredict( x, model, missing, out_preds, tree_begin, tree_end); } else { diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index fe9664a55254..d4c9d34a138c 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -126,18 +126,15 @@ struct EllpackLoader { } }; -template -struct DeviceAdapterLoader { - Batch batch; +struct CuPyAdapterLoader { + data::CupyAdapterBatch batch; bst_feature_t columns; float* smem; bool use_shared; - using BatchT = Batch; - - XGBOOST_DEV_INLINE DeviceAdapterLoader(Batch const batch, bool use_shared, - bst_feature_t num_features, bst_row_t num_rows, - size_t entry_start) : + XGBOOST_DEV_INLINE CuPyAdapterLoader(data::CupyAdapterBatch const batch, bool use_shared, + bst_feature_t num_features, bst_row_t num_rows, + size_t entry_start) : batch{batch}, columns{num_features}, use_shared{use_shared} { @@ -167,6 +164,39 @@ struct DeviceAdapterLoader { } }; +struct CuDFAdapterLoader { + data::CudfAdapterBatch batch; + bst_feature_t columns; + float* smem; + bool use_shared; + + XGBOOST_DEV_INLINE CuDFAdapterLoader(data::CudfAdapterBatch const batch, bool use_shared, + bst_feature_t num_features, + bst_row_t num_rows, size_t entry_start) + : batch{batch}, columns{num_features}, use_shared{use_shared} { + extern __shared__ float _smem[]; + smem = _smem; + if (use_shared) { + uint32_t global_idx = blockDim.x * blockIdx.x + threadIdx.x; + size_t shared_elements = blockDim.x * num_features; + dh::BlockFill(smem, shared_elements, nanf("")); + __syncthreads(); + if (global_idx < num_rows) { + for (size_t i = 0; i < columns; ++i) { + smem[threadIdx.x * columns + i] = batch.GetValue(global_idx, i); + } + } + } + __syncthreads(); + } + XGBOOST_DEV_INLINE float GetElement(size_t ridx, size_t fidx) const { + if (use_shared) { + return smem[threadIdx.x * columns + fidx]; + } + return batch.GetValue(ridx, fidx); + } +}; + template __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, Loader* loader) { @@ -494,7 +524,7 @@ class GPUPredictor : public xgboost::Predictor { out_preds->Size() == dmat->Info().num_row_); } - template + template void DispatchedInplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, @@ -504,22 +534,22 @@ class GPUPredictor : public xgboost::Predictor { DeviceModel d_model; d_model.Init(model, tree_begin, tree_end, this->generic_param_->gpu_id); - auto m = dmlc::get>(x); - CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature) + auto m = dmlc::get(x); + CHECK_EQ(m.NumColumns(), model.learner_model_param->num_feature) << "Number of columns in data must equal to trained model."; - CHECK_EQ(this->generic_param_->gpu_id, m->DeviceIdx()) + CHECK_EQ(this->generic_param_->gpu_id, m.DeviceIdx()) << "XGBoost is running on device: " << this->generic_param_->gpu_id << ", " - << "but data is on: " << m->DeviceIdx(); + << "but data is on: " << m.DeviceIdx(); MetaInfo info; - info.num_col_ = m->NumColumns(); - info.num_row_ = m->NumRows(); + info.num_col_ = m.NumColumns(); + info.num_row_ = m.NumRows(); this->InitOutPredictions(info, &(out_preds->predictions), model); const uint32_t BLOCK_THREADS = 128; auto GRID_SIZE = static_cast(common::DivRoundUp(info.num_row_, BLOCK_THREADS)); auto shared_memory_bytes = - static_cast(sizeof(float) * m->NumColumns() * BLOCK_THREADS); + static_cast(sizeof(float) * m.NumColumns() * BLOCK_THREADS); bool use_shared = true; if (shared_memory_bytes > max_shared_memory_bytes) { shared_memory_bytes = 0; @@ -528,24 +558,24 @@ class GPUPredictor : public xgboost::Predictor { size_t entry_start = 0; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, - m->Value(), + PredictKernel, + m.Value(), d_model.nodes.DeviceSpan(), out_preds->predictions.DeviceSpan(), d_model.tree_segments.DeviceSpan(), d_model.tree_group.DeviceSpan(), - tree_begin, tree_end, m->NumColumns(), info.num_row_, + tree_begin, tree_end, m.NumColumns(), info.num_row_, entry_start, use_shared, output_groups); } void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, float missing, PredictionCacheEntry *out_preds, uint32_t tree_begin, unsigned tree_end) const override { - if (x.type() == typeid(std::shared_ptr)) { + if (x.type() == typeid(data::CupyAdapter)) { this->DispatchedInplacePredict< - data::CupyAdapter, DeviceAdapterLoader>( + data::CupyAdapter, CuPyAdapterLoader, data::CupyAdapterBatch>( x, model, missing, out_preds, tree_begin, tree_end); - } else if (x.type() == typeid(std::shared_ptr)) { + } else if (x.type() == typeid(data::CudfAdapter)) { this->DispatchedInplacePredict< - data::CudfAdapter, DeviceAdapterLoader>( + data::CudfAdapter, CuDFAdapterLoader, data::CudfAdapterBatch>( x, model, missing, out_preds, tree_begin, tree_end); } else { LOG(FATAL) << "Only CuPy and CuDF are supported by GPU Predictor."; diff --git a/tests/cpp/data/test_device_adapter.cu b/tests/cpp/data/test_device_adapter.cu index 34c8e93b7822..2e4faa9521cb 100644 --- a/tests/cpp/data/test_device_adapter.cu +++ b/tests/cpp/data/test_device_adapter.cu @@ -35,12 +35,13 @@ void TestCudfAdapter() EXPECT_NO_THROW({ dh::LaunchN(0, batch.Size(), [=] __device__(size_t idx) { auto element = batch.GetElement(idx); - KERNEL_CHECK(element.row_idx == idx / 2); - if (idx % 2 == 0) { + if (idx < kRowsA) { KERNEL_CHECK(element.column_idx == 0); + KERNEL_CHECK(element.row_idx == idx); KERNEL_CHECK(element.value == element.row_idx * 2.0f); } else { KERNEL_CHECK(element.column_idx == 1); + KERNEL_CHECK(element.row_idx == idx - kRowsA); KERNEL_CHECK(element.value == element.row_idx * 2.0f); } }); diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index ea8114579494..fad9fadfac2f 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -149,8 +149,7 @@ TEST(CpuPredictor, InplacePredict) { HostDeviceVector data; gen.GenerateDense(&data); ASSERT_EQ(data.Size(), kRows * kCols); - std::shared_ptr x{ - new data::DenseAdapter(data.HostPointer(), kRows, kCols)}; + data::DenseAdapter x{data.HostPointer(), kRows, kCols}; TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); } @@ -159,9 +158,8 @@ TEST(CpuPredictor, InplacePredict) { HostDeviceVector rptrs; HostDeviceVector columns; gen.GenerateCSR(&data, &rptrs, &columns); - std::shared_ptr x{new data::CSRAdapter( - rptrs.HostPointer(), columns.HostPointer(), data.HostPointer(), kRows, - data.Size(), kCols)}; + data::CSRAdapter x(rptrs.HostPointer(), columns.HostPointer(), + data.HostPointer(), kRows, data.Size(), kCols); TestInplacePrediction(x, "cpu_predictor", kRows, kCols, -1); } } diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 585acf1790b6..455f55efa3ce 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -129,7 +129,7 @@ TEST(GPUPredictor, InplacePredictCupy) { gen.Device(0); HostDeviceVector data; std::string interface_str = gen.GenerateArrayInterface(&data); - auto x = std::make_shared(interface_str); + data::CupyAdapter x{interface_str}; TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0); } @@ -139,7 +139,7 @@ TEST(GPUPredictor, InplacePredictCuDF) { gen.Device(0); std::vector> storage(kCols); auto interface_str = gen.GenerateColumnarArrayInterface(&storage); - auto x = std::make_shared(interface_str); + data::CudfAdapter x {interface_str}; TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0); } @@ -154,7 +154,7 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT gen.Device(1); HostDeviceVector data; std::string interface_str = gen.GenerateArrayInterface(&data); - auto x = std::make_shared(interface_str); + data::CupyAdapter x{interface_str}; TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 1); EXPECT_THROW(TestInplacePrediction(x, "gpu_predictor", kRows, kCols, 0), dmlc::Error);