Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fleet_executor] Add amplifier interceptor and 1F1B scheduler test #37755

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion paddle/fluid/distributed/fleet_executor/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ else()
endif()

cc_library(fleet_executor SRCS fleet_executor.cc carrier.cc task_node.cc runtime_graph.cc
interceptor.cc compute_interceptor.cc interceptor_message_service.cc message_bus.cc
interceptor.cc compute_interceptor.cc amplifier_interceptor.cc interceptor_message_service.cc message_bus.cc
DEPS proto_desc fleet_executor_desc_proto interceptor_message_proto collective_helper
${BRPC_DEPS})

if(WITH_DISTRIBUTE)
set(DISTRIBUTE_COMPILE_FLAGS "-Wno-non-virtual-dtor -Wno-error=non-virtual-dtor -Wno-error=delete-non-virtual-dtor")
set_source_files_properties(interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(compute_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.h PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(message_bus.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
set_source_files_properties(fleet_executor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
Expand Down
82 changes: 82 additions & 0 deletions paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h"

#include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace distributed {

AmplifierInterceptor::AmplifierInterceptor(int64_t interceptor_id,
TaskNode* node)
: ComputeInterceptor(interceptor_id, node) {
run_per_steps_ = node->run_per_steps();
run_at_offset_ = node->run_at_offset();
reply_up_per_steps_ = node->reply_up_per_steps();
send_down_per_steps_ = node->send_down_per_steps();

PADDLE_ENFORCE_GE(
run_per_steps_, 1,
platform::errors::InvalidArgument(
"run_per_steps must >= 1, but now is %ld", run_per_steps_));
PADDLE_ENFORCE_GE(
run_at_offset_, 0,
platform::errors::InvalidArgument(
"run_at_offset must >= 0, but now is %ld", run_at_offset_));
PADDLE_ENFORCE_LT(run_at_offset_, run_per_steps_,
platform::errors::InvalidArgument(
"run_at_offset must < run_per_steps, must now "
"run_at_offset=%ld run_per_steps=%ld",
run_at_offset_, run_per_steps_));
PADDLE_ENFORCE_GE(
reply_up_per_steps_, 1,
platform::errors::InvalidArgument(
"reply_up_per_steps must >= 1, but now is %ld", reply_up_per_steps_));
PADDLE_ENFORCE_GE(send_down_per_steps_, 1,
platform::errors::InvalidArgument(
"send_down_per_steps must >= 1, but now is %ld",
send_down_per_steps_));
}

void AmplifierInterceptor::RunOps() {
// run_per_steps_, run_at_offset_
// 4, 0 --> run at step 0, 4, 8, 12
// 4, 3 --> run at step 3, 7, 11, 15
if ((step_ % run_per_steps_) == run_at_offset_) {
ComputeInterceptor::RunOps();
}
}

void AmplifierInterceptor::SendDataReadyToDownStream() {
// run multi times, send ready one times to downstream, that is
// input multi times, output one times
if (step_ % send_down_per_steps_ == 0) {
ComputeInterceptor::SendDataReadyToDownStream();
}
}

void AmplifierInterceptor::ReplyCompletedToUpStream() {
// run multi times, reply one times to upstream, that is
// input one times, output multi times
if (step_ % reply_up_per_steps_ == 0) {
ComputeInterceptor::ReplyCompletedToUpStream();
}
}

REGISTER_INTERCEPTOR(Amplifier, AmplifierInterceptor);

} // namespace distributed
} // namespace paddle
43 changes: 43 additions & 0 deletions paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <utility>

#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h"

namespace paddle {
namespace distributed {

class AmplifierInterceptor : public ComputeInterceptor {
public:
AmplifierInterceptor(int64_t interceptor_id, TaskNode* node);

private:
void RunOps() override;
void SendDataReadyToDownStream() override;
void ReplyCompletedToUpStream() override;

int64_t run_per_steps_{1};
int64_t run_at_offset_{0};

// one input produces multi times output
int64_t reply_up_per_steps_{1};
// one output need multi times input
int64_t send_down_per_steps_{1};
};

} // namespace distributed
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/distributed/fleet_executor/carrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace paddle {
namespace distributed {

USE_INTERCEPTOR(Compute);
USE_INTERCEPTOR(Amplifier);

void Carrier::Init(std::shared_ptr<RuntimeGraph> runtime_graph,
framework::Scope* root_scope,
Expand Down
15 changes: 9 additions & 6 deletions paddle/fluid/distributed/fleet_executor/compute_interceptor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,23 +160,26 @@ void ComputeInterceptor::ReplyCompletedToUpStream() {
}
}

void ComputeInterceptor::RunOps() {
VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops.";
for (auto op : node_->ops()) {
op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_);
}
}

void ComputeInterceptor::Run() {
while (IsInputReady() && CanWriteOutput() && !ShouldReset()) {
VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running";

// step_ %= node_->max_run_times();
for (auto op : node_->ops()) {
auto* scope = microbatch_scopes_[step_ % node_->max_run_times()];
op->Run(*scope, place_);
}
RunOps();
++step_;

// send to downstream and increase buff used
SendDataReadyToDownStream();
// reply to upstream and decrease ready data
ReplyCompletedToUpStream();
// Try to stop Carrier
if (step_ % node_->max_run_times() == 0 && is_last_) {
if (is_last_ && (step_ % node_->max_run_times() == 0)) {
StopCarrier();
}
}
Expand Down
13 changes: 8 additions & 5 deletions paddle/fluid/distributed/fleet_executor/compute_interceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ class ComputeInterceptor : public Interceptor {
public:
ComputeInterceptor(int64_t interceptor_id, TaskNode* node);

protected:
virtual void RunOps();
virtual void SendDataReadyToDownStream();
virtual void ReplyCompletedToUpStream();

int64_t step_{0};

private:
void PrepareDeps();

void IncreaseReady(int64_t up_id);
Expand All @@ -33,19 +41,14 @@ class ComputeInterceptor : public Interceptor {
bool CanWriteOutput();
bool ShouldReset();

void SendDataReadyToDownStream();
void ReplyCompletedToUpStream();

void Run();
void Compute(const InterceptorMessage& msg);

void ReceivedStop(int64_t up_id);
void TryStop();

private:
bool is_source_{false};
bool is_last_{false};
int64_t step_{0};

// upstream_id-->(max_ready_size, ready_size)
std::map<int64_t, std::pair<int64_t, int64_t>> in_readys_{};
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/distributed/fleet_executor/task_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,22 @@ class TaskNode final {
int32_t role() const { return role_; }
int64_t max_run_times() const { return max_run_times_; }
int64_t max_slot_nums() const { return max_slot_nums_; }
int64_t run_per_steps() const { return run_per_steps_; }
int64_t run_at_offset() const { return run_at_offset_; }
int64_t reply_up_per_steps() const { return reply_up_per_steps_; }
int64_t send_down_per_steps() const { return send_down_per_steps_; }
const std::unordered_set<int64_t>& upstream() const { return upstream_; }
const std::unordered_set<int64_t>& downstream() const { return downstream_; }
const std::string& type() const { return type_; }
const paddle::framework::ProgramDesc& program() const { return program_; }
const std::vector<OperatorBase*>& ops() const { return ops_; }

void SetRunPerSteps(int64_t value) { run_per_steps_ = value; }
void SetRunAtOffset(int64_t value) { run_at_offset_ = value; }
void SetReplyUpPerSteps(int64_t value) { reply_up_per_steps_ = value; }
void SetSendDownPerSteps(int64_t value) { send_down_per_steps_ = value; }
void SetType(const std::string& type) { type_ = type; }

bool AddUpstreamTask(int64_t task_id);
bool AddDownstreamTask(int64_t task_id);
std::string DebugString() const;
Expand All @@ -76,6 +86,13 @@ class TaskNode final {
int64_t max_run_times_;
int64_t max_slot_nums_;

int64_t run_per_steps_{1};
int64_t run_at_offset_{0};
// one input produces multi times output
int64_t reply_up_per_steps_{1};
// one output need multi times input
int64_t send_down_per_steps_{1};

std::string type_;
};

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/distributed/fleet_executor/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ cc_test(interceptor_ping_pong_test SRCS interceptor_ping_pong_test.cc DEPS fleet
set_source_files_properties(compute_interceptor_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_test SRCS compute_interceptor_test.cc DEPS fleet_executor ${BRPC_DEPS})

set_source_files_properties(interceptor_pipeline_short_path_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pipeline_short_path_test SRCS interceptor_pipeline_short_path_test.cc DEPS fleet_executor ${BRPC_DEPS})

set_source_files_properties(interceptor_pipeline_long_path_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(interceptor_pipeline_long_path_test SRCS interceptor_pipeline_long_path_test.cc DEPS fleet_executor ${BRPC_DEPS})

set_source_files_properties(compute_interceptor_run_op_test.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS})
cc_test(compute_interceptor_run_op_test SRCS compute_interceptor_run_op_test.cc DEPS fleet_executor ${BRPC_DEPS} op_registry fill_constant_op elementwise_add_op scope device_context)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <iostream>
#include <unordered_map>

#include "gtest/gtest.h"

#include "paddle/fluid/distributed/fleet_executor/carrier.h"
#include "paddle/fluid/distributed/fleet_executor/interceptor.h"
#include "paddle/fluid/distributed/fleet_executor/message_bus.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h"

namespace paddle {
namespace distributed {

void LinkNodes(const std::vector<TaskNode*>& nodes) {
size_t size = nodes.size();
if (size <= 1) return;

{ // i = 0
TaskNode* now = nodes[0];
TaskNode* next = nodes[1];
now->AddDownstreamTask(next->task_id());
}
{ // i = size - 1
TaskNode* prev = nodes[size - 2];
TaskNode* now = nodes[size - 1];
now->AddUpstreamTask(prev->task_id());
}

for (size_t i = 1; i < size - 1; ++i) {
TaskNode* prev = nodes[i - 1];
TaskNode* now = nodes[i];
TaskNode* next = nodes[i + 1];

now->AddUpstreamTask(prev->task_id());
now->AddDownstreamTask(next->task_id());
}
}

TEST(AmplifierInterceptor, Amplifier) {
Carrier& carrier = Carrier::Instance();
MessageBus& msg_bus = MessageBus::Instance();
msg_bus.Init({{0, 0}, {1, 0}, {2, 0}, {3, 0}, {4, 0}, {5, 0}},
{{0, "127.0.0.0:0"}}, "127.0.0.0:0");

int64_t micro_steps = 3;

// NOTE: don't delete, otherwise interceptor will use undefined node
TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id
TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0);
TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0);
TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0);
TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0);
TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0);

// a->b->c->d->e->f
LinkNodes({node_a, node_b, node_c, node_d, node_e, node_f});

// LR->b(1:3)->F->B->e(3:1)->U
node_b->SetReplyUpPerSteps(micro_steps);
node_e->SetSendDownPerSteps(micro_steps);

carrier.SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a));
carrier.SetInterceptor(1, InterceptorFactory::Create("Amplifier", 1, node_b));
carrier.SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c));
carrier.SetInterceptor(3, InterceptorFactory::Create("Compute", 3, node_d));
carrier.SetInterceptor(4, InterceptorFactory::Create("Amplifier", 4, node_e));
carrier.SetInterceptor(5, InterceptorFactory::Create("Compute", 5, node_f));

carrier.SetCreatingFlag(false);

// start
InterceptorMessage msg;
msg.set_message_type(DATA_IS_READY);
msg.set_src_id(-1);
msg.set_dst_id(0);
carrier.EnqueueInterceptorMessage(msg);
}

} // namespace distributed
} // namespace paddle
Loading