Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Priority-based parameter propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
anandj91 committed Oct 15, 2019
1 parent 3249e4d commit ebb4317
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 38 deletions.
6 changes: 6 additions & 0 deletions ci/docker/runtime_functions.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,12 @@ integrationtest_ubuntu_gpu_dist_kvstore() {
cd tests/nightly/
../../tools/launch.py -n 7 --launcher local python dist_device_sync_kvstore.py
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=init_gpu
export MXNET_KVSTORE_SLICE_THRESHOLD=40000
export DMLC_PS_WATER_MARK=1
../../tools/launch.py -n 7 --launcher local python dist_device_sync_kvstore.py
../../tools/launch.py -n 7 --launcher local python dist_sync_kvstore.py --type=init_gpu
unset MXNET_KVSTORE_SLICE_THRESHOLD
unset DMLC_PS_WATER_MARK
popd
}

Expand Down
5 changes: 5 additions & 0 deletions docs/static_site/src/pages/api/faq/env_var.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,11 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
- Values: 0(false) or 1(true) ```(default=1)```
- If true, weight updates are performed during the communication step, if possible.

* MXNET_KVSTORE_SLICE_THRESHOLD
- Values: Int ```(default=0)```
- The maximum number of parameters in a parameter slice for priority-based parameter synchronization.
- Only used when priority-based update in KVStore is enabled.

## Memonger

* MXNET_BACKWARD_DO_MIRROR
Expand Down
189 changes: 151 additions & 38 deletions src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class KVStoreDist : public KVStoreLocal {
}
}
bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000);
slice_threshold_ = dmlc::GetEnv("MXNET_KVSTORE_SLICE_THRESHOLD", 0);
rr_ = (slice_threshold_ >= 40000);
log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
}

