Skip to content

Commit

Permalink
Using DistConfig in inference
Browse files Browse the repository at this point in the history
  • Loading branch information
TeslaZhao committed Mar 30, 2022
1 parent 4df7224 commit 4246afc
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 124 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/distributed/ps/table/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@ ${RPC_DEPS} graph_edge graph_node device_context string_helper
simple_threadpool xxhash generator ${EXTERN_DEP})

set_source_files_properties(tensor_accessor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(tensor_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

cc_library(tensor_accessor SRCS tensor_accessor.cc DEPS ${TABLE_DEPS} eigen3 ps_framework_proto device_context)
cc_library(tensor_table SRCS tensor_table.cc DEPS eigen3 ps_framework_proto executor scope device_context tensor ${TABLE_DEPS})
cc_library(tensor_table SRCS DEPS eigen3 ps_framework_proto executor scope device_context tensor ${TABLE_DEPS})
set_source_files_properties(table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})

set_source_files_properties(sparse_sgd_rule.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
Expand All @@ -54,6 +53,7 @@ cc_library(memory_sparse_table SRCS memory_sparse_table.cc DEPS ps_framework_pro

set_source_files_properties(memory_sparse_geo_table.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_library(memory_sparse_geo_table SRCS memory_sparse_geo_table.cc DEPS ps_framework_proto ${TABLE_DEPS} common_table)

cc_library(table SRCS table.cc DEPS memory_sparse_table memory_sparse_geo_table common_table tensor_accessor tensor_table ps_framework_proto string_helper device_context gflags glog boost)

target_link_libraries(table -fopenmp)
107 changes: 1 addition & 106 deletions paddle/fluid/distributed/ps/table/tensor_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,110 +16,5 @@

DECLARE_double(eager_delete_tensor_gb);
namespace paddle {
namespace distributed {

int32_t TensorTable::set_program_env(
framework::Scope *scope, platform::Place place,
const std::vector<framework::ProgramDesc> *sub_program) {
scope_ = scope;
place_ = place;
executor_ = new framework::Executor(place_);
sub_program_ = sub_program;
return 0;
}

int32_t GlobalStepTable::initialize() {
auto _program_config = _config.tensor();
auto trainers_ = _config.common().trainer_num();
FLAGS_eager_delete_tensor_gb = -1;
// Get Config
if (_program_config.has_startup_program_id()) {
startup_program_id_ = _program_config.startup_program_id();
}
if (_program_config.has_main_program_id()) {
main_program_id_ = _program_config.main_program_id();
}
if (_program_config.has_feed_var_name()) {
feed_var_name_ = _program_config.feed_var_name();
}
if (_program_config.has_fetch_var_name()) {
fetch_var_name_ = _program_config.fetch_var_name();
}

// Run startup program
if (startup_program_id_ != -1) {
std::map<std::string, const framework::LoDTensor *> fake_feed;
std::map<std::string, framework::FetchType *> fake_fetch;
auto startup_program_desc = sub_program_->at(startup_program_id_);
auto ctx = executor_->Prepare(startup_program_desc, 0);
executor_->RunPreparedContext(ctx.get(), scope_, false);
}

if (main_program_id_ != -1) {
// Run main porgram, if program is used for learning decay
auto main_program_desc = sub_program_->at(main_program_id_);
auto main_ctx = executor_->Prepare(main_program_desc, 0);
exec_context_ = std::move(main_ctx);
executor_->RunPreparedContext(exec_context_.get(), scope_, false);
// init decay_counters
decay_counters_.reserve(trainers_);
for (int32_t i = 0; i < trainers_; ++i) {
decay_counters_[i] = 0;
}
}

return 0;
}

int32_t GlobalStepTable::set_table_map(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) {
auto *lr_var = scope_->FindVar(fetch_var_name_);
auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>();
auto *lr_value = lr_tensor->mutable_data<float>(platform::CPUPlace());
VLOG(3) << "GlobalStepTable::set_table_map set global lr: " << *lr_value;

for (auto iter = table_map->begin(); iter != table_map->end(); iter++) {
auto table_id = iter->first;
if (table_id == _config.table_id()) {
continue;
}
iter->second->set_global_lr(lr_value);
}
return 0;
}

int32_t GlobalStepTable::push_dense(const int64_t *values,
const int32_t trainer_id) {
return _run_program(values, trainer_id);
}

int32_t GlobalStepTable::_run_program(const int64_t *values,
const uint32_t trainer_id) {
FLAGS_eager_delete_tensor_gb = -1;
auto counter = decay_counters_.at(trainer_id);
counter += int(values[0]);
decay_counters_.at(trainer_id) = counter;

auto *global_step_var = scope_->FindVar(feed_var_name_);
auto *tensor = global_step_var->GetMutable<framework::LoDTensor>();
auto *value = tensor->mutable_data<int64_t>(platform::CPUPlace());

auto global_counter = 0;
for (auto &trainer_counter : decay_counters_) {
global_counter += trainer_counter.second;
}

// Todo: hard code for increment op
value[0] = global_counter - 1;
VLOG(3) << "GlobalStepTable::_run_program global_counter " << value[0];

executor_->RunPreparedContext(exec_context_.get(), scope_, false, false);
auto *lr_var = scope_->FindVar(fetch_var_name_);
auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>();
auto *lr_value = lr_tensor->mutable_data<float>(platform::CPUPlace());
VLOG(3) << "GlobalStepTable::LR value: " << lr_value[0];
return 0;
}

} // namespace distributed
namespace distributed {} // namespace distributed
} // namespace paddle
117 changes: 102 additions & 15 deletions paddle/fluid/distributed/ps/table/tensor_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ struct ExecutorPrepareContext;
} // namespace framework
} // namespace paddle

DECLARE_double(eager_delete_tensor_gb);

namespace paddle {
namespace distributed {

Expand Down Expand Up @@ -66,9 +68,9 @@ class TensorTable : public Table {

virtual void *get_shard(size_t shard_idx) { return 0; }

virtual int32_t initialize_shard() { return 0; };
virtual int32_t initialize_shard() { return 0; }

virtual int32_t flush() { return 0; };
virtual int32_t flush() { return 0; }

virtual int32_t load(const std::string &path, const std::string &param) {
return 0;
Expand All @@ -77,18 +79,23 @@ class TensorTable : public Table {
return 0;
}

virtual void clear(){};
virtual void clear() {}

virtual int32_t initialize() override { return 0; };
int32_t initialize() override { return 0; }

virtual int32_t push_dense(const int64_t *values,
const int32_t trainer_id) override {
int32_t push_dense(const int64_t *values, const int32_t trainer_id) override {
return 0;
};
}

virtual int32_t set_program_env(
int32_t set_program_env(
framework::Scope *scope, platform::Place place,
const std::vector<framework::ProgramDesc> *sub_program) override;
const std::vector<framework::ProgramDesc> *sub_program) override {
scope_ = scope;
place_ = place;
executor_ = new framework::Executor(place_);
sub_program_ = sub_program;
return 0;
}

protected:
framework::Executor *executor_;
Expand Down Expand Up @@ -135,7 +142,7 @@ class DenseTensorTable : public TensorTable {

/*----------------------------------------------------------------------*/

virtual int32_t initialize() override { return 0; }
int32_t initialize() override { return 0; }

int32_t push_dense(const float *values, size_t num) override { return 0; }

Expand Down Expand Up @@ -189,18 +196,98 @@ class GlobalStepTable : public DenseTensorTable {

/*----------------------------------------------------------------------*/

int32_t initialize() override;
int32_t initialize() override {
auto _program_config = _config.tensor();
auto trainers_ = _config.common().trainer_num();
FLAGS_eager_delete_tensor_gb = -1;
// Get Config
if (_program_config.has_startup_program_id()) {
startup_program_id_ = _program_config.startup_program_id();
}
if (_program_config.has_main_program_id()) {
main_program_id_ = _program_config.main_program_id();
}
if (_program_config.has_feed_var_name()) {
feed_var_name_ = _program_config.feed_var_name();
}
if (_program_config.has_fetch_var_name()) {
fetch_var_name_ = _program_config.fetch_var_name();
}

// Run startup program
if (startup_program_id_ != -1) {
std::map<std::string, const framework::LoDTensor *> fake_feed;
std::map<std::string, framework::FetchType *> fake_fetch;
auto startup_program_desc = sub_program_->at(startup_program_id_);
auto ctx = executor_->Prepare(startup_program_desc, 0);
executor_->RunPreparedContext(ctx.get(), scope_, false);
}

if (main_program_id_ != -1) {
// Run main porgram, if program is used for learning decay
auto main_program_desc = sub_program_->at(main_program_id_);
auto main_ctx = executor_->Prepare(main_program_desc, 0);
exec_context_ = std::move(main_ctx);
executor_->RunPreparedContext(exec_context_.get(), scope_, false);
// init decay_counters
decay_counters_.reserve(trainers_);
for (int32_t i = 0; i < trainers_; ++i) {
decay_counters_[i] = 0;
}
}
}

int32_t push_dense(const float *values, size_t num) override { return 0; }

int32_t push_dense(const int64_t *values, const int32_t trainer_id);
int32_t push_dense(const int64_t *values, const int32_t trainer_id) {
return _run_program(values, trainer_id);
}

int32_t set_table_map(
std::unordered_map<uint32_t, std::shared_ptr<Table>> *table_map) override;
int32_t set_table_map(std::unordered_map<uint32_t, std::shared_ptr<Table>>
*table_map) override {
auto *lr_var = scope_->FindVar(fetch_var_name_);
auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>();
auto *lr_value = lr_tensor->mutable_data<float>(platform::CPUPlace());
VLOG(3) << "GlobalStepTable::set_table_map set global lr: " << *lr_value;

for (auto iter = table_map->begin(); iter != table_map->end(); iter++) {
auto table_id = iter->first;
if (table_id == _config.table_id()) {
continue;
}
iter->second->set_global_lr(lr_value);
}
return 0;
}

private:
virtual int32_t _run_program(const int64_t *values,
const uint32_t trainer_id);
const uint32_t trainer_id) {
FLAGS_eager_delete_tensor_gb = -1;
auto counter = decay_counters_.at(trainer_id);
counter += int(values[0]);
decay_counters_.at(trainer_id) = counter;

auto *global_step_var = scope_->FindVar(feed_var_name_);
auto *tensor = global_step_var->GetMutable<framework::LoDTensor>();
auto *value = tensor->mutable_data<int64_t>(platform::CPUPlace());

auto global_counter = 0;
for (auto &trainer_counter : decay_counters_) {
global_counter += trainer_counter.second;
}

// Todo: hard code for increment op
value[0] = global_counter - 1;
VLOG(3) << "GlobalStepTable::_run_program global_counter " << value[0];

executor_->RunPreparedContext(exec_context_.get(), scope_, false, false);
auto *lr_var = scope_->FindVar(fetch_var_name_);
auto *lr_tensor = lr_var->GetMutable<framework::LoDTensor>();
auto *lr_value = lr_tensor->mutable_data<float>(platform::CPUPlace());
VLOG(3) << "GlobalStepTable::LR value: " << lr_value[0];
return 0;
}

private:
std::unordered_map<int, int64_t> decay_counters_;
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/inference/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ if (WITH_CRYPTO)
endif (WITH_CRYPTO)

if (WITH_PSCORE)
set(SHARED_INFERENCE_DEPS ${SHARED_INFERENCE_DEPS} fleet ps_service)
set(SHARED_INFERENCE_DEPS ${SHARED_INFERENCE_DEPS} fleet ps_service tensor_table)
endif ()

if (WITH_ONNXRUNTIME)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/inference/api/paddle_inference_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ namespace paddle_infer {

using PrecisionType = paddle::AnalysisConfig::Precision;
using Config = paddle::AnalysisConfig;
using DistConfig = paddle::DistConfig;

///
/// \class Predictor
Expand Down

1 comment on commit 4246afc

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.