From db485323ecf049b615f8d09fc05cee96101713cf Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 5 Jun 2020 15:25:39 +0800 Subject: [PATCH 1/8] Implement weighted sketching for adapter. --- src/common/hist_util.cc | 4 + src/common/hist_util.cu | 50 ++---- src/common/hist_util.cuh | 276 ++++++++++++++++++++++++----- src/common/hist_util.h | 1 + src/data/device_adapter.cuh | 6 + tests/cpp/common/test_hist_util.cu | 98 +++++++++- tests/cpp/common/test_hist_util.h | 4 +- 7 files changed, 362 insertions(+), 77 deletions(-) diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index e3ca953d2e4b..d44a705586f1 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -140,6 +140,10 @@ void HistogramCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) { bool CutsBuilder::UseGroup(DMatrix* dmat) { auto& info = dmat->Info(); + return CutsBuilder::UseGroup(info); +} + +bool CutsBuilder::UseGroup(MetaInfo const& info) { size_t const num_groups = info.group_ptr_.size() == 0 ? 0 : info.group_ptr_.size() - 1; // Use group index for weights? diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 716d49f9ac37..c3a5e5f60bf5 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2018 XGBoost contributors + * Copyright 2018~2020 XGBoost contributors */ #include @@ -29,23 +29,6 @@ namespace xgboost { namespace common { // Count the entries in each column and exclusive scan -void GetColumnSizesScan(int device, - dh::caching_device_vector* column_sizes_scan, - Span entries, size_t num_columns) { - column_sizes_scan->resize(num_columns + 1, 0); - auto d_column_sizes_scan = column_sizes_scan->data().get(); - auto d_entries = entries.data(); - dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) { - auto& e = d_entries[idx]; - atomicAdd(reinterpret_cast( // NOLINT - &d_column_sizes_scan[e.index]), - static_cast(1)); // NOLINT - }); - dh::XGBCachingDeviceAllocator alloc; - thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), - column_sizes_scan->end(), column_sizes_scan->begin()); -} - void ExtractCuts(int device, size_t num_cuts_per_feature, Span sorted_data, @@ -158,6 +141,23 @@ void ProcessBatch(int device, const SparsePage& page, size_t begin, size_t end, sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan); } +void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, + dh::caching_device_vector* weights, + dh::caching_device_vector* sorted_entries) { + // Sort both entries and wegihts. + thrust::sort_by_key(thrust::cuda::par(*alloc), sorted_entries->begin(), + sorted_entries->end(), weights->begin(), + EntryCompareOp()); + + // Scan weights + thrust::inclusive_scan_by_key(thrust::cuda::par(*alloc), + sorted_entries->begin(), sorted_entries->end(), + weights->begin(), weights->begin(), + [=] __device__(const Entry& a, const Entry& b) { + return a.index == b.index; + }); +} + void ProcessWeightedBatch(int device, const SparsePage& page, Span weights, size_t begin, size_t end, SketchContainer* sketch_container, int num_cuts_per_feature, @@ -201,19 +201,7 @@ void ProcessWeightedBatch(int device, const SparsePage& page, d_temp_weights[idx] = weights[ridx + base_rowid]; }); } - - // Sort both entries and wegihts. - thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries.begin(), - sorted_entries.end(), temp_weights.begin(), - EntryCompareOp()); - - // Scan weights - thrust::inclusive_scan_by_key(thrust::cuda::par(alloc), - sorted_entries.begin(), sorted_entries.end(), - temp_weights.begin(), temp_weights.begin(), - [=] __device__(const Entry& a, const Entry& b) { - return a.index == b.index; - }); + SortByWeight(&alloc, &temp_weights, &sorted_entries); dh::caching_device_vector column_sizes_scan; GetColumnSizesScan(device, &column_sizes_scan, diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index b8057e322f02..39c66ffaa8b8 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -1,9 +1,13 @@ +/*! + * Copyright 2020 XGBoost contributors + */ #ifndef COMMON_HIST_UTIL_CUH_ #define COMMON_HIST_UTIL_CUH_ #include #include "hist_util.h" +#include "threading_utils.h" #include "device_helpers.cuh" #include "../data/device_adapter.cuh" @@ -93,6 +97,62 @@ void ExtractCuts(int device, Span column_sizes_scan, Span out_cuts); +// Count the entries in each column and exclusive scan +inline void GetColumnSizesScan(int device, + dh::caching_device_vector* column_sizes_scan, + Span entries, size_t num_columns) { + column_sizes_scan->resize(num_columns + 1, 0); + auto d_column_sizes_scan = column_sizes_scan->data().get(); + auto d_entries = entries.data(); + dh::LaunchN(device, entries.size(), [=] __device__(size_t idx) { + auto& e = d_entries[idx]; + atomicAdd(reinterpret_cast( // NOLINT + &d_column_sizes_scan[e.index]), + static_cast(1)); // NOLINT + }); + dh::XGBCachingDeviceAllocator alloc; + thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), + column_sizes_scan->end(), column_sizes_scan->begin()); +} + +// For adapter. +template +void GetColumnSizesScan(int device, size_t num_columns, + Iter batch_iter, data::IsValidFunctor is_valid, + size_t begin, size_t end, + dh::caching_device_vector* column_sizes_scan) { + dh::XGBCachingDeviceAllocator alloc; + column_sizes_scan->resize(num_columns + 1, 0); + auto d_column_sizes_scan = column_sizes_scan->data().get(); + dh::LaunchN(device, end - begin, [=] __device__(size_t idx) { + auto e = batch_iter[begin + idx]; + if (is_valid(e)) { + atomicAdd(reinterpret_cast( // NOLINT + &d_column_sizes_scan[e.column_idx]), + static_cast(1)); // NOLINT + } + }); + thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), + column_sizes_scan->end(), column_sizes_scan->begin()); +} + +template +size_t SketchBatchNumElements(AdapterBatch batch, size_t sketch_batch_num_elements, + size_t columns, int device, + size_t num_cuts) { + if (sketch_batch_num_elements == 0) { + int bytes_per_element = 16; + size_t bytes_cuts = num_cuts * columns * sizeof(SketchEntry); + size_t bytes_num_columns = (columns + 1) * sizeof(size_t); + // use up to 80% of available space + sketch_batch_num_elements = (dh::AvailableMemory(device) - + bytes_cuts - bytes_num_columns) * + 0.8 / bytes_per_element; + } + return sketch_batch_num_elements; +} + + // Compute number of sample cuts needed on local node to maintain accuracy // We take more cuts than needed and then reduce them later inline size_t RequiredSampleCuts(int max_bins, size_t num_rows) { @@ -109,52 +169,60 @@ inline size_t RequiredSampleCuts(int max_bins, size_t num_rows) { HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, size_t sketch_batch_num_elements = 0); -template -void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing, - SketchContainer* sketch_container, int num_cuts) { - dh::XGBCachingDeviceAllocator alloc; - adapter->BeforeFirst(); - adapter->Next(); - auto &batch = adapter->Value(); - // Enforce single batch - CHECK(!adapter->Next()); - auto batch_iter = dh::MakeTransformIterator( - thrust::make_counting_iterator(0llu), - [=] __device__(size_t idx) { return batch.GetElement(idx); }); + +template +void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, + Range1d range, float missing, + size_t columns, int device, + thrust::host_vector* host_column_sizes_scan, + dh::caching_device_vector* column_sizes_scan, + dh::caching_device_vector* sorted_entries) { auto entry_iter = dh::MakeTransformIterator( thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { return Entry(batch.GetElement(idx).column_idx, batch.GetElement(idx).value); }); + data::IsValidFunctor is_valid(missing); // Work out how many valid entries we have in each column - dh::caching_device_vector column_sizes_scan(adapter->NumColumns() + 1, - 0); + GetColumnSizesScan(device, columns, + batch_iter, is_valid, + range.begin(), range.end(), + column_sizes_scan); + host_column_sizes_scan->resize(column_sizes_scan->size()); + thrust::copy(column_sizes_scan->begin(), column_sizes_scan->end(), + host_column_sizes_scan->begin()); - auto d_column_sizes_scan = column_sizes_scan.data().get(); - data::IsValidFunctor is_valid(missing); - dh::LaunchN(adapter->DeviceIdx(), end - begin, [=] __device__(size_t idx) { - auto e = batch_iter[begin + idx]; - if (is_valid(e)) { - atomicAdd(reinterpret_cast( // NOLINT - &d_column_sizes_scan[e.column_idx]), - static_cast(1)); // NOLINT - } - }); - thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan.begin(), - column_sizes_scan.end(), column_sizes_scan.begin()); - thrust::host_vector host_column_sizes_scan(column_sizes_scan); - size_t num_valid = host_column_sizes_scan.back(); + size_t num_valid = host_column_sizes_scan->back(); + + // Copy current subset of valid elements into temporary storage and sort + sorted_entries->resize(num_valid); + dh::XGBCachingDeviceAllocator alloc; + thrust::copy_if(thrust::cuda::par(alloc), entry_iter + range.begin(), + entry_iter + range.end(), sorted_entries->begin(), is_valid); +} +template +void ProcessBatchSlidingWindow(AdapterBatch const& batch, int device, size_t columns, + size_t begin, size_t end, float missing, + SketchContainer* sketch_container, int num_cuts) { // Copy current subset of valid elements into temporary storage and sort - dh::caching_device_vector sorted_entries(num_valid); - thrust::copy_if(thrust::cuda::par(alloc), entry_iter + begin, - entry_iter + end, sorted_entries.begin(), is_valid); + dh::caching_device_vector sorted_entries; + dh::caching_device_vector column_sizes_scan; + thrust::host_vector host_column_sizes_scan; + auto batch_iter = dh::MakeTransformIterator( + thrust::make_counting_iterator(0llu), + [=] __device__(size_t idx) { return batch.GetElement(idx); }); + MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing, columns, device, + &host_column_sizes_scan, + &column_sizes_scan, + &sorted_entries); + dh::XGBCachingDeviceAllocator alloc; thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), EntryCompareOp()); // Extract the cuts from all columns concurrently - dh::caching_device_vector cuts(adapter->NumColumns() * num_cuts); - ExtractCuts(adapter->DeviceIdx(), num_cuts, + dh::caching_device_vector cuts(columns * num_cuts); + ExtractCuts(device, num_cuts, dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), dh::ToSpan(cuts)); @@ -164,27 +232,105 @@ void ProcessBatch(AdapterT* adapter, size_t begin, size_t end, float missing, sketch_container->Push(num_cuts, host_cuts, host_column_sizes_scan); } +void ExtractWeightedCuts(int device, + size_t num_cuts_per_feature, + Span sorted_data, + Span weights_scan, + Span column_sizes_scan, + Span cuts); + +void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, + dh::caching_device_vector* weights, + dh::caching_device_vector* sorted_entries); + +template +void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, + int num_cuts_per_feature, + bool is_ranking, float missing, int device, + size_t columns, size_t begin, size_t end, + SketchContainer *sketch_container) { + dh::XGBCachingDeviceAllocator alloc; + dh::safe_cuda(cudaSetDevice(device)); + info.weights_.SetDevice(device); + auto weights = info.weights_.ConstDeviceSpan(); + dh::caching_device_vector group_ptr(info.group_ptr_); + auto d_group_ptr = dh::ToSpan(group_ptr); + + auto batch_iter = dh::MakeTransformIterator( + thrust::make_counting_iterator(0llu), + [=] __device__(size_t idx) { return batch.GetElement(idx); }); + dh::caching_device_vector sorted_entries; + dh::caching_device_vector column_sizes_scan; + thrust::host_vector host_column_sizes_scan; + MakeEntriesFromAdapter(batch, batch_iter, + {begin, end}, missing, columns, device, + &host_column_sizes_scan, + &column_sizes_scan, + &sorted_entries); + data::IsValidFunctor is_valid(missing); + + dh::caching_device_vector temp_weights(sorted_entries.size()); + auto d_temp_weights = dh::ToSpan(temp_weights); + + if (is_ranking) { + auto const weight_iter = dh::MakeTransformIterator( + thrust::make_constant_iterator(0lu), + [=]__device__(size_t idx) -> float { + auto ridx = batch.GetElement(idx).row_idx; + auto it = thrust::upper_bound(thrust::seq, + d_group_ptr.cbegin(), d_group_ptr.cend(), + ridx) - 1; + bst_group_t group = thrust::distance(d_group_ptr.cbegin(), it); + return weights[group]; + }); + auto retit = thrust::copy_if(thrust::cuda::par(alloc), + weight_iter + begin, weight_iter + end, + batch_iter + begin, + d_temp_weights.data(), // output + is_valid); + CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); + } else { + auto const weight_iter = dh::MakeTransformIterator( + thrust::make_counting_iterator(0lu), + [=]__device__(size_t idx) -> float { + return weights[batch.GetElement(idx).row_idx]; + }); + auto retit = thrust::copy_if(thrust::cuda::par(alloc), + weight_iter + begin, weight_iter + end, + batch_iter + begin, + d_temp_weights.data(), // output + is_valid); + CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); + } + + SortByWeight(&alloc, &temp_weights, &sorted_entries); + // Extract cuts + dh::caching_device_vector cuts(columns * num_cuts_per_feature); + ExtractWeightedCuts(device, num_cuts_per_feature, + dh::ToSpan(sorted_entries), + dh::ToSpan(temp_weights), + dh::ToSpan(column_sizes_scan), + dh::ToSpan(cuts)); + + // add cuts into sketches + thrust::host_vector host_cuts(cuts); + sketch_container->Push(num_cuts_per_feature, host_cuts, host_column_sizes_scan); +} + template HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, float missing, size_t sketch_batch_num_elements = 0) { size_t num_cuts = RequiredSampleCuts(num_bins, adapter->NumRows()); - if (sketch_batch_num_elements == 0) { - int bytes_per_element = 16; - size_t bytes_cuts = num_cuts * adapter->NumColumns() * sizeof(SketchEntry); - size_t bytes_num_columns = (adapter->NumColumns() + 1) * sizeof(size_t); - // use up to 80% of available space - sketch_batch_num_elements = (dh::AvailableMemory(adapter->DeviceIdx()) - - bytes_cuts - bytes_num_columns) * - 0.8 / bytes_per_element; - } - CHECK(adapter->NumRows() != data::kAdapterUnknownSize); CHECK(adapter->NumColumns() != data::kAdapterUnknownSize); adapter->BeforeFirst(); adapter->Next(); auto& batch = adapter->Value(); + sketch_batch_num_elements = SketchBatchNumElements( + batch, sketch_batch_num_elements, + adapter->NumColumns(), adapter->DeviceIdx(), num_cuts); // Enforce single batch CHECK(!adapter->Next()); @@ -197,12 +343,54 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); - ProcessBatch(adapter, begin, end, missing, &sketch_container, num_cuts); + auto const& batch = adapter->Value(); + ProcessBatchSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(), + begin, end, missing, &sketch_container, num_cuts); } dense_cuts.Init(&sketch_container.sketches_, num_bins, adapter->NumRows()); return cuts; } + +template +void AdapterDeviceSketch(Batch batch, int num_bins, + float missing, int device, + SketchContainer* sketch_container, + size_t sketch_batch_num_elements = 0) { + size_t num_rows = batch.NumRows(); + size_t num_cols = batch.NumCols(); + size_t num_cuts = RequiredSampleCuts(num_bins, num_rows); + sketch_batch_num_elements = SketchBatchNumElements( + batch, sketch_batch_num_elements, + num_cols, device, num_cuts); + for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { + size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); + ProcessBatchSlidingWindow(batch, device, num_cols, + begin, end, missing, sketch_container, num_cuts); + } +} + +template +void AdapterDeviceSketchWeighted(Batch batch, int num_bins, + MetaInfo const& info, + float missing, + int device, + SketchContainer* sketch_container, + size_t sketch_batch_num_elements = 0) { + size_t num_rows = batch.NumRows(); + size_t num_cols = batch.NumCols(); + size_t num_cuts = RequiredSampleCuts(num_bins, num_rows); + sketch_batch_num_elements = SketchBatchNumElements( + batch, sketch_batch_num_elements, + num_cols, device, num_cuts); + for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { + size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); + ProcessWeightedSlidingWindow(batch, info, + num_cuts, + CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end, + sketch_container); + } +} } // namespace common } // namespace xgboost diff --git a/src/common/hist_util.h b/src/common/hist_util.h index d6096030182f..c48eafad84dc 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -129,6 +129,7 @@ class CutsBuilder { using WQSketch = common::WQuantileSketch; /* \brief return whether group for ranking is used. */ static bool UseGroup(DMatrix* dmat); + static bool UseGroup(MetaInfo const& info); protected: HistogramCuts* p_cuts_; diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index ff3f3f8f51e3..513c42db40f1 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -52,6 +52,9 @@ class CudfAdapterBatch : public detail::NoMetaInfo { return {row_idx, column_idx, value}; } + XGBOOST_DEVICE bst_row_t NumRows() const { return num_elements_ / columns_.size(); } + XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); } + private: common::Span columns_; size_t num_rows_; @@ -167,6 +170,9 @@ class CupyAdapterBatch : public detail::NoMetaInfo { return {row_idx, column_idx, value}; } + XGBOOST_DEVICE bst_row_t NumRows() const { return array_interface_.num_rows; } + XGBOOST_DEVICE bst_row_t NumCols() const { return array_interface_.num_cols; } + private: ArrayInterface array_interface_; }; diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index d4c6a155191e..36c18f00a9b5 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -192,6 +192,20 @@ TEST(HistUtil, DeviceSketchBatches) { auto cuts = DeviceSketch(0, dmat.get(), num_bins, batch_size); ValidateCuts(cuts, dmat.get(), num_bins); } + + num_rows = 1000; + size_t batches = 16; + auto x = GenerateRandom(num_rows * batches, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows * batches, num_columns); + auto cuts_with_batches = DeviceSketch(0, dmat.get(), num_bins, num_rows); + auto cuts = DeviceSketch(0, dmat.get(), num_bins, 0); + + auto const& cut_values_batched = cuts_with_batches.Values(); + auto const& cut_values = cuts.Values(); + CHECK_EQ(cut_values.size(), cut_values_batched.size()); + for (size_t i = 0; i < cut_values.size(); ++i) { + ASSERT_NEAR(cut_values_batched[i], cut_values[i], 1e5); + } } TEST(HistUtil, DeviceSketchMultipleColumnsExternal) { @@ -210,6 +224,19 @@ TEST(HistUtil, DeviceSketchMultipleColumnsExternal) { } } +template +void ValidateBatchedCuts(Adapter adapter, int num_bins, int num_columns, int num_rows, + DMatrix* dmat) { + common::HistogramCuts batched_cuts; + SketchContainer sketch_container(num_bins, num_columns, num_rows); + AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits::quiet_NaN(), + 0, &sketch_container); + common::DenseCuts dense_cuts(&batched_cuts); + dense_cuts.Init(&sketch_container.sketches_, num_bins, num_rows); + ValidateCuts(batched_cuts, dmat, num_bins); +} + + TEST(HistUtil, AdapterDeviceSketch) { int rows = 5; int cols = 1; @@ -284,6 +311,7 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) { auto cuts = AdapterDeviceSketch(&adapter, num_bins, std::numeric_limits::quiet_NaN()); ValidateCuts(cuts, dmat.get(), num_bins); + ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get()); } } } @@ -302,6 +330,7 @@ TEST(HistUtil, AdapterDeviceSketchBatches) { std::numeric_limits::quiet_NaN(), batch_size); ValidateCuts(cuts, dmat.get(), num_bins); + ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get()); } } @@ -323,6 +352,8 @@ TEST(HistUtil, SketchingEquivalent) { EXPECT_EQ(dmat_cuts.Values(), adapter_cuts.Values()); EXPECT_EQ(dmat_cuts.Ptrs(), adapter_cuts.Ptrs()); EXPECT_EQ(dmat_cuts.MinValues(), adapter_cuts.MinValues()); + + ValidateBatchedCuts(adapter, num_bins, num_columns, num_rows, dmat.get()); } } } @@ -330,7 +361,7 @@ TEST(HistUtil, SketchingEquivalent) { TEST(HistUtil, DeviceSketchFromGroupWeights) { size_t constexpr kRows = 3000, kCols = 200, kBins = 256; size_t constexpr kGroups = 10; - auto m = RandomDataGenerator {kRows, kCols, 0}.GenerateDMatrix(); + auto m = RandomDataGenerator{kRows, kCols, 0}.GenerateDMatrix(); auto& h_weights = m->Info().weights_.HostVector(); h_weights.resize(kRows); std::fill(h_weights.begin(), h_weights.end(), 1.0f); @@ -357,6 +388,71 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) { for (size_t i = 0; i < cuts.Ptrs().size(); ++i) { ASSERT_EQ(cuts.Ptrs().at(i), weighted_cuts.Ptrs().at(i)); } + ValidateCuts(weighted_cuts, m.get(), kBins); +} + +void TestAdapterSketchFromWeights(bool with_group) { + size_t constexpr kRows = 300, kCols = 20, kBins = 256; + size_t constexpr kGroups = 10; + HostDeviceVector storage; + std::string m = + RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface( + &storage); + MetaInfo info; + auto& h_weights = info.weights_.HostVector(); + h_weights.resize(kRows); + std::fill(h_weights.begin(), h_weights.end(), 1.0f); + + std::vector groups(kGroups); + if (with_group) { + for (size_t i = 0; i < kGroups; ++i) { + groups[i] = kRows / kGroups; + } + info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + } + + info.weights_.SetDevice(0); + info.num_row_ = kRows; + info.num_col_ = kCols; + + data::CupyAdapter adapter(m); + auto const& batch = adapter.Value(); + SketchContainer sketch_container(kBins, kCols, kRows); + AdapterDeviceSketchWeighted(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), + 0, + &sketch_container); + common::HistogramCuts cuts; + common::DenseCuts dense_cuts(&cuts); + dense_cuts.Init(&sketch_container.sketches_, kBins, kRows); + + auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols); + if (with_group) { + dmat->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + } + + dmat->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size()); + dmat->Info().num_col_ = kCols; + dmat->Info().num_row_ = kRows; + ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); + ValidateCuts(cuts, dmat.get(), kBins); + + if (with_group) { + HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0); + for (size_t i = 0; i < cuts.Values().size(); ++i) { + EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]); + } + for (size_t i = 0; i < cuts.MinValues().size(); ++i) { + ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]); + } + for (size_t i = 0; i < cuts.Ptrs().size(); ++i) { + ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i)); + } + } +} + +TEST(HistUtil, AdapterSketchFromWeights) { + TestAdapterSketchFromWeights(false); + TestAdapterSketchFromWeights(true); } } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index ec55f89d7a76..08405e9f5a40 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -151,7 +151,8 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, size_t num_bins) { // Check the endpoints are correct - EXPECT_LT(cuts.MinValues()[column_idx], sorted_column.front()); + CHECK_GT(sorted_column.size(), 0); + EXPECT_LT(cuts.MinValues().at(column_idx), sorted_column.front()); EXPECT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front()); EXPECT_GE(cuts.Values()[cuts.Ptrs()[column_idx+1]-1], sorted_column.back()); @@ -189,6 +190,7 @@ inline void ValidateCuts(const HistogramCuts& cuts, DMatrix* dmat, // Collect data into columns std::vector> columns(dmat->Info().num_col_); for (auto& batch : dmat->GetBatches()) { + CHECK_GT(batch.Size(), 0); for (auto i = 0ull; i < batch.Size(); i++) { for (auto e : batch[i]) { columns[e.index].push_back(e.fvalue); From c1ee3980003c064e7c2b8f39a67d18447a2fee3f Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 9 Jun 2020 09:41:54 +0800 Subject: [PATCH 2/8] Fix rebase. --- src/data/device_adapter.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 513c42db40f1..5f6a3b6cc732 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -52,7 +52,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { return {row_idx, column_idx, value}; } - XGBOOST_DEVICE bst_row_t NumRows() const { return num_elements_ / columns_.size(); } + XGBOOST_DEVICE bst_row_t NumRows() const { return num_rows_; } XGBOOST_DEVICE bst_row_t NumCols() const { return columns_.size(); } private: From d327f60d190e5e17c15c3ded22660cfe3aeb0bc0 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 9 Jun 2020 09:43:31 +0800 Subject: [PATCH 3/8] Rename. --- src/common/hist_util.cuh | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 39c66ffaa8b8..e18cfae41cf7 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -202,9 +202,9 @@ void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, } template -void ProcessBatchSlidingWindow(AdapterBatch const& batch, int device, size_t columns, - size_t begin, size_t end, float missing, - SketchContainer* sketch_container, int num_cuts) { +void ProcessSlidingWindow(AdapterBatch const& batch, int device, size_t columns, + size_t begin, size_t end, float missing, + SketchContainer* sketch_container, int num_cuts) { // Copy current subset of valid elements into temporary storage and sort dh::caching_device_vector sorted_entries; dh::caching_device_vector column_sizes_scan; @@ -344,8 +344,8 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); auto const& batch = adapter->Value(); - ProcessBatchSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(), - begin, end, missing, &sketch_container, num_cuts); + ProcessSlidingWindow(batch, adapter->DeviceIdx(), adapter->NumColumns(), + begin, end, missing, &sketch_container, num_cuts); } dense_cuts.Init(&sketch_container.sketches_, num_bins, adapter->NumRows()); @@ -365,8 +365,8 @@ void AdapterDeviceSketch(Batch batch, int num_bins, num_cols, device, num_cuts); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); - ProcessBatchSlidingWindow(batch, device, num_cols, - begin, end, missing, sketch_container, num_cuts); + ProcessSlidingWindow(batch, device, num_cols, + begin, end, missing, sketch_container, num_cuts); } } From fbacbc1bacb2cbe4a020b8a286eba7cee17ff66d Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 9 Jun 2020 19:02:47 +0800 Subject: [PATCH 4/8] Add tests. --- tests/cpp/common/test_hist_util.cu | 53 ++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 36c18f00a9b5..5bec5b2db5c8 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -281,6 +281,59 @@ TEST(HistUtil, AdapterDeviceSketchMemory) { bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); } +TEST(HistUtil, AdapterSketchBatchMemory) { + int num_columns = 100; + int num_rows = 1000; + int num_bins = 256; + auto x = GenerateRandom(num_rows, num_columns); + auto x_device = thrust::device_vector(x); + auto adapter = AdapterFromData(x_device, num_rows, num_columns); + + dh::GlobalMemoryLogger().Clear(); + ConsoleLogger::Configure({{"verbosity", "3"}}); + common::HistogramCuts batched_cuts; + SketchContainer sketch_container(num_bins, num_columns, num_rows); + AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits::quiet_NaN(), + 0, &sketch_container); + ConsoleLogger::Configure({{"verbosity", "0"}}); + size_t bytes_num_elements = num_rows * num_columns * sizeof(Entry); + size_t bytes_num_columns = (num_columns + 1) * sizeof(size_t); + size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * + sizeof(DenseCuts::WQSketch::Entry); + size_t bytes_constant = 1000; + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), + bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); +} + +TEST(HistUtil, AdapterSketchBatchWeightedMemory) { + int num_columns = 100; + int num_rows = 1000; + int num_bins = 256; + auto x = GenerateRandom(num_rows, num_columns); + auto x_device = thrust::device_vector(x); + auto adapter = AdapterFromData(x_device, num_rows, num_columns); + MetaInfo info; + auto& h_weights = info.weights_.HostVector(); + h_weights.resize(num_rows); + std::fill(h_weights.begin(), h_weights.end(), 1.0f); + + dh::GlobalMemoryLogger().Clear(); + ConsoleLogger::Configure({{"verbosity", "3"}}); + common::HistogramCuts batched_cuts; + SketchContainer sketch_container(num_bins, num_columns, num_rows); + AdapterDeviceSketchWeighted(adapter.Value(), num_bins, info, + std::numeric_limits::quiet_NaN(), 0, + &sketch_container); + ConsoleLogger::Configure({{"verbosity", "0"}}); + + size_t bytes_num_elements = + num_rows * num_columns * (sizeof(Entry) + sizeof(float)); + size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * + sizeof(DenseCuts::WQSketch::Entry); + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), + size_t((bytes_num_elements + bytes_cuts) * 1.05)); +} + TEST(HistUtil, AdapterDeviceSketchCategorical) { int categorical_sizes[] = {2, 6, 8, 12}; int num_bins = 256; From 427f96404281bb5d5a7c54ac1f6ac80d94a88d95 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 11 Jun 2020 13:15:39 +0800 Subject: [PATCH 5/8] Weighted calculation. --- src/common/hist_util.cuh | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index e18cfae41cf7..719257752a1b 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -136,12 +136,11 @@ void GetColumnSizesScan(int device, size_t num_columns, column_sizes_scan->end(), column_sizes_scan->begin()); } -template -size_t SketchBatchNumElements(AdapterBatch batch, size_t sketch_batch_num_elements, - size_t columns, int device, - size_t num_cuts) { +inline size_t SketchBatchNumElements(size_t sketch_batch_num_elements, + size_t columns, int device, + size_t num_cuts, bool has_weight) { if (sketch_batch_num_elements == 0) { - int bytes_per_element = 16; + size_t bytes_per_element = has_weight ? 24 : 16; size_t bytes_cuts = num_cuts * columns * sizeof(SketchEntry); size_t bytes_num_columns = (columns + 1) * sizeof(size_t); // use up to 80% of available space @@ -329,8 +328,8 @@ HistogramCuts AdapterDeviceSketch(AdapterT* adapter, int num_bins, adapter->Next(); auto& batch = adapter->Value(); sketch_batch_num_elements = SketchBatchNumElements( - batch, sketch_batch_num_elements, - adapter->NumColumns(), adapter->DeviceIdx(), num_cuts); + sketch_batch_num_elements, + adapter->NumColumns(), adapter->DeviceIdx(), num_cuts, false); // Enforce single batch CHECK(!adapter->Next()); @@ -361,8 +360,8 @@ void AdapterDeviceSketch(Batch batch, int num_bins, size_t num_cols = batch.NumCols(); size_t num_cuts = RequiredSampleCuts(num_bins, num_rows); sketch_batch_num_elements = SketchBatchNumElements( - batch, sketch_batch_num_elements, - num_cols, device, num_cuts); + sketch_batch_num_elements, + num_cols, device, num_cuts, false); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); ProcessSlidingWindow(batch, device, num_cols, @@ -381,8 +380,8 @@ void AdapterDeviceSketchWeighted(Batch batch, int num_bins, size_t num_cols = batch.NumCols(); size_t num_cuts = RequiredSampleCuts(num_bins, num_rows); sketch_batch_num_elements = SketchBatchNumElements( - batch, sketch_batch_num_elements, - num_cols, device, num_cuts); + sketch_batch_num_elements, + num_cols, device, num_cuts, true); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); ProcessWeightedSlidingWindow(batch, info, From 1f96d6b419164dc57519a795037428e9ac2f4361 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 11 Jun 2020 14:13:56 +0800 Subject: [PATCH 6/8] Memory bound tests for sketching. --- src/common/hist_util.cu | 10 +++------- src/common/hist_util.cuh | 3 ++- tests/cpp/common/test_hist_util.cu | 15 ++++++++++++--- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index c3a5e5f60bf5..3da728f677da 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -227,13 +227,9 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, // Configure batch size based on available memory bool has_weights = dmat->Info().weights_.Size() > 0; size_t num_cuts_per_feature = RequiredSampleCuts(max_bins, dmat->Info().num_row_); - if (sketch_batch_num_elements == 0) { - int bytes_per_element = has_weights ? 24 : 16; - size_t bytes_cuts = num_cuts_per_feature * dmat->Info().num_col_ * sizeof(SketchEntry); - // use up to 80% of available space - sketch_batch_num_elements = - (dh::AvailableMemory(device) - bytes_cuts) * 0.8 / bytes_per_element; - } + sketch_batch_num_elements = SketchBatchNumElements( + sketch_batch_num_elements, + dmat->Info().num_col_, device, num_cuts_per_feature, has_weights); HistogramCuts cuts; DenseCuts dense_cuts(&cuts); diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 719257752a1b..202fdfb2fd0a 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -140,7 +140,8 @@ inline size_t SketchBatchNumElements(size_t sketch_batch_num_elements, size_t columns, int device, size_t num_cuts, bool has_weight) { if (sketch_batch_num_elements == 0) { - size_t bytes_per_element = has_weight ? 24 : 16; + // Double the memory usage for sorting. + size_t bytes_per_element = (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2; size_t bytes_cuts = num_cuts * columns * sizeof(SketchEntry); size_t bytes_num_columns = (columns + 1) * sizeof(size_t); // use up to 80% of available space diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 5bec5b2db5c8..7cb596f4edaf 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -71,12 +71,14 @@ TEST(HistUtil, DeviceSketchMemory) { auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); ConsoleLogger::Configure({{"verbosity", "0"}}); - size_t bytes_num_elements = num_rows * num_columns*sizeof(Entry); + size_t bytes_num_elements = num_rows * num_columns * sizeof(Entry); size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * sizeof(DenseCuts::WQSketch::Entry); size_t bytes_constant = 1000; EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_num_elements + bytes_cuts + bytes_constant); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), + bytes_num_elements + bytes_cuts); } TEST(HistUtil, DeviceSketchMemoryWeights) { @@ -98,6 +100,8 @@ TEST(HistUtil, DeviceSketchMemoryWeights) { sizeof(DenseCuts::WQSketch::Entry); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), size_t((bytes_num_elements + bytes_cuts) * 1.05)); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), + size_t((bytes_num_elements + bytes_cuts))); } TEST(HistUtil, DeviceSketchDeterminism) { @@ -278,7 +282,9 @@ TEST(HistUtil, AdapterDeviceSketchMemory) { sizeof(DenseCuts::WQSketch::Entry); size_t bytes_constant = 1000; EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), - bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); + bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), + bytes_num_elements + bytes_cuts + bytes_num_columns); } TEST(HistUtil, AdapterSketchBatchMemory) { @@ -302,7 +308,9 @@ TEST(HistUtil, AdapterSketchBatchMemory) { sizeof(DenseCuts::WQSketch::Entry); size_t bytes_constant = 1000; EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), - bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); + bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), + bytes_num_elements + bytes_cuts + bytes_num_columns); } TEST(HistUtil, AdapterSketchBatchWeightedMemory) { @@ -332,6 +340,7 @@ TEST(HistUtil, AdapterSketchBatchWeightedMemory) { sizeof(DenseCuts::WQSketch::Entry); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), size_t((bytes_num_elements + bytes_cuts) * 1.05)); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), (bytes_num_elements + bytes_cuts)); } TEST(HistUtil, AdapterDeviceSketchCategorical) { From f6ec7bcb94823e57e84369dfac683dc17f79d0b5 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 11 Jun 2020 14:18:17 +0800 Subject: [PATCH 7/8] Remove duplicated kfactor. --- src/common/hist_util.cu | 3 +++ src/common/hist_util.cuh | 4 ++-- tests/cpp/common/test_hist_util.cu | 3 +-- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 3da728f677da..8c8e92aee721 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -28,6 +28,9 @@ namespace xgboost { namespace common { + +constexpr float SketchContainer::kFactor; + // Count the entries in each column and exclusive scan void ExtractCuts(int device, size_t num_cuts_per_feature, diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 202fdfb2fd0a..1a0338cb02b3 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -27,6 +27,7 @@ using SketchEntry = WQSketch::Entry; struct SketchContainer { std::vector sketches_; // NOLINT static constexpr int kOmpNumColsParallelizeLimit = 1000; + static constexpr float kFactor = 8; SketchContainer(int max_bin, size_t num_columns, size_t num_rows) { // Initialize Sketches for this dmatrix @@ -156,8 +157,7 @@ inline size_t SketchBatchNumElements(size_t sketch_batch_num_elements, // Compute number of sample cuts needed on local node to maintain accuracy // We take more cuts than needed and then reduce them later inline size_t RequiredSampleCuts(int max_bins, size_t num_rows) { - constexpr int kFactor = 8; - double eps = 1.0 / (kFactor * max_bins); + double eps = 1.0 / (SketchContainer::kFactor * max_bins); size_t dummy_nlevel; size_t num_cuts; WQuantileSketch::LimitSizeLevel( diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 7cb596f4edaf..93a5c51601bf 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -50,8 +50,7 @@ TEST(HistUtil, DeviceSketch) { // Duplicate this function from hist_util.cu so we don't have to expose it in // header size_t RequiredSampleCutsTest(int max_bins, size_t num_rows) { - constexpr int kFactor = 8; - double eps = 1.0 / (kFactor * max_bins); + double eps = 1.0 / (SketchContainer::kFactor * max_bins); size_t dummy_nlevel; size_t num_cuts; WQuantileSketch::LimitSizeLevel( From 361327c1a216401bbdf9e498ef3b16b1c2db88dd Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 11 Jun 2020 15:10:19 +0800 Subject: [PATCH 8/8] Correct tests. --- src/common/hist_util.cu | 4 +- src/common/hist_util.cuh | 9 ++++- tests/cpp/common/test_hist_util.cu | 64 +++++++++++------------------- 3 files changed, 33 insertions(+), 44 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 8c8e92aee721..f7744c7884cb 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -243,12 +243,12 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, for (const auto& batch : dmat->GetBatches()) { size_t batch_nnz = batch.data.Size(); auto const& info = dmat->Info(); - dh::caching_device_vector groups(info.group_ptr_.cbegin(), - info.group_ptr_.cend()); for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) { size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements)); if (has_weights) { bool is_ranking = CutsBuilder::UseGroup(dmat); + dh::caching_device_vector groups(info.group_ptr_.cbegin(), + info.group_ptr_.cend()); ProcessWeightedBatch( device, batch, dmat->Info().weights_.ConstDeviceSpan(), begin, end, &sketch_container, diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 1a0338cb02b3..6f8d1e52206d 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -137,12 +137,17 @@ void GetColumnSizesScan(int device, size_t num_columns, column_sizes_scan->end(), column_sizes_scan->begin()); } +inline size_t BytesPerElement(bool has_weight) { + // Double the memory usage for sorting. We need to assign weight for each element, so + // sizeof(float) is added to all elements. + return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2; +} + inline size_t SketchBatchNumElements(size_t sketch_batch_num_elements, size_t columns, int device, size_t num_cuts, bool has_weight) { if (sketch_batch_num_elements == 0) { - // Double the memory usage for sorting. - size_t bytes_per_element = (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2; + size_t bytes_per_element = BytesPerElement(has_weight); size_t bytes_cuts = num_cuts * columns * sizeof(SketchEntry); size_t bytes_num_columns = (columns + 1) * sizeof(size_t); // use up to 80% of available space diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 93a5c51601bf..14ed78ddaea3 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -58,6 +58,15 @@ size_t RequiredSampleCutsTest(int max_bins, size_t num_rows) { return std::min(num_cuts, num_rows); } +size_t BytesRequiredForTest(size_t num_rows, size_t num_columns, size_t num_bins, + bool with_weights) { + size_t bytes_num_elements = BytesPerElement(with_weights) * num_rows * num_columns; + size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * + sizeof(DenseCuts::WQSketch::Entry); + // divide by 2 is because the memory quota used in sorting is reused for storing cuts. + return bytes_num_elements / 2 + bytes_cuts; +} + TEST(HistUtil, DeviceSketchMemory) { int num_columns = 100; int num_rows = 1000; @@ -70,14 +79,10 @@ TEST(HistUtil, DeviceSketchMemory) { auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); ConsoleLogger::Configure({{"verbosity", "0"}}); - size_t bytes_num_elements = num_rows * num_columns * sizeof(Entry); - size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * - sizeof(DenseCuts::WQSketch::Entry); + size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false); size_t bytes_constant = 1000; - EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), - bytes_num_elements + bytes_cuts + bytes_constant); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), - bytes_num_elements + bytes_cuts); + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); } TEST(HistUtil, DeviceSketchMemoryWeights) { @@ -93,14 +98,9 @@ TEST(HistUtil, DeviceSketchMemoryWeights) { auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); ConsoleLogger::Configure({{"verbosity", "0"}}); - size_t bytes_num_elements = - num_rows * num_columns * (sizeof(Entry) + sizeof(float)); - size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * - sizeof(DenseCuts::WQSketch::Entry); - EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), - size_t((bytes_num_elements + bytes_cuts) * 1.05)); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), - size_t((bytes_num_elements + bytes_cuts))); + size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true); + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); } TEST(HistUtil, DeviceSketchDeterminism) { @@ -274,16 +274,10 @@ TEST(HistUtil, AdapterDeviceSketchMemory) { auto cuts = AdapterDeviceSketch(&adapter, num_bins, std::numeric_limits::quiet_NaN()); ConsoleLogger::Configure({{"verbosity", "0"}}); - - size_t bytes_num_elements = num_rows * num_columns * sizeof(Entry); - size_t bytes_num_columns = (num_columns + 1) * sizeof(size_t); - size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * - sizeof(DenseCuts::WQSketch::Entry); size_t bytes_constant = 1000; - EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), - bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), - bytes_num_elements + bytes_cuts + bytes_num_columns); + size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false); + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); } TEST(HistUtil, AdapterSketchBatchMemory) { @@ -301,15 +295,10 @@ TEST(HistUtil, AdapterSketchBatchMemory) { AdapterDeviceSketch(adapter.Value(), num_bins, std::numeric_limits::quiet_NaN(), 0, &sketch_container); ConsoleLogger::Configure({{"verbosity", "0"}}); - size_t bytes_num_elements = num_rows * num_columns * sizeof(Entry); - size_t bytes_num_columns = (num_columns + 1) * sizeof(size_t); - size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * - sizeof(DenseCuts::WQSketch::Entry); size_t bytes_constant = 1000; - EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), - bytes_num_elements + bytes_cuts + bytes_num_columns + bytes_constant); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), - bytes_num_elements + bytes_cuts + bytes_num_columns); + size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, false); + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required + bytes_constant); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); } TEST(HistUtil, AdapterSketchBatchWeightedMemory) { @@ -332,14 +321,9 @@ TEST(HistUtil, AdapterSketchBatchWeightedMemory) { std::numeric_limits::quiet_NaN(), 0, &sketch_container); ConsoleLogger::Configure({{"verbosity", "0"}}); - - size_t bytes_num_elements = - num_rows * num_columns * (sizeof(Entry) + sizeof(float)); - size_t bytes_cuts = RequiredSampleCutsTest(num_bins, num_rows) * num_columns * - sizeof(DenseCuts::WQSketch::Entry); - EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), - size_t((bytes_num_elements + bytes_cuts) * 1.05)); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), (bytes_num_elements + bytes_cuts)); + size_t bytes_required = BytesRequiredForTest(num_rows, num_columns, num_bins, true); + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); } TEST(HistUtil, AdapterDeviceSketchCategorical) {