Expand Down Expand Up @@ -178,6 +180,7 @@ class KVStoreDist : public KVStoreLocal {
* not perfectly divisible.
*/
std::unordered_map<int, PSKV> ps_kv_;
std::unordered_map<int, PSKV> rr_ps_kv_;
std::unordered_map<int, ComprPSKV> compr_ps_kv_;

/**
Expand All @@ -188,11 +191,19 @@ class KVStoreDist : public KVStoreLocal {
void InitImpl(const std::vector<int>& keys,
const std::vector<NDArray>& values) override {
CheckUnique(keys);

for (size_t i = 0; i < keys.size(); ++i) {
comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
if (slice_threshold_ > 0) {
// Initialize the slices as this is the only function that is assured to
// be called in the same order in all the worker machines
EncodeDefaultKey(keys[i], values[i].shape().Size(),
mshadow::mshadow_sizeof(values[i].dtype()));
}
}
if (get_rank() == 0 && this->ps_worker_->get_customer()->customer_id() == 0) {
Push_(keys, values, 0, false);

// wait until the push is finished
for (const int key : keys) {
comm_buf_[key].WaitToWrite();
Expand Down Expand Up @@ -253,19 +264,39 @@ class KVStoreDist : public KVStoreLocal {

CHECK(gradient_compression_->get_type() == CompressionType::kNone)
<< "Compression not supported with PushPull";
auto pushpull = [this, key, comm_buf](
auto pushpull = [this, key, comm_buf, priority](
RunContext rctx, Engine::CallbackOnComplete cb) {
size_t size = comm_buf.shape().Size();
const int dtype = comm_buf.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
const int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);

PSKV& pskv = EncodeDefaultKey(key, size, num_bytes);
PSKV& pskv = EncodeDefaultKey(key, size, num_bytes, rr_);
char* data = static_cast<char*>(comm_buf.data().dptr_);
auto vals = new ps::SArray<char>(data, size * num_bytes, false);

CHECK_NOTNULL(ps_worker_)->ZPushPull(
pskv.keys, *vals, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
if (rr_) {
size_t off = 0;
auto counter = new std::atomic<int>(pskv.keys.size());
for (size_t idx = 0; idx < pskv.keys.size(); idx++) {
auto ks = pskv.keys.segment(idx, idx+1);
auto ls = pskv.lens.segment(idx, idx+1);
auto vs = vals->segment(off, off + pskv.lens[idx]);
CHECK_NOTNULL(ps_worker_)->ZPushPull(
ks, vs, &vs, &ls, cmd, [vals, counter, cb]() {
if (--(*counter) == 0) {
delete vals;
delete counter;
cb();
}
},
priority);
off += pskv.lens[idx];
}
} else {
CHECK_NOTNULL(ps_worker_)->ZPushPull(
pskv.keys, *vals, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
}
};

CHECK_NOTNULL(Engine::Get())->PushAsync(
Expand Down Expand Up @@ -324,8 +355,27 @@ class KVStoreDist : public KVStoreLocal {
RequestType mode = (gradient_compression_->get_type() != CompressionType::kNone) ?
RequestType::kCompressedPushPull : RequestType::kDefaultPushPull;
const int cmd = GetCommandType(mode, dtype);
CHECK_NOTNULL(ps_worker_)->ZPull(
pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
if (rr_) {
size_t off = 0;
auto counter = new std::atomic<int>(pskv.keys.size());
for (size_t idx = 0; idx < pskv.keys.size(); idx++) {
auto ks = pskv.keys.segment(idx, idx+1);
auto ls = pskv.lens.segment(idx, idx+1);
auto vs = vals->segment(off, off + pskv.lens[idx]);
CHECK_NOTNULL(ps_worker_)->ZPull(
ks, &vs, &ls, cmd, [vals, counter, cb]() {
if (--(*counter) == 0) {
delete vals;
delete counter;
cb();
}
});
off += pskv.lens[idx];
}
} else {
CHECK_NOTNULL(ps_worker_)->ZPull(
pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
}
};

CHECK_NOTNULL(Engine::Get())->PushAsync(
Expand Down Expand Up @@ -493,9 +543,26 @@ class KVStoreDist : public KVStoreLocal {
// do push. false means no delete
ps::SArray<char> vals(data, size, false);
int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
CHECK_NOTNULL(ps_worker_)->ZPush(
pskv.keys, vals, pskv.lens,
cmd, [cb]() { cb(); });
if (rr_) {
size_t off = 0;
auto counter = new std::atomic<int>(pskv.keys.size());
for (size_t idx = 0; idx < pskv.keys.size(); idx++) {
auto ks = pskv.keys.segment(idx, idx+1);
auto ls = pskv.lens.segment(idx, idx+1);
auto vs = vals.segment(off, off + pskv.lens[idx]);
CHECK_NOTNULL(ps_worker_)->ZPush(
ks, vs, ls, cmd, [vals, counter, cb]() {
if (--(*counter) == 0) {
delete counter;
cb();
}
});
}
} else {
CHECK_NOTNULL(ps_worker_)->ZPush(
pskv.keys, vals, pskv.lens,
cmd, [cb]() { cb(); });
}
};
Engine::Get()->PushAsync(
push_to_servers,
Expand Down Expand Up @@ -601,52 +668,93 @@ class KVStoreDist : public KVStoreLocal {
* \param key
* \param num_arr_elems number of elements in the value for key
* \param num_bytes size of each element in number of bytes
* \param rr whether to use round-robin distribution
* \return PSKV used for both push and pull
*/
inline PSKV& EncodeDefaultKey(const int key, const size_t num_arr_elems,
const int num_bytes) {
const int num_bytes, bool rr = false) {
mu_.lock();
PSKV& pskv = ps_kv_[key];
PSKV& pskv = (rr_) ? rr_ps_kv_[key] : ps_kv_[key];
mu_.unlock();

size_t pskv_size = num_arr_elems * num_bytes;
if (!pskv.keys.empty()) {
CHECK_EQ(static_cast<size_t>(pskv.size), pskv_size)
<< "The value size cannot be changed " << pskv_size << ". Key is " << key;
} else {
auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
const int num_servers = krs.size();
CHECK_GT(num_servers, 0);

// a simple heuristic for load balance
if (num_arr_elems < bigarray_bound_) {
// send it to a single random picked server
int server = (key * 9973) % num_servers;
ps::Key ps_key = krs[server].begin() + key;
CHECK_LT(ps_key, krs[server].end());
pskv.keys.push_back(ps_key);
const int total_bytes = num_arr_elems * num_bytes;
pskv.lens.push_back(total_bytes);
pskv.size = total_bytes;
if (rr_) {
RRKeyDist(&pskv, key, num_arr_elems, num_bytes);
} else {
// parition it to all servers
pskv.size = 0;
for (int i = 0; i < num_servers; ++i) {
size_t part_size =
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*(i+1))) -
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*i));
ps::Key ps_key = krs[i].begin() + key;
CHECK_LT(ps_key, krs[i].end());
pskv.keys.push_back(ps_key);
const int total_bytes = part_size * num_bytes;
pskv.lens.push_back(total_bytes);
pskv.size += total_bytes;
}
DefaultKeyDist(&pskv, key, num_arr_elems, num_bytes);
}
CHECK_EQ(static_cast<size_t>(pskv.size), pskv_size);
}
return pskv;
}

