From 6e1fa48527ce1cbdd6a5a77c5816c927ce7f6275 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 25 Jan 2018 11:30:43 +0800 Subject: [PATCH 1/9] initialize batch barrier --- paddle/operators/detail/grpc_client.cc | 17 ++++++ paddle/operators/detail/grpc_client.h | 14 +++++ paddle/operators/detail/grpc_server.cc | 71 ++++++++++++++++++++++--- paddle/operators/detail/grpc_server.h | 8 ++- paddle/operators/detail/send_recv.proto | 2 + paddle/operators/recv_op.cc | 6 +-- paddle/operators/send_op.cc | 5 ++ 7 files changed, 112 insertions(+), 11 deletions(-) diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc index d699dabf2fb982..c6ce4594b38547 100644 --- a/paddle/operators/detail/grpc_client.cc +++ b/paddle/operators/detail/grpc_client.cc @@ -97,6 +97,23 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } +bool RPCClient::AsyncBatchBarrier(const std::string& ep, int64_t time_out) { + const std::string ep_val = ep; + const auto ch = GetChannel(ep_val); + + framework::Async([ep_val, time_out, ch, this] { + BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); + sendrecv::VoidMessage req; + + auto rpc = s->stub_->AsyncBatchBarrier(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); + }); + + req_count_++; + + return true; +} + bool RPCClient::Wait() { if (req_count_ <= 0) { return true; diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h index a62e70a2533ae5..8f791aedc24c15 100644 --- a/paddle/operators/detail/grpc_client.h +++ b/paddle/operators/detail/grpc_client.h @@ -117,6 +117,17 @@ class GetProcessor : public ClientBase { RequestGetCallBack response_call_back_ = ProcGetResponse; }; +class BatchBarrierProcessor : public ClientBase { + public: + explicit BatchBarrierProcessor(std::shared_ptr ch) + : ClientBase(ch) {} + + virtual ~BatchBarrierProcessor() {} + + virtual void Process() {} + sendrecv::VoidMessage reply_; +}; + class RPCClient { public: bool AsyncSendVariable(const std::string& ep, @@ -130,6 +141,9 @@ class RPCClient { const framework::Scope& scope, const std::string& var_name, int64_t time_out = 600 * 1000); + + bool AsyncBatchBarrier(const std::string& ep, int64_t time_out = 600 * 1000); + bool Wait(); private: diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index 3ddcd839bdd235..2eddcf6912158e 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -117,6 +117,31 @@ class RequestGet final : public RequestBase { SimpleBlockQueue* queue_; }; +class RequestBatchBarrier final : public RequestBase { + public: + explicit RequestBatchBarrier(sendrecv::SendRecvService::AsyncService* service, + grpc::ServerCompletionQueue* cq) + : RequestBase(service, cq), responder_(&ctx_) { + service_->RequestBatchBarrier(&ctx_, &request_, &responder_, cq_, cq_, + this); + } + + virtual ~RequestBatchBarrier() {} + + virtual std::string GetReqName() { return "Batch Barrier"; } + + virtual void Process() { + // TODO(Yancey1989): sub batch cond + responder_.Finish(reply_, grpc::Status::OK, this); + status_ = FINISH; + } + + protected: + sendrecv::VoidMessage request_; + sendrecv::VoidMessage reply_; + ServerAsyncResponseWriter responder_; +}; + void AsyncGRPCServer::WaitClientGet(int count) { for (int i = 0; i < count; ++i) { var_get_queue_.Pop(); @@ -132,6 +157,8 @@ void AsyncGRPCServer::RunSyncUpdate() { cq_send_ = builder.AddCompletionQueue(); cq_get_ = builder.AddCompletionQueue(); + cq_batch_barrier_ = builder.AddCompletionQueue(); + server_ = builder.BuildAndStart(); LOG(INFO) << "Server listening on " << address_ << std::endl; @@ -139,19 +166,26 @@ void AsyncGRPCServer::RunSyncUpdate() { std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); std::function get_register = std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); + std::function batch_barrier_register = + std::bind(&AsyncGRPCServer::TryToRegisterNewBatchBarrier, this); t_send_.reset( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false, + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_send_.get(), "cq_send", send_register))); t_get_.reset( - new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true, + new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); + t_batch_barrier_.reset(new std::thread( + std::bind(&AsyncGRPCServer::HandleRequest, this, cq_batch_barrier_.get(), + "cq_batch_barrier", batch_barrier_register))); + // wait server server_->Wait(); t_send_->join(); t_get_->join(); + t_batch_barrier_->join(); } void AsyncGRPCServer::ShutdownQueue() { @@ -174,7 +208,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() { } RequestSend* send = new RequestSend(&service_, cq_send_.get(), &var_recv_queue_); - VLOG(4) << "create RequestSend status:" << send->Status(); + VLOG(4) << "Create RequestSend status:" << send->Status(); } void AsyncGRPCServer::TryToRegisterNewGetOne() { @@ -184,11 +218,21 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { } RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_, &var_get_queue_); - VLOG(4) << "create Requestget status:" << get->Status(); + VLOG(4) << "Create RequestGet status:" << get->Status(); +} + +void AsyncGRPCServer::TryToRegisterNewBatchBarrier() { + std::unique_lock lock(cq_mutex_); + if (is_shut_down_) { + return; + } + RequestBatchBarrier* r = + new RequestBatchBarrier(&service_, cq_batch_barrier_.get()); + VLOG(4) << "Create RequestBatchBarrier status:" << r->Status(); } -// FIXME(typhoonzero): remove wait argument and change cq_name to enum. -void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq, +// FIXME(typhoonzero): change cq_name to enum. +void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq, std::string cq_name, std::function TryToRegisterNewOne) { TryToRegisterNewOne(); @@ -250,6 +294,21 @@ void AsyncGRPCServer::SetCond(int cond) { barrier_condition_.notify_all(); } +void AsyncGRPCServer::SubCond(int arg) { + { + std::unique_lock lock(this->barrier_mutex_); + barrier_cond_step_ -= arg; + } + barrier_condition_.notify_all(); +} + +bool AsyncGRPCServer::CondEqualTo(int arg) { + { + std::unique_lock lock(this->barrier_mutex_); + return barrier_cond_step_ == arg; + } +} + } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 1ca9086c744c55..a79af9213920b1 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -45,6 +45,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void WaitCond(int cond); void SetCond(int cond); void WaitClientGet(int count); + bool CondEqualTo(int cond); + void SubCond(int arg); void SetScope(framework::Scope *scope) { scope_ = scope; } @@ -57,11 +59,11 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void ShutDown(); protected: - void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq, - std::string cq_name, + void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name, std::function TryToRegisterNewOne); void TryToRegisterNewSendOne(); void TryToRegisterNewGetOne(); + void TryToRegisterNewBatchBarrier(); void ShutdownQueue(); private: @@ -69,6 +71,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { volatile bool is_shut_down_ = false; std::unique_ptr cq_send_; std::unique_ptr cq_get_; + std::unique_ptr cq_batch_barrier_; sendrecv::SendRecvService::AsyncService service_; std::unique_ptr server_; @@ -87,6 +90,7 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { std::unique_ptr t_send_; std::unique_ptr t_get_; + std::unique_ptr t_batch_barrier_; }; }; // namespace detail diff --git a/paddle/operators/detail/send_recv.proto b/paddle/operators/detail/send_recv.proto index 8f962b4c69cc83..d553af97ef6440 100644 --- a/paddle/operators/detail/send_recv.proto +++ b/paddle/operators/detail/send_recv.proto @@ -21,6 +21,8 @@ service SendRecvService { rpc SendVariable(VariableMessage) returns (VoidMessage) {} // Argument VariableMessage for GetVariable should only contain varname. rpc GetVariable(VariableMessage) returns (VariableMessage) {} + // Control the batch barrier + rpc BatchBarrier(VoidMessage) returns (VoidMessage) {} } // VariableMessage is serialized paddle variable message. diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 593c35879ae2b3..14b9d1a5d19ce9 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -107,8 +107,8 @@ class RecvOp : public framework::OperatorBase { while (!exit_flag) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. - rpc_service_->SetCond(0); - for (size_t i = 0; i < barrier_size; ++i) { + rpc_service_->SetCond(fan_in); + while (rpc_service_->CondEqualTo(0)) { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { @@ -145,7 +145,7 @@ class RecvOp : public framework::OperatorBase { } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } - rpc_service_->SetCond(1); + rpc_service_->SetCond(fan_in); rpc_service_->WaitClientGet(barrier_size); grads_counter_.clear(); } // while(true) diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 5aa66c20eaf779..d9d4d9ffa748b8 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -46,6 +46,11 @@ class SendOp : public framework::OperatorBase { } PADDLE_ENFORCE(client_.Wait()); + for (auto& ep : epmap) { + client_.AsyncBatchBarrier(ep); + } + PADDLE_ENFORCE(client_.Wait()); + for (size_t i = 0; i < outs.size(); i++) { VLOG(3) << "getting " << outs[i]; client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); From b346c8c95514a69579ba2939491278aee56afb70 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 25 Jan 2018 11:32:38 +0800 Subject: [PATCH 2/9] add some comments --- paddle/operators/recv_op.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 14b9d1a5d19ce9..d77620e32939bf 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -135,6 +135,7 @@ class RecvOp : public framework::OperatorBase { } detail::DeserializeFromMessage(v.second, dev_ctx, var); } + // TODO(Yancey1989): merge SelectedRows variables here if (exit_flag) { break; } From 745ec2bd6dda8b8e3024af7715b1baf832324dfc Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Thu, 25 Jan 2018 20:11:31 +0800 Subject: [PATCH 3/9] update --- paddle/operators/detail/grpc_client.cc | 16 ++++------ paddle/operators/detail/grpc_client.h | 9 ++++++ paddle/operators/detail/grpc_server.cc | 31 ++++++++++--------- paddle/operators/detail/grpc_server.h | 12 +++++-- paddle/operators/detail/simple_block_queue.h | 5 +++ paddle/operators/recv_op.cc | 14 ++++++--- paddle/operators/send_op.cc | 6 ++-- .../paddle/v2/fluid/distribute_transpiler.py | 4 +-- 8 files changed, 62 insertions(+), 35 deletions(-) diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc index c6ce4594b38547..554dc00e60fe56 100644 --- a/paddle/operators/detail/grpc_client.cc +++ b/paddle/operators/detail/grpc_client.cc @@ -45,7 +45,6 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, SendProcessor* s = new SendProcessor(ch); s->Prepare(var_h, time_out); s->response_call_back_ = NULL; - auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, (void*)s); }); @@ -98,17 +97,14 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, } bool RPCClient::AsyncBatchBarrier(const std::string& ep, int64_t time_out) { - const std::string ep_val = ep; - const auto ch = GetChannel(ep_val); + const auto ch = GetChannel(ep); - framework::Async([ep_val, time_out, ch, this] { - BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); - sendrecv::VoidMessage req; - - auto rpc = s->stub_->AsyncBatchBarrier(s->context_.get(), req, &cq_); - rpc->Finish(&s->reply_, &s->status_, (void*)s); - }); + BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); + s->Prepare(time_out); + sendrecv::VoidMessage req; + auto rpc = s->stub_->AsyncBatchBarrier(s->context_.get(), req, &cq_); + rpc->Finish(&s->reply_, &s->status_, (void*)s); req_count_++; return true; diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h index 8f791aedc24c15..769529aab1e839 100644 --- a/paddle/operators/detail/grpc_client.h +++ b/paddle/operators/detail/grpc_client.h @@ -71,6 +71,15 @@ class ClientBase { context_->set_deadline(deadline); } + virtual void Prepare(int64_t time_out) { + context_.reset(new grpc::ClientContext()); + + std::chrono::system_clock::time_point deadline = + std::chrono::system_clock::now() + std::chrono::milliseconds(time_out); + + context_->set_deadline(deadline); + } + virtual void Process() = 0; std::unique_ptr stub_; diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index 2eddcf6912158e..b162152b0c1d86 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -119,9 +119,10 @@ class RequestGet final : public RequestBase { class RequestBatchBarrier final : public RequestBase { public: - explicit RequestBatchBarrier(sendrecv::SendRecvService::AsyncService* service, + explicit RequestBatchBarrier(AsyncGRPCServer* server, + sendrecv::SendRecvService::AsyncService* service, grpc::ServerCompletionQueue* cq) - : RequestBase(service, cq), responder_(&ctx_) { + : RequestBase(service, cq), responder_(&ctx_), server_(server) { service_->RequestBatchBarrier(&ctx_, &request_, &responder_, cq_, cq_, this); } @@ -131,6 +132,7 @@ class RequestBatchBarrier final : public RequestBase { virtual std::string GetReqName() { return "Batch Barrier"; } virtual void Process() { + server_->SubBatchCond(1); // TODO(Yancey1989): sub batch cond responder_.Finish(reply_, grpc::Status::OK, this); status_ = FINISH; @@ -140,6 +142,7 @@ class RequestBatchBarrier final : public RequestBase { sendrecv::VoidMessage request_; sendrecv::VoidMessage reply_; ServerAsyncResponseWriter responder_; + AsyncGRPCServer* server_; }; void AsyncGRPCServer::WaitClientGet(int count) { @@ -227,7 +230,7 @@ void AsyncGRPCServer::TryToRegisterNewBatchBarrier() { return; } RequestBatchBarrier* r = - new RequestBatchBarrier(&service_, cq_batch_barrier_.get()); + new RequestBatchBarrier(this, &service_, cq_batch_barrier_.get()); VLOG(4) << "Create RequestBatchBarrier status:" << r->Status(); } @@ -294,19 +297,19 @@ void AsyncGRPCServer::SetCond(int cond) { barrier_condition_.notify_all(); } -void AsyncGRPCServer::SubCond(int arg) { - { - std::unique_lock lock(this->barrier_mutex_); - barrier_cond_step_ -= arg; - } - barrier_condition_.notify_all(); +void AsyncGRPCServer::SetBatchCond(int cond) { + std::unique_lock lock(this->batch_barrier_mutex_); + batch_barrier_cond_ = cond; } -bool AsyncGRPCServer::CondEqualTo(int arg) { - { - std::unique_lock lock(this->barrier_mutex_); - return barrier_cond_step_ == arg; - } +void AsyncGRPCServer::SubBatchCond(int arg) { + std::unique_lock lock(this->batch_barrier_mutex_); + batch_barrier_cond_ -= arg; +} + +bool AsyncGRPCServer::BatchCondEqualTo(int arg) { + std::unique_lock lock(this->batch_barrier_mutex_); + return batch_barrier_cond_ == arg; } } // namespace detail diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index a79af9213920b1..240f8ce27a9158 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -45,8 +45,10 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void WaitCond(int cond); void SetCond(int cond); void WaitClientGet(int count); - bool CondEqualTo(int cond); - void SubCond(int arg); + + void SetBatchCond(int cond); + bool BatchCondEqualTo(int arg); + void SubBatchCond(int arg); void SetScope(framework::Scope *scope) { scope_ = scope; } @@ -56,6 +58,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } + bool IsRecvQueueEmpty() { return this->var_recv_queue_.IsEmpty(); } + void ShutDown(); protected: @@ -88,6 +92,10 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { mutable int barrier_cond_step_; std::condition_variable barrier_condition_; + // condition of batch barrier + std::mutex batch_barrier_mutex_; + mutable int batch_barrier_cond_; + std::unique_ptr t_send_; std::unique_ptr t_get_; std::unique_ptr t_batch_barrier_; diff --git a/paddle/operators/detail/simple_block_queue.h b/paddle/operators/detail/simple_block_queue.h index c7f5ff4b5f494c..6bb9ccdf74db6d 100644 --- a/paddle/operators/detail/simple_block_queue.h +++ b/paddle/operators/detail/simple_block_queue.h @@ -45,6 +45,11 @@ class SimpleBlockQueue { this->queue_.pop_back(); return rc; } + + bool IsEmpty() { + std::unique_lock lock(this->mutex_); + return this->queue_.empty(); + } }; } // namespace detail diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index d77620e32939bf..73c4550186b2d0 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -95,7 +95,6 @@ class RecvOp : public framework::OperatorBase { auto param_list = Attr>("ParamList"); auto grad_list = Attr>("GradList"); auto fan_in = Attr("Fanin"); - size_t param_count = param_list.size(); auto *block = Attr(kOptimizeBlock); auto *program = block->Program(); @@ -103,13 +102,16 @@ class RecvOp : public framework::OperatorBase { // TODO(typhoonzero): change this to a while_op for every cluster-batch. bool exit_flag = false; - size_t barrier_size = param_count * fan_in; while (!exit_flag) { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. - rpc_service_->SetCond(fan_in); - while (rpc_service_->CondEqualTo(0)) { + rpc_service_->SetCond(0); + rpc_service_->SetBatchCond(fan_in); + size_t barrier_size = 0; + while (!rpc_service_->BatchCondEqualTo(0) || + !rpc_service_->IsRecvQueueEmpty()) { const detail::MessageWithName &v = rpc_service_->Get(); + barrier_size++; auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { LOG(INFO) << "received terminate message and exit"; @@ -135,6 +137,7 @@ class RecvOp : public framework::OperatorBase { } detail::DeserializeFromMessage(v.second, dev_ctx, var); } + VLOG(3) << "recv " << barrier_size << " parmeters for one barrier."; // TODO(Yancey1989): merge SelectedRows variables here if (exit_flag) { break; @@ -146,7 +149,8 @@ class RecvOp : public framework::OperatorBase { } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } - rpc_service_->SetCond(fan_in); + rpc_service_->SetBatchCond(fan_in); + rpc_service_->SetCond(0); rpc_service_->WaitClientGet(barrier_size); grads_counter_.clear(); } // while(true) diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index d9d4d9ffa748b8..3eb73a2ec8ab79 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -41,12 +41,14 @@ class SendOp : public framework::OperatorBase { platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); for (size_t i = 0; i < ins.size(); i++) { - VLOG(3) << "sending " << ins[i]; + VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); } PADDLE_ENFORCE(client_.Wait()); - for (auto& ep : epmap) { + std::set epset(epmap.begin(), epmap.end()); + for (auto& ep : epset) { + VLOG(3) << "batch barrier, ep: " << ep; client_.AsyncBatchBarrier(ep); } PADDLE_ENFORCE(client_.Wait()); diff --git a/python/paddle/v2/fluid/distribute_transpiler.py b/python/paddle/v2/fluid/distribute_transpiler.py index abcad899bfac9b..08002d8c5f2f58 100644 --- a/python/paddle/v2/fluid/distribute_transpiler.py +++ b/python/paddle/v2/fluid/distribute_transpiler.py @@ -221,7 +221,7 @@ def _append_split_op(self, program, gradblocks): if len(splited_vars) <= 1: continue orig_var = program.global_block().vars[varname] - if orig_var == core.VarDesc.VarType.SELECTED_ROWS: + if orig_var.type == core.VarDesc.VarType.SELECTED_ROWS: height_sections = [] for v in splited_vars: height_sections.append(v.shape[0]) @@ -230,7 +230,7 @@ def _append_split_op(self, program, gradblocks): inputs={"X": orig_var}, outputs={"Out": splited_vars}, attrs={"height_sections": height_sections}) - elif orig_var == core.VarDesc.VarType.LOD_TENSOR: + elif orig_var.type == core.VarDesc.VarType.LOD_TENSOR: sections = [] for v in splited_vars: sections.append(v.shape[0]) From 586a06ddfd93d71aab957691648df7b803464bc6 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Sat, 27 Jan 2018 10:31:56 +0800 Subject: [PATCH 4/9] fix batch barrier --- paddle/operators/recv_op.cc | 2 +- paddle/operators/send_op.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 73c4550186b2d0..5de002baae224c 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -150,7 +150,7 @@ class RecvOp : public framework::OperatorBase { LOG(ERROR) << "run sub program error " << e.what(); } rpc_service_->SetBatchCond(fan_in); - rpc_service_->SetCond(0); + rpc_service_->SetCond(1); rpc_service_->WaitClientGet(barrier_size); grads_counter_.clear(); } // while(true) diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 3eb73a2ec8ab79..4d13c6dcd99b78 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -54,7 +54,7 @@ class SendOp : public framework::OperatorBase { PADDLE_ENFORCE(client_.Wait()); for (size_t i = 0; i < outs.size(); i++) { - VLOG(3) << "getting " << outs[i]; + VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); } From 0eb9f809e45aaa6287b4123af912a30713ba6ecd Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Sat, 27 Jan 2018 22:52:54 +0800 Subject: [PATCH 5/9] use sendvariable rpc interface to send batch barrier --- paddle/operators/detail/grpc_client.cc | 5 +- paddle/operators/detail/grpc_server.cc | 60 ---------------------- paddle/operators/detail/grpc_server.h | 10 ---- paddle/operators/detail/send_recv.proto | 2 - paddle/operators/detail/sendrecvop_utils.h | 3 ++ paddle/operators/recv_op.cc | 16 +++--- 6 files changed, 15 insertions(+), 81 deletions(-) diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc index 554dc00e60fe56..a1804500742f09 100644 --- a/paddle/operators/detail/grpc_client.cc +++ b/paddle/operators/detail/grpc_client.cc @@ -102,8 +102,9 @@ bool RPCClient::AsyncBatchBarrier(const std::string& ep, int64_t time_out) { BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); s->Prepare(time_out); - sendrecv::VoidMessage req; - auto rpc = s->stub_->AsyncBatchBarrier(s->context_.get(), req, &cq_); + sendrecv::VariableMessage req; + req.set_varname(BATCH_BARRIER_MESSAGE); + auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, (void*)s); req_count_++; diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index b162152b0c1d86..d55c89460ebbcd 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -117,34 +117,6 @@ class RequestGet final : public RequestBase { SimpleBlockQueue* queue_; }; -class RequestBatchBarrier final : public RequestBase { - public: - explicit RequestBatchBarrier(AsyncGRPCServer* server, - sendrecv::SendRecvService::AsyncService* service, - grpc::ServerCompletionQueue* cq) - : RequestBase(service, cq), responder_(&ctx_), server_(server) { - service_->RequestBatchBarrier(&ctx_, &request_, &responder_, cq_, cq_, - this); - } - - virtual ~RequestBatchBarrier() {} - - virtual std::string GetReqName() { return "Batch Barrier"; } - - virtual void Process() { - server_->SubBatchCond(1); - // TODO(Yancey1989): sub batch cond - responder_.Finish(reply_, grpc::Status::OK, this); - status_ = FINISH; - } - - protected: - sendrecv::VoidMessage request_; - sendrecv::VoidMessage reply_; - ServerAsyncResponseWriter responder_; - AsyncGRPCServer* server_; -}; - void AsyncGRPCServer::WaitClientGet(int count) { for (int i = 0; i < count; ++i) { var_get_queue_.Pop(); @@ -169,8 +141,6 @@ void AsyncGRPCServer::RunSyncUpdate() { std::bind(&AsyncGRPCServer::TryToRegisterNewSendOne, this); std::function get_register = std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this); - std::function batch_barrier_register = - std::bind(&AsyncGRPCServer::TryToRegisterNewBatchBarrier, this); t_send_.reset( new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, @@ -180,15 +150,10 @@ void AsyncGRPCServer::RunSyncUpdate() { new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, cq_get_.get(), "cq_get", get_register))); - t_batch_barrier_.reset(new std::thread( - std::bind(&AsyncGRPCServer::HandleRequest, this, cq_batch_barrier_.get(), - "cq_batch_barrier", batch_barrier_register))); - // wait server server_->Wait(); t_send_->join(); t_get_->join(); - t_batch_barrier_->join(); } void AsyncGRPCServer::ShutdownQueue() { @@ -224,16 +189,6 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() { VLOG(4) << "Create RequestGet status:" << get->Status(); } -void AsyncGRPCServer::TryToRegisterNewBatchBarrier() { - std::unique_lock lock(cq_mutex_); - if (is_shut_down_) { - return; - } - RequestBatchBarrier* r = - new RequestBatchBarrier(this, &service_, cq_batch_barrier_.get()); - VLOG(4) << "Create RequestBatchBarrier status:" << r->Status(); -} - // FIXME(typhoonzero): change cq_name to enum. void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq, std::string cq_name, @@ -297,21 +252,6 @@ void AsyncGRPCServer::SetCond(int cond) { barrier_condition_.notify_all(); } -void AsyncGRPCServer::SetBatchCond(int cond) { - std::unique_lock lock(this->batch_barrier_mutex_); - batch_barrier_cond_ = cond; -} - -void AsyncGRPCServer::SubBatchCond(int arg) { - std::unique_lock lock(this->batch_barrier_mutex_); - batch_barrier_cond_ -= arg; -} - -bool AsyncGRPCServer::BatchCondEqualTo(int arg) { - std::unique_lock lock(this->batch_barrier_mutex_); - return batch_barrier_cond_ == arg; -} - } // namespace detail } // namespace operators } // namespace paddle diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 240f8ce27a9158..0d8bce2a2954ba 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -46,10 +46,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void SetCond(int cond); void WaitClientGet(int count); - void SetBatchCond(int cond); - bool BatchCondEqualTo(int arg); - void SubBatchCond(int arg); - void SetScope(framework::Scope *scope) { scope_ = scope; } void SetDevCtx(const platform::DeviceContext *dev_ctx) { dev_ctx_ = dev_ctx; } @@ -67,7 +63,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { std::function TryToRegisterNewOne); void TryToRegisterNewSendOne(); void TryToRegisterNewGetOne(); - void TryToRegisterNewBatchBarrier(); void ShutdownQueue(); private: @@ -92,13 +87,8 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { mutable int barrier_cond_step_; std::condition_variable barrier_condition_; - // condition of batch barrier - std::mutex batch_barrier_mutex_; - mutable int batch_barrier_cond_; - std::unique_ptr t_send_; std::unique_ptr t_get_; - std::unique_ptr t_batch_barrier_; }; }; // namespace detail diff --git a/paddle/operators/detail/send_recv.proto b/paddle/operators/detail/send_recv.proto index d553af97ef6440..8f962b4c69cc83 100644 --- a/paddle/operators/detail/send_recv.proto +++ b/paddle/operators/detail/send_recv.proto @@ -21,8 +21,6 @@ service SendRecvService { rpc SendVariable(VariableMessage) returns (VoidMessage) {} // Argument VariableMessage for GetVariable should only contain varname. rpc GetVariable(VariableMessage) returns (VariableMessage) {} - // Control the batch barrier - rpc BatchBarrier(VoidMessage) returns (VoidMessage) {} } // VariableMessage is serialized paddle variable message. diff --git a/paddle/operators/detail/sendrecvop_utils.h b/paddle/operators/detail/sendrecvop_utils.h index bc6581afab93c6..8e66f7299c7b4d 100644 --- a/paddle/operators/detail/sendrecvop_utils.h +++ b/paddle/operators/detail/sendrecvop_utils.h @@ -30,6 +30,9 @@ namespace paddle { namespace operators { namespace detail { +#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" +#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV" + void SerializeToMessage(const std::string& name, const framework::Variable* var, const platform::DeviceContext& ctx, sendrecv::VariableMessage* msg); diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index 5de002baae224c..e1f308f8dbf64f 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -29,8 +29,6 @@ limitations under the License. */ #include "paddle/operators/detail/simple_block_queue.h" #include "paddle/string/printf.h" -#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" - namespace paddle { namespace operators { @@ -106,18 +104,22 @@ class RecvOp : public framework::OperatorBase { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. rpc_service_->SetCond(0); - rpc_service_->SetBatchCond(fan_in); size_t barrier_size = 0; - while (!rpc_service_->BatchCondEqualTo(0) || - !rpc_service_->IsRecvQueueEmpty()) { + int batch_barrier = 0; + while (batch_barrier != fan_in || !rpc_service_->IsRecvQueueEmpty()) { const detail::MessageWithName &v = rpc_service_->Get(); - barrier_size++; auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { LOG(INFO) << "received terminate message and exit"; exit_flag = true; break; } + if (grad_var_name == BATCH_BARRIER_MESSAGE) { + VLOG(3) << "recv batch barrier message"; + batch_barrier++; + continue; + } + barrier_size++; auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); std::string param_var_name; if (it != grad_list.end()) { @@ -127,6 +129,7 @@ class RecvOp : public framework::OperatorBase { } VLOG(3) << "received grad: " << grad_var_name << " updating param: " << param_var_name; + if (fan_in > 1) { grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); } @@ -149,7 +152,6 @@ class RecvOp : public framework::OperatorBase { } catch (std::exception &e) { LOG(ERROR) << "run sub program error " << e.what(); } - rpc_service_->SetBatchCond(fan_in); rpc_service_->SetCond(1); rpc_service_->WaitClientGet(barrier_size); grads_counter_.clear(); From f9174038fefd6b54d0561946ee7e779ac83d055d Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 29 Jan 2018 10:42:47 +0800 Subject: [PATCH 6/9] fix comment --- paddle/operators/detail/grpc_client.cc | 3 ++- paddle/operators/detail/grpc_client.h | 3 ++- paddle/operators/detail/grpc_server.cc | 1 - paddle/operators/detail/grpc_server.h | 1 - 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/paddle/operators/detail/grpc_client.cc b/paddle/operators/detail/grpc_client.cc index 5a78789800c2d0..e44639d8de88f2 100644 --- a/paddle/operators/detail/grpc_client.cc +++ b/paddle/operators/detail/grpc_client.cc @@ -45,6 +45,7 @@ bool RPCClient::AsyncSendVariable(const std::string& ep, SendProcessor* s = new SendProcessor(ch); s->Prepare(var_h, time_out); s->response_call_back_ = NULL; + auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_); rpc->Finish(&s->reply_, &s->status_, (void*)s); }); @@ -96,7 +97,7 @@ bool RPCClient::AsyncGetVariable(const std::string& ep, return true; } -bool RPCClient::AsyncBatchBarrier(const std::string& ep, int64_t time_out) { +bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) { const auto ch = GetChannel(ep); BatchBarrierProcessor* s = new BatchBarrierProcessor(ch); diff --git a/paddle/operators/detail/grpc_client.h b/paddle/operators/detail/grpc_client.h index 769529aab1e839..f9499f6dc70c54 100644 --- a/paddle/operators/detail/grpc_client.h +++ b/paddle/operators/detail/grpc_client.h @@ -151,7 +151,8 @@ class RPCClient { const std::string& var_name, int64_t time_out = 600 * 1000); - bool AsyncBatchBarrier(const std::string& ep, int64_t time_out = 600 * 1000); + bool AsyncSendBatchBarrier(const std::string& ep, + int64_t time_out = 600 * 1000); bool Wait(); diff --git a/paddle/operators/detail/grpc_server.cc b/paddle/operators/detail/grpc_server.cc index d55c89460ebbcd..4f94e1315fbd28 100644 --- a/paddle/operators/detail/grpc_server.cc +++ b/paddle/operators/detail/grpc_server.cc @@ -132,7 +132,6 @@ void AsyncGRPCServer::RunSyncUpdate() { cq_send_ = builder.AddCompletionQueue(); cq_get_ = builder.AddCompletionQueue(); - cq_batch_barrier_ = builder.AddCompletionQueue(); server_ = builder.BuildAndStart(); LOG(INFO) << "Server listening on " << address_ << std::endl; diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 0d8bce2a2954ba..4f7db9868b51a1 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -70,7 +70,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { volatile bool is_shut_down_ = false; std::unique_ptr cq_send_; std::unique_ptr cq_get_; - std::unique_ptr cq_batch_barrier_; sendrecv::SendRecvService::AsyncService service_; std::unique_ptr server_; From 782d0480a49d99de23e4934bc74539b0a55d4636 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 29 Jan 2018 14:35:59 +0800 Subject: [PATCH 7/9] fix method --- paddle/operators/send_op.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 4d13c6dcd99b78..7f1620a49eda54 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -49,7 +49,7 @@ class SendOp : public framework::OperatorBase { std::set epset(epmap.begin(), epmap.end()); for (auto& ep : epset) { VLOG(3) << "batch barrier, ep: " << ep; - client_.AsyncBatchBarrier(ep); + client_.AsyncSendBatchBarrier(ep); } PADDLE_ENFORCE(client_.Wait()); From 840bd1f9fc6fe0a5b2936b11d80d0e088923b5a5 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 29 Jan 2018 15:24:58 +0800 Subject: [PATCH 8/9] fix by comment --- paddle/operators/detail/grpc_server.h | 2 - paddle/operators/detail/simple_block_queue.h | 5 -- paddle/operators/recv_op.cc | 52 ++++++++++---------- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/paddle/operators/detail/grpc_server.h b/paddle/operators/detail/grpc_server.h index 4f7db9868b51a1..3f8b9d93176148 100644 --- a/paddle/operators/detail/grpc_server.h +++ b/paddle/operators/detail/grpc_server.h @@ -54,8 +54,6 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service { void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); } - bool IsRecvQueueEmpty() { return this->var_recv_queue_.IsEmpty(); } - void ShutDown(); protected: diff --git a/paddle/operators/detail/simple_block_queue.h b/paddle/operators/detail/simple_block_queue.h index 6bb9ccdf74db6d..c7f5ff4b5f494c 100644 --- a/paddle/operators/detail/simple_block_queue.h +++ b/paddle/operators/detail/simple_block_queue.h @@ -45,11 +45,6 @@ class SimpleBlockQueue { this->queue_.pop_back(); return rc; } - - bool IsEmpty() { - std::unique_lock lock(this->mutex_); - return this->queue_.empty(); - } }; } // namespace detail diff --git a/paddle/operators/recv_op.cc b/paddle/operators/recv_op.cc index e1f308f8dbf64f..080b4d869e6513 100644 --- a/paddle/operators/recv_op.cc +++ b/paddle/operators/recv_op.cc @@ -104,43 +104,45 @@ class RecvOp : public framework::OperatorBase { // Get from multiple trainers, we don't care about the order in which // the gradients arrives, just add suffix 0~n and merge the gradient. rpc_service_->SetCond(0); - size_t barrier_size = 0; + size_t recv_var_cnt = 0; int batch_barrier = 0; - while (batch_barrier != fan_in || !rpc_service_->IsRecvQueueEmpty()) { + while (batch_barrier != fan_in) { const detail::MessageWithName &v = rpc_service_->Get(); auto grad_var_name = v.first; if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { LOG(INFO) << "received terminate message and exit"; exit_flag = true; break; - } - if (grad_var_name == BATCH_BARRIER_MESSAGE) { + } else if (grad_var_name == BATCH_BARRIER_MESSAGE) { VLOG(3) << "recv batch barrier message"; batch_barrier++; continue; - } - barrier_size++; - auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); - std::string param_var_name; - if (it != grad_list.end()) { - param_var_name = param_list[it - grad_list.begin()]; } else { - LOG(ERROR) << "grad has no paired param:" << grad_var_name; - } - VLOG(3) << "received grad: " << grad_var_name - << " updating param: " << param_var_name; - - if (fan_in > 1) { - grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); - } - auto *var = recv_scope.FindVar(grad_var_name); - if (var == nullptr) { - LOG(ERROR) << "Can not find server side var: " << grad_var_name; - PADDLE_THROW("Can not find server side var"); + // receive a variable + recv_var_cnt++; + auto it = + std::find(grad_list.begin(), grad_list.end(), grad_var_name); + std::string param_var_name; + if (it != grad_list.end()) { + param_var_name = param_list[it - grad_list.begin()]; + } else { + LOG(ERROR) << "grad has no paired param:" << grad_var_name; + } + VLOG(3) << "received grad: " << grad_var_name + << " updating param: " << param_var_name; + + if (fan_in > 1) { + grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); + } + auto *var = recv_scope.FindVar(grad_var_name); + if (var == nullptr) { + LOG(ERROR) << "Can not find server side var: " << grad_var_name; + PADDLE_THROW("Can not find server side var"); + } + detail::DeserializeFromMessage(v.second, dev_ctx, var); } - detail::DeserializeFromMessage(v.second, dev_ctx, var); } - VLOG(3) << "recv " << barrier_size << " parmeters for one barrier."; + VLOG(3) << "recv " << recv_var_cnt << " parmeters for one barrier."; // TODO(Yancey1989): merge SelectedRows variables here if (exit_flag) { break; @@ -153,7 +155,7 @@ class RecvOp : public framework::OperatorBase { LOG(ERROR) << "run sub program error " << e.what(); } rpc_service_->SetCond(1); - rpc_service_->WaitClientGet(barrier_size); + rpc_service_->WaitClientGet(recv_var_cnt); grads_counter_.clear(); } // while(true) } From e4c0de071b15dea112a6d148749750c63104aec4 Mon Sep 17 00:00:00 2001 From: Yancey1989 Date: Mon, 29 Jan 2018 15:49:27 +0800 Subject: [PATCH 9/9] fix by comment --- paddle/operators/send_op.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/paddle/operators/send_op.cc b/paddle/operators/send_op.cc index 7f1620a49eda54..bb719dc2a8a577 100644 --- a/paddle/operators/send_op.cc +++ b/paddle/operators/send_op.cc @@ -37,6 +37,8 @@ class SendOp : public framework::OperatorBase { auto ins = Inputs("X"); auto outs = Outputs("Out"); std::vector epmap = Attr>("epmap"); + std::vector endpoints = + Attr>("endpoints"); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto& ctx = *pool.Get(place); @@ -46,8 +48,7 @@ class SendOp : public framework::OperatorBase { } PADDLE_ENFORCE(client_.Wait()); - std::set epset(epmap.begin(), epmap.end()); - for (auto& ep : epset) { + for (auto& ep : endpoints) { VLOG(3) << "batch barrier, ep: " << ep; client_.AsyncSendBatchBarrier(ep); }