Skip to content

Commit

Permalink
add set_gpu_graph_mode;test=develop (PaddlePaddle#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
danleifeng authored Jul 27, 2022
1 parent 416c558 commit bba158c
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 2 deletions.
12 changes: 12 additions & 0 deletions paddle/fluid/framework/data_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,18 @@ void DatasetImpl<T>::SetFeaEval(bool fea_eval, int record_candidate_size) {
<< " with record candidate size: " << record_candidate_size;
}

template <typename T>
void DatasetImpl<T>::SetGpuGraphMode(int is_graph_mode) {
gpu_graph_mode_ = is_graph_mode;
}


template <typename T>
int DatasetImpl<T>::GetGpuGraphMode() {
return gpu_graph_mode_;
}


template <typename T>
std::vector<paddle::framework::DataFeed*> DatasetImpl<T>::GetReaders() {
std::vector<paddle::framework::DataFeed*> ret;
Expand Down
7 changes: 6 additions & 1 deletion paddle/fluid/framework/data_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,9 @@ class Dataset {

virtual std::vector<std::string> GetSlots() = 0;

virtual void SetGpuGraphMode(int is_graph_mode) = 0;
virtual int GetGpuGraphMode() = 0;

protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0;
Expand Down Expand Up @@ -210,6 +213,8 @@ class DatasetImpl : public Dataset {
virtual std::pair<std::string, std::string> GetHdfsConfig() {
return std::make_pair(fs_name_, fs_ugi_);
}
virtual void SetGpuGraphMode(int is_graph_mode);
virtual int GetGpuGraphMode();
virtual std::string GetDownloadCmd();
virtual const paddle::framework::DataFeedDesc& GetDataFeedDesc() {
return data_feed_desc_;
Expand Down Expand Up @@ -331,7 +336,7 @@ class DatasetImpl : public Dataset {
std::vector<T> input_records_; // only for paddleboxdatafeed
std::vector<std::string> use_slots_;
bool enable_heterps_ = false;
int gpu_graph_mode_ = 1;
int gpu_graph_mode_ = 0;
// std::vector<std::vector<int64_t>> gpu_graph_device_keys_;
std::vector<std::vector<std::vector<uint64_t>>> graph_all_type_total_keys_;
std::vector<uint64_t> gpu_graph_total_keys_;
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/fleet/ps_gpu_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,7 @@ void PSGPUWrapper::LoadIntoMemory(bool is_shuffle) {
dataset_->LocalShuffle();
}
InitSlotInfo();
gpu_graph_mode_ = dataset_->GetGpuGraphMode();
std::shared_ptr<HeterContext> gpu_task = gpu_task_pool_.Get();
gpu_task->Reset();

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/ps_gpu_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -613,7 +613,7 @@ class PSGPUWrapper {
int multi_node_{0};
int node_size_;
uint64_t table_id_;
int gpu_graph_mode_ = 1;
int gpu_graph_mode_ = 0;
#ifdef PADDLE_WITH_CUDA
std::vector<ncclComm_t> inner_comms_;
std::vector<ncclComm_t> inter_comms_;
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/data_set_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,9 @@ void BindDataset(py::module *m) {
&framework::Dataset::SetFleetSendSleepSeconds,
py::call_guard<py::gil_scoped_release>())
.def("enable_pv_merge", &framework::Dataset::EnablePvMerge,
py::call_guard<py::gil_scoped_release>())
.def("set_gpu_graph_mode",
&framework::Dataset::SetGpuGraphMode,
py::call_guard<py::gil_scoped_release>());

py::class_<IterableDatasetWrapper>(*m, "IterableDatasetWrapper")
Expand Down
1 change: 1 addition & 0 deletions python/paddle/fluid/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1064,6 +1064,7 @@ def set_graph_config(self, config):
self.proto_desc.graph_config.meta_path = config.get("meta_path", "")
self.proto_desc.graph_config.gpu_graph_training = config.get(
"gpu_graph_training", True)
self.dataset.set_gpu_graph_mode(True)


class QueueDataset(DatasetBase):
Expand Down

0 comments on commit bba158c

Please sign in to comment.