-
Notifications
You must be signed in to change notification settings - Fork 5.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add align for WorkQueue * add spinlock * merge develop * merge * Add EventsWaiter * Revert "Add EventsWaiter" This reverts commit e206173. * update EventsWater * fix * split workqueue files * add more tests * fix * bugfix * bugfix * update Co-authored-by: liutiexing <liutiexing@google.com>
- Loading branch information
1 parent
4221cd3
commit 198d11b
Showing
17 changed files
with
380 additions
and
165 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
cc_library(workqueue SRCS workqueue.cc workqueue_utils.cc events_waiter.cc DEPS enforce glog) | ||
cc_test(workqueue_test SRCS workqueue_test.cc DEPS workqueue) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
147 changes: 147 additions & 0 deletions
147
paddle/fluid/framework/new_executor/workqueue/events_waiter.cc
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "paddle/fluid/framework/new_executor/workqueue/events_waiter.h" | ||
#include <glog/logging.h> | ||
#include "paddle/fluid/platform/enforce.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
EventsWaiter::EventsWaiter() | ||
: trigger_event_(nullptr), counter_(0), waiting_(false), cv_(1) {} | ||
|
||
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent( | ||
const std::string& name, EventChecker checker) { | ||
auto counter = counter_.fetch_add(1); | ||
auto id = std::hash<std::string>()(name + std::to_string(counter)); | ||
VLOG(10) << "Register event id:" << id << " name:" << name; | ||
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this)); | ||
EventInfo evt{id, name, TriggerType::LevelTriggered, std::move(checker)}; | ||
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_); | ||
events_[id] = std::move(evt); | ||
return notifier; | ||
} | ||
|
||
std::shared_ptr<EventsWaiter::EventNotifier> EventsWaiter::RegisterEvent( | ||
const std::string& name) { | ||
auto counter = counter_.fetch_add(1); | ||
auto id = std::hash<std::string>()(name + std::to_string(counter)); | ||
VLOG(10) << "Register event id:" << id << " name:" << name; | ||
auto notifier = std::shared_ptr<EventNotifier>(new EventNotifier(id, this)); | ||
EventInfo evt{id, name, TriggerType::EdgeTriggered, []() { return false; }}; | ||
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_); | ||
events_[id] = std::move(evt); | ||
return notifier; | ||
} | ||
|
||
void EventsWaiter::UnregisterEvent(const EventId& id) { | ||
VLOG(10) << "Unregister event id:" << id; | ||
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_); | ||
events_.erase(id); | ||
} | ||
|
||
std::string EventsWaiter::WaitEvent() { | ||
// only one user can wait at any time | ||
bool waiting = false; | ||
if (!waiting_.compare_exchange_strong(waiting, true, | ||
std::memory_order_seq_cst, | ||
std::memory_order_relaxed)) { | ||
PADDLE_THROW( | ||
platform::errors::ResourceExhausted("Another thread is waiting.")); | ||
} | ||
auto w = cv_.GetWaiter(0); | ||
cv_.Prewait(); | ||
std::string* triggered = trigger_event_; | ||
if (triggered == nullptr) { | ||
// checkers | ||
{ | ||
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_); | ||
for (auto& kv : events_) { | ||
auto& evt = kv.second; | ||
if (TriggerType::LevelTriggered == evt.type && evt.checker()) { | ||
triggered = new std::string(evt.name); | ||
break; | ||
} | ||
} | ||
} | ||
if (triggered != nullptr) { | ||
std::string* prev = nullptr; | ||
if (!trigger_event_.compare_exchange_strong(prev, triggered, | ||
std::memory_order_seq_cst, | ||
std::memory_order_relaxed)) { | ||
delete triggered; | ||
triggered = prev; | ||
} | ||
} | ||
} | ||
if (triggered) { | ||
cv_.CancelWait(); | ||
} else { | ||
cv_.CommitWait(w); | ||
triggered = trigger_event_; | ||
} | ||
trigger_event_.store(nullptr, std::memory_order_relaxed); | ||
waiting_.store(false); | ||
auto trigger_event = *triggered; | ||
delete triggered; | ||
return trigger_event; | ||
} | ||
|
||
int EventsWaiter::Clear() { | ||
bool waiting = false; | ||
if (!waiting_.compare_exchange_strong(waiting, true, | ||
std::memory_order_seq_cst, | ||
std::memory_order_relaxed)) { | ||
return -1; | ||
} | ||
trigger_event_.store(nullptr, std::memory_order_relaxed); | ||
waiting_.store(false); | ||
return 0; | ||
} | ||
|
||
void EventsWaiter::TriggerEvent(const EventId& id) { | ||
VLOG(10) << "Try to trigger event id:" << id; | ||
std::string* trigger_event = new std::string; | ||
{ | ||
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_); | ||
auto iter = events_.find(id); | ||
if (iter == events_.end()) { | ||
delete trigger_event; | ||
return; | ||
} | ||
*trigger_event = iter->second.name; | ||
} | ||
std::string* prev = nullptr; | ||
if (!trigger_event_.compare_exchange_strong(prev, trigger_event, | ||
std::memory_order_seq_cst, | ||
std::memory_order_relaxed)) { | ||
delete trigger_event; | ||
return; | ||
} | ||
VLOG(10) << "Triggered event id:" << id << " name:" << *trigger_event; | ||
cv_.Notify(true); | ||
} | ||
|
||
std::string EventsWaiter::GetEventName(const EventId& id) { | ||
std::lock_guard<paddle::memory::SpinLock> guard(events_lock_); | ||
auto iter = events_.find(id); | ||
if (iter == events_.end()) { | ||
return "Unregistered"; | ||
} | ||
return iter->second.name; | ||
} | ||
|
||
} // namespace framework | ||
} // namespace paddle |
111 changes: 111 additions & 0 deletions
111
paddle/fluid/framework/new_executor/workqueue/events_waiter.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#pragma once | ||
|
||
#include <atomic> | ||
#include <cstddef> | ||
#include <functional> | ||
#include <string> | ||
#include <unordered_map> | ||
#include "paddle/fluid/framework/new_executor/workqueue/event_count.h" | ||
#include "paddle/fluid/memory/allocation/spin_lock.h" | ||
|
||
namespace paddle { | ||
namespace framework { | ||
|
||
// A multiplexing waiter, be able to wait multiple kinds of events | ||
// simultaneously. | ||
// Muti-Producer single-consumer single-slot message-queue. | ||
class EventsWaiter { | ||
public: | ||
using EventId = std::size_t; | ||
|
||
using EventChecker = std::function<bool()>; | ||
|
||
// Make sure EventsWaiter has a longer lifetime than EventNotifier. | ||
class EventNotifier { | ||
public: | ||
void NotifyEvent() { waiter_.TriggerEvent(id_); } | ||
|
||
void UnregisterEvent() { waiter_.UnregisterEvent(id_); } | ||
|
||
EventId GetEventId() { return id_; } | ||
|
||
// return "Unregistered" if the corresponding event was unregistered. | ||
std::string GetEventName() { return waiter_.GetEventName(id_); } | ||
|
||
private: | ||
friend EventsWaiter; | ||
EventNotifier(EventId id, EventsWaiter* waiter) | ||
: id_(id), waiter_(*waiter) {} | ||
EventNotifier(const EventNotifier&) = delete; | ||
void operator=(const EventNotifier&) = delete; | ||
|
||
EventId id_; | ||
EventsWaiter& waiter_; | ||
}; | ||
|
||
EventsWaiter(); | ||
EventsWaiter(const EventsWaiter&) = delete; | ||
EventsWaiter& operator=(const EventsWaiter&) = delete; | ||
|
||
// Register a level-triggered event. If the checker returns true or | ||
// EventNotifier::NotifyEvent is called, the corresponding event will be | ||
// distributed. | ||
std::shared_ptr<EventNotifier> RegisterEvent(const std::string& name, | ||
EventChecker checker); | ||
|
||
// Register an edge-triggered event. The corresponding event will be | ||
// distributed when EventNotifier::NotifyEvent is called. | ||
std::shared_ptr<EventNotifier> RegisterEvent(const std::string& name); | ||
|
||
void UnregisterEvent(const EventId& id); | ||
|
||
// Blocking the calling thread to wait any of the registered events. | ||
std::string WaitEvent(); | ||
|
||
// Nonblocking. | ||
// Clear the slot, no matter whether there is an event. | ||
// Return value: | ||
// -1 : another thread is waiting. | ||
// 0 : succ. | ||
int Clear(); | ||
|
||
private: | ||
friend EventNotifier; | ||
|
||
enum class TriggerType { LevelTriggered, EdgeTriggered }; | ||
|
||
struct EventInfo { | ||
EventId id; | ||
std::string name; | ||
TriggerType type; | ||
EventChecker checker; | ||
}; | ||
|
||
void TriggerEvent(const EventId& id); | ||
|
||
std::string GetEventName(const EventId& id); | ||
|
||
std::unordered_map<EventId, EventInfo> events_; | ||
paddle::memory::SpinLock events_lock_; | ||
std::atomic<std::string*> trigger_event_; | ||
std::atomic<uint64_t> counter_; | ||
std::atomic<bool> waiting_; | ||
EventCount cv_; | ||
}; | ||
|
||
} // namespace framework | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.