Skip to content

Commit

Permalink
Purge device_helpers.cuh (#5534)
Browse files Browse the repository at this point in the history
* Simplifications with caching_device_vector

* Purge device helpers
  • Loading branch information
RAMitchell authored Apr 15, 2020
1 parent a2f5496 commit ca4e056
Show file tree
Hide file tree
Showing 9 changed files with 185 additions and 736 deletions.
557 changes: 44 additions & 513 deletions src/common/device_helpers.cuh

Large diffs are not rendered by default.

24 changes: 11 additions & 13 deletions src/linear/updater_gpu_coordinate.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,14 +86,13 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
std::make_pair(column_begin - col.cbegin(), column_end - col.cbegin()));
row_ptr_.push_back(row_ptr_.back() + (column_end - column_begin));
}
ba_.Allocate(learner_param_->gpu_id, &data_, row_ptr_.back(), &gpair_,
num_row_ * model_param.num_output_group);

data_.resize(row_ptr_.back());
gpair_.resize(num_row_ * model_param.num_output_group);
for (size_t fidx = 0; fidx < batch.Size(); fidx++) {
auto col = batch[fidx];
auto seg = column_segments[fidx];
dh::safe_cuda(cudaMemcpy(
data_.subspan(row_ptr_[fidx]).data(),
data_.data().get() + row_ptr_[fidx],
col.data() + seg.first,
sizeof(Entry) * (seg.second - seg.first), cudaMemcpyHostToDevice));
}
Expand Down Expand Up @@ -192,7 +191,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
// This needs to be public because of the __device__ lambda.
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
if (dbias == 0.0f) return;
auto d_gpair = gpair_;
auto d_gpair = dh::ToSpan(gpair_);
dh::LaunchN(learner_param_->gpu_id, num_row_, [=] __device__(size_t idx) {
auto &g = d_gpair[idx * num_groups + group_idx];
g += GradientPair(g.GetHess() * dbias, 0);
Expand All @@ -202,9 +201,9 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
// This needs to be public because of the __device__ lambda.
GradientPair GetGradient(int group_idx, int num_group, int fidx) {
dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id));
common::Span<xgboost::Entry> d_col = data_.subspan(row_ptr_[fidx]);
common::Span<xgboost::Entry> d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]);
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
common::Span<GradientPair> d_gpair = gpair_;
common::Span<GradientPair> d_gpair = dh::ToSpan(gpair_);
auto counting = thrust::make_counting_iterator(0ull);
auto f = [=] __device__(size_t idx) {
auto entry = d_col[idx];
Expand All @@ -219,8 +218,8 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT

// This needs to be public because of the __device__ lambda.
void UpdateResidual(float dw, int group_idx, int num_groups, int fidx) {
common::Span<GradientPair> d_gpair = gpair_;
common::Span<Entry> d_col = data_.subspan(row_ptr_[fidx]);
common::Span<GradientPair> d_gpair = dh::ToSpan(gpair_);
common::Span<Entry> d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]);
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
dh::LaunchN(learner_param_->gpu_id, col_size, [=] __device__(size_t idx) {
auto entry = d_col[idx];
Expand All @@ -236,7 +235,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT

void UpdateGpair(const std::vector<GradientPair> &host_gpair) {
dh::safe_cuda(cudaMemcpyAsync(
gpair_.data(),
gpair_.data().get(),
host_gpair.data(),
gpair_.size() * sizeof(GradientPair), cudaMemcpyHostToDevice));
}
Expand All @@ -247,10 +246,9 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
std::unique_ptr<FeatureSelector> selector_;
common::Monitor monitor_;

dh::BulkAllocator ba_;
std::vector<size_t> row_ptr_;
common::Span<xgboost::Entry> data_;
common::Span<GradientPair> gpair_;
dh::device_vector<xgboost::Entry> data_;
dh::caching_device_vector<GradientPair> gpair_;
dh::CubMemory temp_;
size_t num_row_;
};
Expand Down
84 changes: 43 additions & 41 deletions src/tree/gpu_hist/gradient_based_sampler.cu
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ ExternalMemoryUniformSampling::ExternalMemoryUniformSampling(EllpackPageImpl* pa
size_t n_rows,
const BatchParam& batch_param,
float subsample)
: original_page_(page), batch_param_(batch_param), subsample_(subsample) {
ba_.Allocate(batch_param_.gpu_id, &sample_row_index_, n_rows);
}
: original_page_(page),
batch_param_(batch_param),
subsample_(subsample),
sample_row_index_(n_rows) {}

GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
Expand All @@ -207,12 +208,12 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());

// Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero());
thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_),
dh::tbegin(sample_row_index_));
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero());
thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(),
sample_row_index_.begin());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(sample_row_index_),
dh::tbegin(sample_row_index_),
sample_row_index_.begin(),
sample_row_index_.begin(),
ClearEmptyRows());

// Create a new ELLPACK page with empty rows.
Expand All @@ -224,7 +225,7 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_);
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
}

return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
Expand All @@ -233,23 +234,23 @@ GradientBasedSample ExternalMemoryUniformSampling::Sample(common::Span<GradientP
GradientBasedSampling::GradientBasedSampling(EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample) : page_(page), subsample_(subsample) {
ba_.Allocate(batch_param.gpu_id,
&threshold_, n_rows + 1,
&grad_sum_, n_rows);
}
float subsample)
: page_(page),
subsample_(subsample),
threshold_(n_rows + 1, 0.0f),
grad_sum_(n_rows, 0.0f) {}

GradientBasedSample GradientBasedSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
size_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, threshold_, grad_sum_, n_rows * subsample_);
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);

// Perform Poisson sampling in place.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
dh::tbegin(gpair),
PoissonSampling(threshold_,
PoissonSampling(dh::ToSpan(threshold_),
threshold_index,
RandomWeight(common::GlobalRandom()())));
return {n_rows, page_, gpair};
Expand All @@ -259,24 +260,25 @@ ExternalMemoryGradientBasedSampling::ExternalMemoryGradientBasedSampling(
EllpackPageImpl* page,
size_t n_rows,
const BatchParam& batch_param,
float subsample) : original_page_(page), batch_param_(batch_param), subsample_(subsample) {
ba_.Allocate(batch_param.gpu_id,
&threshold_, n_rows + 1,
&grad_sum_, n_rows,
&sample_row_index_, n_rows);
}
float subsample)
: original_page_(page),
batch_param_(batch_param),
subsample_(subsample),
threshold_(n_rows + 1, 0.0f),
grad_sum_(n_rows, 0.0f),
sample_row_index_(n_rows) {}

GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<GradientPair> gpair,
DMatrix* dmat) {
size_t n_rows = dmat->Info().num_row_;
size_t threshold_index = GradientBasedSampler::CalculateThresholdIndex(
gpair, threshold_, grad_sum_, n_rows * subsample_);
gpair, dh::ToSpan(threshold_), dh::ToSpan(grad_sum_), n_rows * subsample_);

// Perform Poisson sampling in place.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
thrust::counting_iterator<size_t>(0),
dh::tbegin(gpair),
PoissonSampling(threshold_,
PoissonSampling(dh::ToSpan(threshold_),
threshold_index,
RandomWeight(common::GlobalRandom()())));

Expand All @@ -288,12 +290,12 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<Gra
thrust::copy_if(dh::tbegin(gpair), dh::tend(gpair), gpair_.begin(), IsNonZero());

// Index the sample rows.
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(sample_row_index_), IsNonZero());
thrust::exclusive_scan(dh::tbegin(sample_row_index_), dh::tend(sample_row_index_),
dh::tbegin(sample_row_index_));
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), sample_row_index_.begin(), IsNonZero());
thrust::exclusive_scan(sample_row_index_.begin(), sample_row_index_.end(),
sample_row_index_.begin());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(sample_row_index_),
dh::tbegin(sample_row_index_),
sample_row_index_.begin(),
sample_row_index_.begin(),
ClearEmptyRows());

