-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batch barrier in send/recv op #7847
Changes from 9 commits
6e1fa48
b346c8c
745ec2b
586a06d
0eb9f80
d8551c0
f917403
6b51936
782d048
840bd1f
e4c0de0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,8 +29,6 @@ limitations under the License. */ | |
#include "paddle/operators/detail/simple_block_queue.h" | ||
#include "paddle/string/printf.h" | ||
|
||
#define LISTEN_TERMINATE_MESSAGE "TERMINATE@RECV" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
|
@@ -95,27 +93,33 @@ class RecvOp : public framework::OperatorBase { | |
auto param_list = Attr<std::vector<std::string>>("ParamList"); | ||
auto grad_list = Attr<std::vector<std::string>>("GradList"); | ||
auto fan_in = Attr<int>("Fanin"); | ||
size_t param_count = param_list.size(); | ||
|
||
auto *block = Attr<framework::BlockDesc *>(kOptimizeBlock); | ||
auto *program = block->Program(); | ||
framework::Executor executor(dev_place); | ||
|
||
// TODO(typhoonzero): change this to a while_op for every cluster-batch. | ||
bool exit_flag = false; | ||
size_t barrier_size = param_count * fan_in; | ||
while (!exit_flag) { | ||
// Get from multiple trainers, we don't care about the order in which | ||
// the gradients arrives, just add suffix 0~n and merge the gradient. | ||
rpc_service_->SetCond(0); | ||
for (size_t i = 0; i < barrier_size; ++i) { | ||
size_t barrier_size = 0; | ||
int batch_barrier = 0; | ||
while (batch_barrier != fan_in || !rpc_service_->IsRecvQueueEmpty()) { | ||
const detail::MessageWithName &v = rpc_service_->Get(); | ||
auto grad_var_name = v.first; | ||
if (grad_var_name == LISTEN_TERMINATE_MESSAGE) { | ||
LOG(INFO) << "received terminate message and exit"; | ||
exit_flag = true; | ||
break; | ||
} | ||
if (grad_var_name == BATCH_BARRIER_MESSAGE) { | ||
VLOG(3) << "recv batch barrier message"; | ||
batch_barrier++; | ||
continue; | ||
} | ||
barrier_size++; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
auto it = std::find(grad_list.begin(), grad_list.end(), grad_var_name); | ||
std::string param_var_name; | ||
if (it != grad_list.end()) { | ||
|
@@ -125,6 +129,7 @@ class RecvOp : public framework::OperatorBase { | |
} | ||
VLOG(3) << "received grad: " << grad_var_name | ||
<< " updating param: " << param_var_name; | ||
|
||
if (fan_in > 1) { | ||
grad_var_name = this->GetGradVarNameForTrainer(grad_var_name); | ||
} | ||
|
@@ -135,6 +140,8 @@ class RecvOp : public framework::OperatorBase { | |
} | ||
detail::DeserializeFromMessage(v.second, dev_ctx, var); | ||
} | ||
VLOG(3) << "recv " << barrier_size << " parmeters for one barrier."; | ||
// TODO(Yancey1989): merge SelectedRows variables here | ||
if (exit_flag) { | ||
break; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,13 +41,20 @@ class SendOp : public framework::OperatorBase { | |
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); | ||
auto& ctx = *pool.Get(place); | ||
for (size_t i = 0; i < ins.size(); i++) { | ||
VLOG(3) << "sending " << ins[i]; | ||
VLOG(3) << "sending " << ins[i] << " to " << epmap[i]; | ||
client_.AsyncSendVariable(epmap[i], ctx, scope, ins[i]); | ||
} | ||
PADDLE_ENFORCE(client_.Wait()); | ||
|
||
std::set<std::string> epset(epmap.begin(), epmap.end()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
for (auto& ep : epset) { | ||
VLOG(3) << "batch barrier, ep: " << ep; | ||
client_.AsyncSendBatchBarrier(ep); | ||
} | ||
PADDLE_ENFORCE(client_.Wait()); | ||
|
||
for (size_t i = 0; i < outs.size(); i++) { | ||
VLOG(3) << "getting " << outs[i]; | ||
VLOG(3) << "getting " << outs[i] << " from " << epmap[i]; | ||
client_.AsyncGetVariable(epmap[i], ctx, scope, outs[i]); | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure why
!rpc_service_->IsRecvQueueEmpty()
is needed.rpc_service_->Get()
is a blocking call which will wait until a new message arrives.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, it's not used :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I deleted
!rpc_service_->IsRecvQueueEmpty()
, because send op would send barrier signal by least, if RecvOp received barrier signal, it should be the least message from one trainer.