Skip to content

Commit

Permalink
fix failed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chenwhql committed Dec 30, 2021
2 parents b515271 + de26b88 commit 5ab13ab
Show file tree
Hide file tree
Showing 88 changed files with 3,512 additions and 741 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ void ScaleAPI(const egr::EagerTensor& x, float scale, float bias,
SizeOf(dense_tensor->dtype());
auto dense_out = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<paddle::experimental::SharedStorage>(
paddle::memory::Alloc(place, bytes_size), 0),
paddle::memory::Alloc(place, bytes_size)),
std::move(tensor_meta));
// Handle Device Context
const paddle::platform::Place& expected_kernel_place =
Expand Down
25 changes: 25 additions & 0 deletions paddle/fluid/eager/auto_code_generator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@ execute_process(
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generate_file_structures.py" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/"
)

set(tmp_dygraph_forward_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.tmp.h")
set(tmp_dygraph_forward_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/forwards/dygraph_forward_functions.tmp.cc")
set(tmp_dygraph_node_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.tmp.h")
set(tmp_dygraph_node_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.tmp.cc")
set(dygraph_forward_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h")
set(dygraph_forward_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/forwards/dygraph_forward_functions.cc")
set(dygraph_node_h_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.h")
set(dygraph_node_cc_path "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated/nodes/nodes.cc")

if(WIN32)
set(EAGER_CODEGEN_DEPS eager_generator)
if("${CMAKE_GENERATOR}" STREQUAL "Ninja")
Expand Down Expand Up @@ -48,13 +57,29 @@ if(WIN32)

add_custom_target(eager_codegen
COMMAND "${eager_generator_path}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_forward_h_path} ${dygraph_forward_h_path}
COMMENT "copy_if_different ${tmp_dygraph_forward_h_path} to ${dygraph_forward_h_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_forward_cc_path} ${dygraph_forward_cc_path}
COMMENT "copy_if_different ${tmp_dygraph_forward_cc_path} to ${dygraph_forward_cc_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_node_h_path} ${dygraph_node_h_path}
COMMENT "copy_if_different ${tmp_dygraph_node_h_path} to ${dygraph_node_h_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_node_cc_path} ${dygraph_node_cc_path}
COMMENT "copy_if_different ${tmp_dygraph_node_cc_path} to ${dygraph_node_cc_path}"
DEPENDS ${EAGER_CODEGEN_DEPS}
VERBATIM)
else()
add_custom_target(eager_codegen
COMMAND ${CMAKE_COMMAND} -E env "LD_LIBRARY_PATH=$ENV{LD_LIBRARY_PATH}:${CMAKE_CURRENT_BINARY_DIR}/../../pybind"
"${CMAKE_CURRENT_BINARY_DIR}/eager_generator"
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_forward_h_path} ${dygraph_forward_h_path}
COMMENT "copy_if_different ${tmp_dygraph_forward_h_path} to ${dygraph_forward_h_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_forward_cc_path} ${dygraph_forward_cc_path}
COMMENT "copy_if_different ${tmp_dygraph_forward_cc_path} to ${dygraph_forward_cc_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_node_h_path} ${dygraph_node_h_path}
COMMENT "copy_if_different ${tmp_dygraph_node_h_path} to ${dygraph_node_h_path}"
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_dygraph_node_cc_path} ${dygraph_node_cc_path}
COMMENT "copy_if_different ${tmp_dygraph_node_cc_path} to ${dygraph_node_cc_path}"
DEPENDS eager_generator
VERBATIM)
endif()
32 changes: 14 additions & 18 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1847,19 +1847,15 @@ static std::string GenerateDygraphHFileIncludes() {
return dygraph_forward_api_includes_str;
}

static void GenerateForwardHFile(const std::string& output_dir,
static void GenerateForwardHFile(const std::string& dygraph_forward_api_path,
const std::string& dygraph_forward_api_str) {
std::string dygraph_forward_api_path = output_dir + "/dygraph_forward_api.h";
std::ofstream forward_header_stream(dygraph_forward_api_path, std::ios::out);
forward_header_stream << dygraph_forward_api_str;
forward_header_stream.close();
}

static void GenerateForwardDygraphFile(const std::string& output_dir,
static void GenerateForwardDygraphFile(const std::string& forward_cc_path,
const std::string& fwd_function_str) {
std::string forwards_dir = output_dir + "/forwards/";
std::string forward_cc_filename = "dygraph_forward_functions.cc";
std::string forward_cc_path = forwards_dir + forward_cc_filename;
const char* FORWARD_INCLUDE_TEMPLATE =
"#include "
"\"paddle/fluid/eager/api/generated/fluid_generated/"
Expand All @@ -1876,11 +1872,8 @@ static void GenerateForwardDygraphFile(const std::string& output_dir,
forward_cc_stream.close();
}

static void GenerateNodeHFile(const std::string& output_dir,
static void GenerateNodeHFile(const std::string& node_h_path,
const std::string& grad_node_str) {
std::string nodes_dir = output_dir + "/nodes/";
std::string node_h_filename = "nodes.h";
std::string node_h_path = nodes_dir + node_h_filename;
std::string node_h_include_str =
"#pragma once\n"
"#include \"paddle/fluid/eager/tensor_wrapper.h\"\n"
Expand All @@ -1892,11 +1885,8 @@ static void GenerateNodeHFile(const std::string& output_dir,
node_h_stream.close();
}

static void GenerateNodeCCFile(const std::string& output_dir,
static void GenerateNodeCCFile(const std::string& node_cc_path,
const std::string& grad_function_str) {
std::string nodes_dir = output_dir + "/nodes/";
std::string node_cc_filename = "nodes.cc";
std::string node_cc_path = nodes_dir + node_cc_filename;
const char* NODE_CC_INCLUDE_TEMPLATE =
"#include \"glog/logging.h\"\n"
"#include \"paddle/pten/api/all.h\"\n"
Expand Down Expand Up @@ -2026,18 +2016,24 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
}

VLOG(6) << "-------- GenerateDygraphForwardCCFile -------";
std::string forward_cc_path =
output_dir + "/forwards/dygraph_forward_functions.tmp.cc";
fwd_function_str += "\n";
fwd_function_str += GenerateCoreOpsReturnsInfo();
GenerateForwardDygraphFile(output_dir, fwd_function_str);
GenerateForwardDygraphFile(forward_cc_path, fwd_function_str);

VLOG(6) << "-------- GenerateForwardHFile -------";
GenerateForwardHFile(output_dir, dygraph_forward_api_str);
std::string dygraph_forward_api_path =
output_dir + "/dygraph_forward_api.tmp.h";
GenerateForwardHFile(dygraph_forward_api_path, dygraph_forward_api_str);

VLOG(6) << "-------- GenerateNodeHFile -------";
GenerateNodeHFile(output_dir, grad_node_h_str);
std::string node_h_path = output_dir + "/nodes/nodes.tmp.h";
GenerateNodeHFile(node_h_path, grad_node_h_str);

VLOG(6) << "-------- GenerateNodeCCFile -------";
GenerateNodeCCFile(output_dir, grad_node_cc_str);
std::string node_cc_path = output_dir + "/nodes/nodes.tmp.cc";
GenerateNodeCCFile(node_cc_path, grad_node_cc_str);
}

static void PrepareAttrMapForOps() {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/eager/tests/task_tests/fwd_bwd_joint_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ egr::EagerTensor hook_function(const egr::EagerTensor& t) {
paddle::framework::product(t_dense->dims()) * SizeOf(t_dense->dtype());
auto ret_dense = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<paddle::experimental::SharedStorage>(
paddle::memory::Alloc(place, bytes_size), 0),
paddle::memory::Alloc(place, bytes_size)),
std::move(ret_meta));

float* t_ptr = t_dense->mutable_data<float>();
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/eager/tests/task_tests/hook_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ egr::EagerTensor hook_function(const egr::EagerTensor& t) {
paddle::framework::product(t_dense->dims()) * SizeOf(t_dense->dtype());
auto ret_dense = std::make_shared<pten::DenseTensor>(
pten::make_intrusive<paddle::experimental::SharedStorage>(
paddle::memory::Alloc(place, bytes_size), 0),
paddle::memory::Alloc(place, bytes_size)),
std::move(ret_meta));

float* t_ptr = t_dense->mutable_data<float>();
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/custom_operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -207,14 +207,14 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
"Tensors.",
vec_true_outs.size(), outs.size()));
for (size_t j = 0; j < vec_true_outs.size(); ++j) {
experimental::MovesSharedStorage(
experimental::SharesStorage(
std::dynamic_pointer_cast<pten::DenseTensor>(outs.at(j).impl())
.get(),
vec_true_outs.at(j));
}
} else {
auto* true_out = ctx.Output<Tensor>(out_name);
experimental::MovesSharedStorage(
experimental::SharesStorage(
std::dynamic_pointer_cast<pten::DenseTensor>(outs.at(i).impl())
.get(),
true_out);
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/framework/distributed_strategy.proto
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ message ShardingConfig {
optional bool optimize_cast = 12 [ default = false ];
// Optimizer sharding. Temporary plans and may be deprecated
optional bool _dp_as_optimizer_sharding = 13 [ default = false ];
optional int32 stage = 14 [ default = 1 ];
}

message HybridConfig {
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/hashtable.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ limitations under the License. */
#include "thrust/pair.h"
// #include "cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/cudf/concurrent_unordered_map.cuh.h"
#include "paddle/fluid/framework/fleet/heter_ps/feature_value.h"
#include "paddle/fluid/framework/fleet/heter_ps/mem_pool.h"
#ifdef PADDLE_WITH_HETERPS
#include "paddle/fluid/platform/device/gpu/gpu_types.h"

Expand All @@ -53,24 +55,42 @@ class HashTable {
HashTable& operator=(const HashTable&) = delete;
void insert(const KeyType* d_keys, const ValType* d_vals, size_t len,
gpuStream_t stream);
void insert(const KeyType* d_keys, size_t len, char* pool, size_t start_index,
gpuStream_t stream);
void get(const KeyType* d_keys, ValType* d_vals, size_t len,
gpuStream_t stream);
void get(const KeyType* d_keys, char* d_vals, size_t len, gpuStream_t stream);
void show();
void dump_to_cpu(int devid, cudaStream_t stream);

template <typename GradType, typename Sgd>
void update(const KeyType* d_keys, const GradType* d_grads, size_t len,
Sgd sgd, gpuStream_t stream);

template <typename Sgd>
void update(const KeyType* d_keys, const char* d_grads, size_t len, Sgd sgd,
gpuStream_t stream);

int size() { return container_->size(); }

void set_feature_value_size(size_t pull_feature_value_size,
size_t push_grad_value_size) {
pull_feature_value_size_ = pull_feature_value_size;
push_grad_value_size_ = push_grad_value_size;
VLOG(3) << "hashtable set pull value size: " << pull_feature_value_size_
<< " push value size: " << push_grad_value_size_;
}

std::unique_ptr<RWLock> rwlock_{nullptr};

private:
TableContainer<KeyType, ValType>* container_;
int BLOCK_SIZE_{256};
float LOAD_FACTOR{0.75f};
size_t capacity_;
size_t max_mf_dim_ = 8;
size_t pull_feature_value_size_;
size_t push_grad_value_size_;
};
} // end namespace framework
} // end namespace paddle
Expand Down
88 changes: 88 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/hashtable_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,23 @@ __global__ void insert_kernel(Table* table,
}
}

template <typename Table>
__global__ void insert_kernel(Table* table,
const typename Table::key_type* const keys,
size_t len, char* pool, int start_index) {
ReplaceOp<typename Table::mapped_type> op;
thrust::pair<typename Table::key_type, typename Table::mapped_type> kv;

const size_t i = blockIdx.x * blockDim.x + threadIdx.x;

if (i < len) {
kv.first = keys[i];
kv.second = (Table::mapped_type)(pool + (start_index + i) * 80);
auto it = table->insert(kv, op);
assert(it != table->end() && "error: insert fails: table is full");
}
}

template <typename Table>
__global__ void search_kernel(Table* table,
const typename Table::key_type* const keys,
Expand All @@ -56,6 +73,20 @@ __global__ void search_kernel(Table* table,
}
}

template <typename Table>
__global__ void dy_mf_search_kernel(Table* table,
const typename Table::key_type* const keys,
char* const vals, size_t len,
size_t pull_feature_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
auto it = table->find(keys[i]);

if (it != table->end()) {
*(FeatureValue*)(vals + i * pull_feature_value_size) = *(it->second);
}
}
}
template <typename Table, typename GradType, typename Sgd>
__global__ void update_kernel(Table* table,
const typename Table::key_type* const keys,
Expand All @@ -70,6 +101,23 @@ __global__ void update_kernel(Table* table,
}
}

template <typename Table, typename Sgd>
__global__ void dy_mf_update_kernel(Table* table,
const typename Table::key_type* const keys,
const char* const grads, size_t len,
Sgd sgd, size_t grad_value_size) {
const size_t i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < len) {
auto it = table->find(keys[i]);
if (it != table->end()) {
FeaturePushValue* cur = (FeaturePushValue*)(grads + i * grad_value_size);
sgd.dy_mf_update_value((it.getter())->second, *cur);
} else {
printf("yxf::push miss key: %d", keys[i]);
}
}
}

template <typename KeyType, typename ValType>
HashTable<KeyType, ValType>::HashTable(size_t capacity) {
container_ = new TableContainer<KeyType, ValType>(capacity);
Expand Down Expand Up @@ -97,6 +145,17 @@ void HashTable<KeyType, ValType>::get(const KeyType* d_keys, ValType* d_vals,
d_vals, len);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::get(const KeyType* d_keys, char* d_vals,
size_t len, gpuStream_t stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
dy_mf_search_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
container_, d_keys, d_vals, len, pull_feature_value_size_);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
const ValType* d_vals, size_t len,
Expand All @@ -109,6 +168,21 @@ void HashTable<KeyType, ValType>::insert(const KeyType* d_keys,
d_vals, len);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::insert(const KeyType* d_keys, size_t len,
char* pool, size_t start_index,
gpuStream_t stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;
if (pool == NULL) {
return;
}
insert_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(container_, d_keys, len,
pool, start_index);
}

template <typename KeyType, typename ValType>
void HashTable<KeyType, ValType>::dump_to_cpu(int devid, cudaStream_t stream) {
container_->prefetch(cudaCpuDeviceId, stream);
Expand Down Expand Up @@ -166,6 +240,20 @@ void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
d_grads, len, sgd);
}

template <typename KeyType, typename ValType>
template <typename Sgd>
void HashTable<KeyType, ValType>::update(const KeyType* d_keys,
const char* d_grads, size_t len,
Sgd sgd, gpuStream_t stream) {
if (len == 0) {
return;
}
const int grid_size = (len - 1) / BLOCK_SIZE_ + 1;

dy_mf_update_kernel<<<grid_size, BLOCK_SIZE_, 0, stream>>>(
container_, d_keys, d_grads, len, sgd, push_grad_value_size_);
}

} // end namespace framework
} // end namespace paddle
#endif
7 changes: 7 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_resource.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,13 @@ int HeterPsResource::get_index_by_devid(int devid) {

int HeterPsResource::total_gpu() { return dev_ids_.size(); }

void HeterPsResource::set_multi_mf(int multi_mf_dim, int max_mf_dim) {
multi_mf_dim_ = multi_mf_dim;
max_mf_dim_ = max_mf_dim;
VLOG(3) << "heter resource set mf dim: " << multi_mf_dim_
<< " max_mf_dim_: " << max_mf_dim_;
}

} // end namespace framework
} // end namespace paddle
#endif
3 changes: 3 additions & 0 deletions paddle/fluid/framework/fleet/heter_ps/heter_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,16 @@ class HeterPsResource {
int total_gpu();
int get_index_by_devid(int devid);
int dev_id(int num);
void set_multi_mf(int multi_mf_dim, int max_mf_dim);
gpuStream_t local_stream(int gpu_num, int stream_num);
gpuStream_t remote_stream(int gpu_num, int stream_num);
gpuStream_t comm_stream(int gpu_num, int stream_num);

std::vector<std::shared_ptr<GPUResource>> resources_;
std::vector<int> dev_ids_;
std::map<int, int> devid_2_index_;
int multi_mf_dim_{0};
int max_mf_dim_{0};
};

} // end namespace framework
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/fleet/heter_ps/mem_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class HBMMemoryPool : public managed {
out << "show: " << x->show << " clk: " << x->clk << " slot: " << x->slot
<< " lr: " << x->lr << " mf_dim: " << x->mf_size
<< " mf_size: " << x->mf_size << " mf:";
for (int i = 0; i < x->mf_dim + 1; ++i) {
for (int i = 0; i < x->mf_size + 1; ++i) {
out << " " << x->mf[i];
}
out << "\n";
Expand Down
Loading

1 comment on commit 5ab13ab

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.