Skip to content

Commit

Permalink
Revert "Base remote/lxch pre stable (PaddlePaddle#30)" (PaddlePaddle#33)
Browse files Browse the repository at this point in the history
This reverts commit e603334.
  • Loading branch information
zmxdream authored Jun 30, 2022
1 parent e603334 commit 8b92330
Show file tree
Hide file tree
Showing 30 changed files with 3,200 additions and 2,111 deletions.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,3 @@ paddle/infrt/tests/lit.cfg.py
paddle/infrt/kernel/phi/infershaped/infershaped_kernel_launchers.cc
paddle/fluid/pybind/eager_final_state_op_function_impl.h
paddle/fluid/pybind/tmp_eager_final_state_op_function_impl.h
builder
3 changes: 2 additions & 1 deletion cmake/external/pslib.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS}
PREFIX ${PSLIB_PREFIX_DIR}
DOWNLOAD_DIR ${PSLIB_DOWNLOAD_DIR}
DOWNLOAD_COMMAND cp /root/paddlejob/new1_code/ps/baidu/paddlepaddle/pslib/pslib.tar.gz ./ && tar zxvf ${PSLIB_NAME}.tar.gz
DOWNLOAD_COMMAND wget --no-check-certificate ${PSLIB_URL} -c -q -O ${PSLIB_NAME}.tar.gz
&& tar zxvf ${PSLIB_NAME}.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${PSLIB_INSTALL_ROOT}
Expand Down
11 changes: 1 addition & 10 deletions paddle/fluid/framework/data_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,7 @@ class Dataset {
virtual void SetFleetSendSleepSeconds(int seconds) = 0;

virtual std::vector<std::string> GetSlots() = 0;
virtual void SetPassId(uint32_t pass_id) = 0;
virtual uint32_t GetPassID() = 0;

protected:
virtual int ReceiveFromClient(int msg_type, int client_id,
const std::string& msg) = 0;
Expand Down Expand Up @@ -250,13 +249,6 @@ class DatasetImpl : public Dataset {
virtual void DynamicAdjustReadersNum(int thread_num);
virtual void SetFleetSendSleepSeconds(int seconds);
virtual std::vector<std::string> GetSlots();
virtual void SetPassId(uint32_t pass_id) {
pass_id_ = pass_id;
}
virtual uint32_t GetPassID() {
return pass_id_;
}

/* for enable_heterps_
virtual void EnableHeterps(bool enable_heterps) {
enable_heterps_ = enable_heterps;
Expand All @@ -283,7 +275,6 @@ class DatasetImpl : public Dataset {
// TODO(yaoxuefeng) for SlotRecordDataset
return -1;
}
uint32_t pass_id_ = 0;
std::vector<std::shared_ptr<paddle::framework::DataFeed>> readers_;
std::vector<std::shared_ptr<paddle::framework::DataFeed>> preload_readers_;
paddle::framework::Channel<T> input_channel_;
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/fleet/fleet_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1388,9 +1388,9 @@ void FleetWrapper::SetDate(const uint64_t table_id, const std::string& date) {
#endif
}

void FleetWrapper::PrintTableStat(const uint64_t table_id, uint32_t pass_id, size_t threshold) {
void FleetWrapper::PrintTableStat(const uint64_t table_id) {
#ifdef PADDLE_WITH_PSLIB
auto ret = pslib_ptr_->_worker_ptr->print_table_stat(table_id, pass_id, threshold);
auto ret = pslib_ptr_->_worker_ptr->print_table_stat(table_id);
ret.wait();
int32_t err_code = ret.get();
if (err_code == -1) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/fleet_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ class FleetWrapper {
std::vector<std::string> table_var_list,
bool load_combine);

void PrintTableStat(const uint64_t table_id, uint32_t pass_id, uint64_t threshold);
void PrintTableStat(const uint64_t table_id);
void SetFileNumOneShard(const uint64_t table_id, int file_num);
// mode = 0, load all feature
// mode = 1, load delta feature, which means load diff
Expand Down
225 changes: 188 additions & 37 deletions paddle/fluid/framework/fleet/heter_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,76 +39,227 @@ namespace framework {

class HeterContext {
public:
//保存去重后的待查table的key, 第一层对应table-shard, 第二层对应不同维度,第三层就是key集合
std::vector<std::vector<std::vector<FeatureKey>>>feature_keys_;
//保存查到的value数据,维度同feature_keys_
virtual ~HeterContext() {
if (!multi_mf_dim_) {
for (size_t i = 0; i < mutex_.size(); ++i) {
delete mutex_[i];
}
mutex_.clear();
} else {
for (size_t i = 0; i < dim_mutex_.size(); ++i) {
for (size_t j = 0; j < dim_mutex_[i].size(); j++) {
delete dim_mutex_[i][j];
}
dim_mutex_[i].clear();
}
}
}
Scope* scope_{nullptr};
std::vector<std::vector<FeatureKey>> feature_keys_;
std::vector<std::vector<std::vector<FeatureKey>>> feature_dim_keys_;
std::vector<std::vector<std::vector<FeatureKey>>> device_task_keys_;

#ifdef PADDLE_WITH_PSLIB
std::vector<std::vector<paddle::ps::DownpourFixedFeatureValue*>> value_ptr_;
std::vector<std::vector<std::vector<paddle::ps::DownpourFixedFeatureValue*>>>
device_task_ptr_;
std::vector<std::vector<std::vector<paddle::ps::DownpourFixedFeatureValue*>>>
value_dim_ptr_;
std::vector<std::vector<std::vector<paddle::ps::DownpourFixedFeatureValue*>>>
value_ptr_;
device_dim_ptr_;
#endif
#ifdef PADDLE_WITH_PSCORE
std::vector<std::vector<paddle::distributed::FixedFeatureValue*>> value_ptr_;
std::vector<std::vector<std::vector<paddle::distributed::FixedFeatureValue*>>>
value_ptr_;
value_dim_ptr_;
std::vector<std::vector<std::vector<paddle::distributed::FixedFeatureValue*>>>
device_task_ptr_;
std::vector<std::vector<std::vector<paddle::distributed::FixedFeatureValue*>>>
device_dim_ptr_;
#endif
//经过去重后的gpu-table中的key数据, 第一层设备,第二层维度,第三层具体的key
std::vector<std::vector<std::vector<FeatureKey>>> device_keys_;
std::vector<std::vector<FeatureValue>> device_values_;
std::vector<std::vector<FeatureKey>> device_keys_;
std::vector<std::vector<std::vector<FeatureKey>>> device_dim_keys_;
std::vector<std::vector<std::vector<FeatureValue>>> device_dim_values_;
std::vector<std::mutex*> mutex_;
std::vector<std::vector<std::mutex*>> dim_mutex_;
int multi_mf_dim_ = 0;

uint32_t shard_num_ = 37;
uint64_t size() {
uint64_t total_size = 0;
for (auto& keys : feature_keys_) {
total_size += keys.size();
}
return total_size;
}
void SetShardNum(uint32_t shard_num) { shard_num_ = shard_num; }
uint32_t ShardNum() { return shard_num_; }
void init(int shard_num, int device_num) {
shard_num_ = shard_num;
feature_keys_.resize(shard_num_);
value_ptr_.resize(shard_num_);
device_task_ptr_.resize(shard_num_);
device_task_keys_.resize(shard_num_);
for (size_t i = 0; i < device_task_ptr_.size(); i++) {
device_task_ptr_[i].resize(device_num);
device_task_keys_[i].resize(device_num);
}

device_values_.resize(device_num);
device_keys_.resize(device_num);
mutex_.resize(device_num);
for (size_t i = 0; i < mutex_.size(); ++i) {
mutex_[i] = new std::mutex();
}
}

//初始化
void init(int shard_num, int device_num, int dim_num) {
feature_keys_.resize(shard_num);
for (auto& iter : feature_keys_) {
iter.resize(dim_num);
for (auto& iter1: iter) {
iter1.clear();
}
shard_num_ = shard_num;
feature_keys_.resize(shard_num_);
feature_dim_keys_.resize(shard_num_);
value_ptr_.resize(shard_num_);
value_dim_ptr_.resize(shard_num_);
device_task_ptr_.resize(shard_num_);
device_task_keys_.resize(shard_num_);
for (size_t i = 0; i < device_task_ptr_.size(); i++) {
device_task_ptr_[i].resize(device_num);
device_task_keys_[i].resize(device_num);
}
value_ptr_.resize(shard_num);
for (auto& iter : value_ptr_) {
iter.resize(dim_num);
for (auto& iter1: iter) {
iter1.clear();
}
for (size_t i = 0; i < feature_dim_keys_.size(); i++) {
feature_dim_keys_[i].resize(dim_num);
value_dim_ptr_[i].resize(dim_num);
}
device_values_.resize(device_num);
device_dim_values_.resize(device_num);
device_keys_.resize(device_num);
for (auto& iter : device_keys_) {
iter.resize(dim_num);
for (auto& iter1: iter) {
iter1.clear();

device_dim_keys_.resize(device_num);
device_dim_ptr_.resize(device_num);
mutex_.resize(device_num);
dim_mutex_.resize(device_num);
for (size_t i = 0; i < mutex_.size(); ++i) {
mutex_[i] = new std::mutex();
}
for (size_t i = 0; i < dim_mutex_.size(); ++i) {
dim_mutex_[i].resize(dim_num);
for (int j = 0; j < dim_num; j++) {
dim_mutex_[i][j] = new std::mutex();
}
}
multi_mf_dim_ = dim_num;
}

void Reset() {
if (!multi_mf_dim_) {
for (size_t i = 0; i < feature_keys_.size(); ++i) {
feature_keys_[i].clear();
}
for (size_t i = 0; i < value_ptr_.size(); ++i) {
value_ptr_[i].clear();
}
for (size_t i = 0; i < device_values_.size(); ++i) {
device_values_[i].clear();
}
for (size_t i = 0; i < device_keys_.size(); ++i) {
device_keys_[i].clear();
}
for (size_t i = 0; i < device_task_ptr_.size(); ++i) {
for (size_t j = 0; j < device_task_ptr_[i].size(); ++j) {
device_task_ptr_[i][j].clear();
device_task_keys_[i][j].clear();
}
}
} else {
VLOG(3) << "Reset gpu task with dynamic mf dimention";
for (size_t i = 0; i < feature_dim_keys_.size(); i++) {
for (size_t j = 0; j < feature_dim_keys_[i].size(); j++) {
feature_dim_keys_[i][j].clear();
}
}
for (size_t i = 0; i < value_dim_ptr_.size(); i++) {
for (size_t j = 0; j < value_dim_ptr_[i].size(); j++) {
value_dim_ptr_[i][j].clear();
}
}

for (size_t i = 0; i < device_dim_keys_.size(); i++) {
for (size_t j = 0; j < device_dim_keys_[i].size(); j++) {
device_dim_keys_[i][j].clear();
}
}
for (size_t i = 0; i < device_dim_ptr_.size(); i++) {
for (size_t j = 0; j < device_dim_ptr_[i].size(); j++) {
device_dim_ptr_[i][j].clear();
}
}
}
}
void batch_add_keys(
const std::vector<std::unordered_set<uint64_t>>& thread_keys) {
assert(thread_keys.size() == feature_keys_.size());

for (uint32_t i = 0; i < shard_num_; i++) {
int idx = 0;
idx = feature_keys_[i].size();
feature_keys_[i].resize(feature_keys_[i].size() + thread_keys[i].size());
std::copy(thread_keys[i].begin(), thread_keys[i].end(),
feature_keys_[i].begin() + idx);
}
}

void batch_add_keys(int shard_num,
const robin_hood::unordered_set<uint64_t>& shard_keys) {
int idx = feature_keys_[shard_num].size();
feature_keys_[shard_num].resize(feature_keys_[shard_num].size() +
shard_keys.size());
std::copy(shard_keys.begin(), shard_keys.end(),
feature_keys_[shard_num].begin() + idx);
}
//将粗去重的key加入进来,后面再做精细化去重

void batch_add_keys(int shard_num, int dim_id,
const robin_hood::unordered_set<uint64_t>& shard_keys) {
int idx = feature_keys_[shard_num][dim_id].size();
feature_keys_[shard_num][dim_id].resize(
feature_keys_[shard_num][dim_id].size() + shard_keys.size());
int idx = feature_dim_keys_[shard_num][dim_id].size();
feature_dim_keys_[shard_num][dim_id].resize(
feature_dim_keys_[shard_num][dim_id].size() + shard_keys.size());
std::copy(shard_keys.begin(), shard_keys.end(),
feature_keys_[shard_num][dim_id].begin() + idx);
feature_dim_keys_[shard_num][dim_id].begin() + idx);
}
void unique_keys() {

void UniqueKeys() {
std::vector<std::thread> threads;
auto unique_func = [this](int i, int j) {
auto& cur_keys = feature_keys_[i][j];
auto unique_func = [this](int i) {
auto& cur_keys = feature_keys_[i];
std::sort(cur_keys.begin(), cur_keys.end());
std::vector<FeatureKey>::iterator it;
it = std::unique(cur_keys.begin(), cur_keys.end());
cur_keys.resize(std::distance(cur_keys.begin(), it));
};
auto unique_dynamic_mf_func = [this](int i, int j) {
auto& cur_keys = feature_dim_keys_[i][j];
std::sort(cur_keys.begin(), cur_keys.end());
std::vector<FeatureKey>::iterator it;
it = std::unique(cur_keys.begin(), cur_keys.end());
cur_keys.resize(std::distance(cur_keys.begin(), it));
};
for (size_t i = 0; i < feature_keys_.size(); i++) {
for (size_t j = 0; j < feature_keys_[i].size(); j++) {
threads.push_back(std::thread(unique_func, i, j));
if (!multi_mf_dim_) {
for (uint32_t i = 0; i < shard_num_; i++) {
threads.push_back(std::thread(unique_func, i));
}
} else {
for (uint32_t i = 0; i < shard_num_; i++) {
for (int j = 0; j < multi_mf_dim_; j++) {
threads.push_back(std::thread(unique_dynamic_mf_func, i, j));
}
}
VLOG(3) << "heter_context unique keys with dynamic mf dimention";
}
for (std::thread& t : threads) {
t.join();
}
}
uint16_t pass_id_;
};


} // end namespace framework
} // end namespace paddle
#endif
8 changes: 4 additions & 4 deletions paddle/fluid/framework/fleet/heter_ps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ IF(WITH_GPU)
get_property(RPC_DEPS GLOBAL PROPERTY RPC_DEPS)
SET(HETERPS_DEPS ${HETERPS_DEPS} ${RPC_DEPS})
endif()
nv_library(heter_comm SRCS heter_comm.h feature_value.h dy_gpu_value_inl.h feature_value_inl.h gpu_value_inl.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS})
nv_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h mem_pool.h DEPS ${HETERPS_DEPS})
nv_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu feature_value.cu DEPS heter_comm)
nv_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
if(WITH_PSCORE)
nv_library(graph_gpu_ps SRCS graph_gpu_ps_table.h DEPS heter_comm table)
nv_test(test_graph_comm SRCS test_graph.cu DEPS graph_gpu_ps)
Expand All @@ -20,7 +20,7 @@ IF(WITH_GPU)
endif()
ENDIF()
IF(WITH_ROCM)
hip_library(heter_comm SRCS heter_comm.h feature_value.h dy_gpu_value_inl.h feature_value_inl.h gpu_value_inl.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context)
hip_library(heter_comm SRCS heter_comm.h feature_value.h heter_resource.cc heter_resource.h hashtable.h DEPS cub device_context)
hip_test(test_heter_comm SRCS feature_value.h DEPS heter_comm)
hip_library(heter_ps SRCS heter_ps.cu feature_value.cu DEPS heter_comm)
hip_library(heter_ps SRCS heter_ps.cu DEPS heter_comm)
ENDIF()
Loading

0 comments on commit 8b92330

Please sign in to comment.