// Create a new ELLPACK page with empty rows.
Expand All @@ -305,7 +307,7 @@ GradientBasedSample ExternalMemoryGradientBasedSampling::Sample(common::Span<Gra
// Compact the ELLPACK pages into the single sample page.
thrust::fill(dh::tbegin(page_->gidx_buffer), dh::tend(page_->gidx_buffer), 0);
for (auto& batch : dmat->GetBatches<EllpackPage>(batch_param_)) {
page_->Compact(batch_param_.gpu_id, batch.Impl(), sample_row_index_);
page_->Compact(batch_param_.gpu_id, batch.Impl(), dh::ToSpan(sample_row_index_));
}

return {sample_rows, page_.get(), dh::ToSpan(gpair_)};
Expand Down Expand Up @@ -358,21 +360,21 @@ GradientBasedSample GradientBasedSampler::Sample(common::Span<GradientPair> gpai
return sample;
}

size_t GradientBasedSampler::CalculateThresholdIndex(common::Span<GradientPair> gpair,
common::Span<float> threshold,
common::Span<float> grad_sum,
size_t sample_rows) {
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold), std::numeric_limits<float>::max());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair),
dh::tbegin(threshold),
size_t GradientBasedSampler::CalculateThresholdIndex(
common::Span<GradientPair> gpair, common::Span<float> threshold,
common::Span<float> grad_sum, size_t sample_rows) {
thrust::fill(dh::tend(threshold) - 1, dh::tend(threshold),
std::numeric_limits<float>::max());
thrust::transform(dh::tbegin(gpair), dh::tend(gpair), dh::tbegin(threshold),
CombineGradientPair());
thrust::sort(dh::tbegin(threshold), dh::tend(threshold) - 1);
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1, dh::tbegin(grad_sum));
thrust::inclusive_scan(dh::tbegin(threshold), dh::tend(threshold) - 1,
dh::tbegin(grad_sum));
thrust::transform(dh::tbegin(grad_sum), dh::tend(grad_sum),
thrust::counting_iterator<size_t>(0),
dh::tbegin(grad_sum),
thrust::counting_iterator<size_t>(0), dh::tbegin(grad_sum),
SampleRateDelta(threshold, gpair.size(), sample_rows));
thrust::device_ptr<float> min = thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
thrust::device_ptr<float> min =
thrust::min_element(dh::tbegin(grad_sum), dh::tend(grad_sum));
return thrust::distance(dh::tbegin(grad_sum), min) + 1;
}

Expand Down
15 changes: 6 additions & 9 deletions src/tree/gpu_hist/gradient_based_sampler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -73,13 +73,12 @@ class ExternalMemoryUniformSampling : public SamplingStrategy {
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;

private:
dh::BulkAllocator ba_;
EllpackPageImpl* original_page_;
BatchParam batch_param_;
float subsample_;
std::unique_ptr<EllpackPageImpl> page_;
dh::device_vector<GradientPair> gpair_{};
common::Span<size_t> sample_row_index_;
dh::caching_device_vector<size_t> sample_row_index_;
};

/*! \brief Gradient-based sampling in in-memory mode.. */
Expand All @@ -94,9 +93,8 @@ class GradientBasedSampling : public SamplingStrategy {
private:
EllpackPageImpl* page_;
float subsample_;
dh::BulkAllocator ba_;
common::Span<float> threshold_;
common::Span<float> grad_sum_;
dh::caching_device_vector<float> threshold_;
dh::caching_device_vector<float> grad_sum_;
};

/*! \brief Gradient-based sampling in external memory mode.. */
Expand All @@ -109,15 +107,14 @@ class ExternalMemoryGradientBasedSampling : public SamplingStrategy {
GradientBasedSample Sample(common::Span<GradientPair> gpair, DMatrix* dmat) override;

private:
dh::BulkAllocator ba_;
EllpackPageImpl* original_page_;
BatchParam batch_param_;
float subsample_;
common::Span<float> threshold_;
common::Span<float> grad_sum_;
dh::caching_device_vector<float> threshold_;
dh::caching_device_vector<float> grad_sum_;
std::unique_ptr<EllpackPageImpl> page_;
dh::device_vector<GradientPair> gpair_;
common::Span<size_t> sample_row_index_;
dh::caching_device_vector<size_t> sample_row_index_;
};

/*! \brief Draw a sample of rows from a DMatrix.
Expand Down
Loading

0 comments on commit ca4e056

Please sign in to comment.