Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch barrier in send/recv op #7847

Merged
merged 11 commits into from
Jan 29, 2018
15 changes: 15 additions & 0 deletions paddle/operators/detail/grpc_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,21 @@ bool RPCClient::AsyncGetVariable(const std::string& ep,
return true;
}

bool RPCClient::AsyncSendBatchBarrier(const std::string& ep, int64_t time_out) {
const auto ch = GetChannel(ep);

BatchBarrierProcessor* s = new BatchBarrierProcessor(ch);
s->Prepare(time_out);

sendrecv::VariableMessage req;
req.set_varname(BATCH_BARRIER_MESSAGE);
auto rpc = s->stub_->AsyncSendVariable(s->context_.get(), req, &cq_);
rpc->Finish(&s->reply_, &s->status_, (void*)s);
req_count_++;

return true;
}

bool RPCClient::Wait() {
if (req_count_ <= 0) {
return true;
Expand Down
24 changes: 24 additions & 0 deletions paddle/operators/detail/grpc_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ class ClientBase {
context_->set_deadline(deadline);
}

virtual void Prepare(int64_t time_out) {
context_.reset(new grpc::ClientContext());

std::chrono::system_clock::time_point deadline =
std::chrono::system_clock::now() + std::chrono::milliseconds(time_out);

context_->set_deadline(deadline);
}

virtual void Process() = 0;

std::unique_ptr<sendrecv::SendRecvService::Stub> stub_;
Expand Down Expand Up @@ -117,6 +126,17 @@ class GetProcessor : public ClientBase {
RequestGetCallBack response_call_back_ = ProcGetResponse;
};

class BatchBarrierProcessor : public ClientBase {
public:
explicit BatchBarrierProcessor(std::shared_ptr<grpc::Channel> ch)
: ClientBase(ch) {}

virtual ~BatchBarrierProcessor() {}

virtual void Process() {}
sendrecv::VoidMessage reply_;
};

class RPCClient {
public:
bool AsyncSendVariable(const std::string& ep,
Expand All @@ -130,6 +150,10 @@ class RPCClient {
const framework::Scope& scope,
const std::string& var_name,
int64_t time_out = 600 * 1000);

bool AsyncSendBatchBarrier(const std::string& ep,
int64_t time_out = 600 * 1000);

bool Wait();

private:
Expand Down
13 changes: 7 additions & 6 deletions paddle/operators/detail/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ void AsyncGRPCServer::RunSyncUpdate() {

cq_send_ = builder.AddCompletionQueue();
cq_get_ = builder.AddCompletionQueue();

server_ = builder.BuildAndStart();
LOG(INFO) << "Server listening on " << address_ << std::endl;

Expand All @@ -141,11 +142,11 @@ void AsyncGRPCServer::RunSyncUpdate() {
std::bind(&AsyncGRPCServer::TryToRegisterNewGetOne, this);

t_send_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, false,
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_send_.get(), "cq_send", send_register)));

t_get_.reset(
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this, true,
new std::thread(std::bind(&AsyncGRPCServer::HandleRequest, this,
cq_get_.get(), "cq_get", get_register)));

// wait server
Expand Down Expand Up @@ -174,7 +175,7 @@ void AsyncGRPCServer::TryToRegisterNewSendOne() {
}
RequestSend* send =
new RequestSend(&service_, cq_send_.get(), &var_recv_queue_);
VLOG(4) << "create RequestSend status:" << send->Status();
VLOG(4) << "Create RequestSend status:" << send->Status();
}

void AsyncGRPCServer::TryToRegisterNewGetOne() {
Expand All @@ -184,11 +185,11 @@ void AsyncGRPCServer::TryToRegisterNewGetOne() {
}
RequestGet* get = new RequestGet(&service_, cq_get_.get(), scope_, dev_ctx_,
&var_get_queue_);
VLOG(4) << "create Requestget status:" << get->Status();
VLOG(4) << "Create RequestGet status:" << get->Status();
}

// FIXME(typhoonzero): remove wait argument and change cq_name to enum.
void AsyncGRPCServer::HandleRequest(bool wait, grpc::ServerCompletionQueue* cq,
// FIXME(typhoonzero): change cq_name to enum.
void AsyncGRPCServer::HandleRequest(grpc::ServerCompletionQueue* cq,
std::string cq_name,
std::function<void()> TryToRegisterNewOne) {
TryToRegisterNewOne();
Expand Down
5 changes: 3 additions & 2 deletions paddle/operators/detail/grpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,12 @@ class AsyncGRPCServer final : public sendrecv::SendRecvService::Service {

void Push(const MessageWithName &msg) { this->var_recv_queue_.Push(msg); }

bool IsRecvQueueEmpty() { return this->var_recv_queue_.IsEmpty(); }

void ShutDown();

protected:
void HandleRequest(bool wait, grpc::ServerCompletionQueue *cq,
std::string cq_name,
void HandleRequest(grpc::ServerCompletionQueue *cq, std::string cq_name,
std::function<void()> TryToRegisterNewOne);
void TryToRegisterNewSendOne();
void TryToRegisterNewGetOne();
Expand Down
3 changes: 3 additions & 0 deletions paddle/operators/detail/sendrecvop_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ namespace paddle {
namespace operators {
namespace detail {

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"
#define BATCH_BARRIER_MESSAGE "BATCH_BARRIER@RECV"

void SerializeToMessage(const std::string& name, const framework::Variable* var,
const platform::DeviceContext& ctx,
sendrecv::VariableMessage* msg);
Expand Down
5 changes: 5 additions & 0 deletions paddle/operators/detail/simple_block_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@ class SimpleBlockQueue {
this->queue_.pop_back();
return rc;
}

bool IsEmpty() {
std::unique_lock<std::mutex> lock(this->mutex_);
return this->queue_.empty();
}
};

} // namespace detail
Expand Down
17 changes: 12 additions & 5 deletions paddle/operators/recv_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ limitations under the License. */
#include "paddle/operators/detail/simple_block_queue.h"
#include "paddle/string/printf.h"

#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV"

namespace paddle {
namespace operators {

Expand Down Expand Up @@ -95,27 +93,33 @@ class RecvOp : public framework::OperatorBase {
auto param_list = Attr<std::vector<std::string>>("ParamList");
auto grad_list = Attr<std::vector<std::string>>("GradList");
auto fan_in = Attr<int>("Fanin");
size_t param_count = param_list.size();

auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock);
auto *program = block->Program();
framework::Executor executor(dev_place);

// TODO(typhoonzero): change this to a while_op for every cluster-batch.
bool exit_flag = false;
size_t barrier_size = param_count * fan_in;
while (!exit_flag) {
// Get from multiple trainers, we don't care about the order in which
// the gradients arrives, just add suffix 0~n and merge the gradient.
rpc_service_->SetCond(0);
for (size_t i = 0; i < barrier_size; ++i) {
size_t barrier_size = 0;
int batch_barrier = 0;
while (batch_barrier != fan_in || !rpc_service_->IsRecvQueueEmpty()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure why !rpc_service_->IsRecvQueueEmpty() is needed. rpc_service_->Get() is a blocking call which will wait until a new message arrives.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it's not used :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted !rpc_service_->IsRecvQueueEmpty(), because send op would send barrier signal by least, if RecvOp received barrier signal, it should be the least message from one trainer.

const detail::MessageWithName &v = rpc_service_->Get();
auto grad_var_name = v.first;
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) {
LOG(INFO) << "received terminate message and exit";
exit_flag = true;
break;
}
if (grad_var_name == BATCH_BARRIER_MESSAGE) {
VLOG(3) << "recv batch barrier message";
batch_barrier++;
continue;
}
barrier_size++;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

barrier_size is used only for printing log, can remove or rename it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name);
std::string param_var_name;
if (it != grad_list.end()) {
Expand All @@ -125,6 +129,7 @@ class RecvOp : public framework::OperatorBase {
}
VLOG(3) << "received grad: " << grad_var_name
<< " updating param: " << param_var_name;

if (fan_in > 1) {
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name);
}
Expand All @@ -135,6 +140,8 @@ class RecvOp : public framework::OperatorBase {
}
detail::DeserializeFromMessage(v.second, dev_ctx, var);
}
VLOG(3) << "recv " << barrier_size << " parmeters for one barrier.";
// TODO(Yancey1989): merge SelectedRows variables here
if (exit_flag) {
break;
}
Expand Down
11 changes: 9 additions & 2 deletions paddle/operators/send_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,20 @@ class SendOp : public framework::OperatorBase {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto& ctx = *pool.Get(place);
for (size_t i = 0; i < ins.size(); i++) {
VLOG(3) << "sending " << ins[i];
VLOG(3) << "sending " << ins[i] << " to " << epmap[i];
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]);
}
PADDLE_ENFORCE(client_.Wait());

std::set<std::string> epset(epmap.begin(), epmap.end());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use endpoints attribute is the same thing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

for (auto& ep : epset) {
VLOG(3) << "batch barrier, ep: " << ep;
client_.AsyncSendBatchBarrier(ep);
}
PADDLE_ENFORCE(client_.Wait());

for (size_t i = 0; i < outs.size(); i++) {
VLOG(3) << "getting " << outs[i];
VLOG(3) << "getting " << outs[i] << " from " << epmap[i];
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]);
}

Expand Down