/**
* Round-Robin key distribution strategy
*/
void RRKeyDist(PSKV* pskv, const int key, const size_t num_arr_elems,
const int num_bytes) {
auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
const int num_servers = krs.size();
CHECK_GT(num_servers, 0);

int64_t num_params = num_arr_elems * num_bytes;
int64_t slice_bound = slice_threshold_ * num_bytes;
static size_t server = 0;
while (num_params > 0) {
ps::Key ps_key = krs[server%num_servers].begin()
+ (ps::Key)(key + server/num_servers);
CHECK_LT(ps_key, krs[server%num_servers].end());
pskv->keys.push_back(ps_key);
const size_t part_size = static_cast<size_t>((num_params > slice_bound)
? slice_bound : num_params);
pskv->lens.push_back(part_size);
pskv->size += part_size;

num_params -= part_size;
server++;
}
}

/**
* Default key distribution strategy
*/
void DefaultKeyDist(PSKV* pskv, const int key, const size_t num_arr_elems,
const int num_bytes) {
auto krs = ps::Postoffice::Get()->GetServerKeyRanges();
const int num_servers = krs.size();
CHECK_GT(num_servers, 0);

// a simple heuristic for load balance
if (num_arr_elems < bigarray_bound_) {
// send it to a single random picked server
int server = (key * 9973) % num_servers;
ps::Key ps_key = krs[server].begin() + key;
CHECK_LT(ps_key, krs[server].end());
pskv->keys.push_back(ps_key);
const int total_bytes = num_arr_elems * num_bytes;
pskv->lens.push_back(total_bytes);
pskv->size = total_bytes;
} else {
// parition it to all servers
pskv->size = 0;
for (int i = 0; i < num_servers; ++i) {
size_t part_size =
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*(i+1))) -
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*i));
ps::Key ps_key = krs[i].begin() + key;
CHECK_LT(ps_key, krs[i].end());
pskv->keys.push_back(ps_key);
const int total_bytes = part_size * num_bytes;
pskv->lens.push_back(total_bytes);
pskv->size += total_bytes;
}
}
}

/**
* \brief Convert to PSKV for pushes and pulls when gradient compression is used.
* Divides original array into equal parts for each server.
Expand Down Expand Up @@ -818,6 +926,11 @@ class KVStoreDist : public KVStoreLocal {
* \brief threshold for partition
*/
size_t bigarray_bound_;
/**
* \brief threshold for the parameter slice size
*/
size_t slice_threshold_;
bool rr_;
/**
* \brief buffer for non-compressed data.
* When gradient compression is active, this is used
Expand Down

0 comments on commit ebb4317

Please sign in to comment.