Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Priority-based parameter propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
Anand J committed Jun 5, 2019
1 parent a16009d commit 07f9818
Showing 1 changed file with 98 additions and 56 deletions.
154 changes: 98 additions & 56 deletions src/kvstore/kvstore_dist.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class KVStoreDist : public KVStoreLocal {
ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
}
}
bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000);
bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 40 * 1000);
log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
}

Expand Down Expand Up @@ -234,35 +234,62 @@ class KVStoreDist : public KVStoreLocal {
// it may happen for the first time a no-rank-0 worker pull the weight.
recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_,
true, grouped_vals[i][0]->dtype());
}
auto pull_from_servers = [this, key, recv_buf](
RunContext rctx, Engine::CallbackOnComplete cb) {
// convert to ps keys
size_t size = recv_buf.shape().Size();
const int dtype = recv_buf.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ?
EncodeDefaultKey(key, size, num_bytes) :
EncodeCompressedKey(key, size, false, num_bytes);
char* data = static_cast<char*> (recv_buf.data().dptr_);
// false means not to delete data when SArray is deleted
auto vals = new ps::SArray<char>(data, size * num_bytes, false);
// issue pull
RequestType mode = (gradient_compression_->get_type() != CompressionType::kNone) ?
RequestType::kCompressedPushPull : RequestType::kDefaultPushPull;
const int cmd = GetCommandType(mode, dtype);
CHECK_NOTNULL(ps_worker_)->ZPull(
pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
};
auto pull_from_servers = [this, key, recv_buf, priority](
RunContext rctx, Engine::CallbackOnComplete cb) {
// convert to ps keys
size_t size = recv_buf.shape().Size();
const int dtype = recv_buf.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
PSKV& pskv = (gradient_compression_->get_type() == CompressionType::kNone) ?
EncodeDefaultKey(key, size, num_bytes) :
EncodeCompressedKey(key, size, false, num_bytes);
char* data = static_cast<char*> (recv_buf.data().dptr_);
// false means not to delete data when SArray is deleted
auto vals = new ps::SArray<char>(data, size * num_bytes, false);
// issue pull
RequestType mode = (gradient_compression_->get_type() != CompressionType::kNone) ?
RequestType::kCompressedPushPull : RequestType::kDefaultPushPull;
const int cmd = GetCommandType(mode, dtype);

int len = 0;
auto *counter = new std::atomic<int>(pskv.keys.size());
for (size_t i = 0; i < pskv.keys.size(); i++) {
auto vs = new ps::SArray<char>(std::move(vals->segment(len, len+pskv.lens[i])));
auto ls = new ps::SArray<int>(std::move(pskv.lens.segment(i, i+1)));
CHECK_NOTNULL(ps_worker_)->ZPull(
std::move(pskv.keys.segment(i, i+1)), vs, ls, cmd,
[vs, ls, vals, cb, counter]() {
delete vs;
delete ls;
(*counter)--;
if (counter->load() == 0) {
delete vals;
delete counter;
cb();
}
}, priority);
len += pskv.lens[i];
}
};

CHECK_NOTNULL(Engine::Get())->PushAsync(
pull_from_servers,
pinned_ctx_,
{},
{recv_buf.var()},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultStoragePull");
CHECK_NOTNULL(Engine::Get())->PushAsync(
pull_from_servers,
pinned_ctx_,
{},
{recv_buf.var()},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultStoragePull");
} else {
CHECK_NOTNULL(Engine::Get())->PushAsync(
[] (RunContext rctx, Engine::CallbackOnComplete cb) { cb(); },
pinned_ctx_,
{},
{recv_buf.var()},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultStoragePull");
}

comm_->Broadcast(key, recv_buf, grouped_vals[i], priority);
}
Expand Down Expand Up @@ -412,17 +439,38 @@ class KVStoreDist : public KVStoreLocal {

void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) {
auto push_to_servers =
[this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
[this, key, pskv, send_buf, priority](RunContext rctx, Engine::CallbackOnComplete cb) {
const int dtype = send_buf.dtype();
// convert to ps keys
const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
char* data = static_cast<char *>(send_buf.data().dptr_);
// do push. false means no delete
ps::SArray<char> vals(data, size, false);
int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
CHECK_NOTNULL(ps_worker_)->ZPush(
pskv.keys, vals, pskv.lens,
cmd, [cb]() { cb(); });

auto *counter = new std::atomic<int>(pskv.keys.size());
int len = 0;
for (size_t i = 0; i < pskv.keys.size(); i++) {
auto ks = new ps::SArray<ps::Key>(std::move(pskv.keys.segment(i, i+1)));
auto vs = new ps::SArray<char>(std::move(vals.segment(len, len+pskv.lens[i])));
auto ls = new ps::SArray<int>(std::move(pskv.lens.segment(i, i+1)));
CHECK_NOTNULL(ps_worker_)->ZPush(*ks, *vs, *ls,
static_cast<int>(RequestType::kDefaultPushPull),
[this, cb, counter, ks, vs, ls, cmd]() {
CHECK_NOTNULL(ps_worker_)->ZPull(*ks, vs, ls, cmd,
[cb, counter, ks, vs, ls]() {
delete ks;
delete vs;
delete ls;
(*counter)--;
if (counter->load() == 0) {
delete counter;
cb();
}
});
}, priority);
len += pskv.lens[i];
}
};
Engine::Get()->PushAsync(
push_to_servers,
Expand Down Expand Up @@ -544,31 +592,25 @@ class KVStoreDist : public KVStoreLocal {
const int num_servers = krs.size();
CHECK_GT(num_servers, 0);

// a simple heuristic for load balance
if (num_arr_elems < bigarray_bound_) {
// send it to a single random picked server
int server = (key * 9973) % num_servers;
ps::Key ps_key = krs[server].begin() + key;
CHECK_LT(ps_key, krs[server].end());
/**
* Round-Robin key assignment
*/
int64_t params = pskv_size;
int64_t slice_bound = bigarray_bound_ * num_bytes;
static ps::Key server = 0;
while (params > 0) {
ps::Key ps_key = krs[server%num_servers].begin()
+ (ps::Key)(key + server/num_servers);
CHECK_LT(ps_key, krs[server%num_servers].end());
pskv.keys.push_back(ps_key);
const int total_bytes = num_arr_elems * num_bytes;
pskv.lens.push_back(total_bytes);
pskv.size = total_bytes;
} else {
// parition it to all servers
pskv.size = 0;
for (int i = 0; i < num_servers; ++i) {
size_t part_size =
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*(i+1))) -
static_cast<size_t>(round(static_cast<double>(num_arr_elems)/num_servers*i));
ps::Key ps_key = krs[i].begin() + key;
CHECK_LT(ps_key, krs[i].end());
pskv.keys.push_back(ps_key);
const int total_bytes = part_size * num_bytes;
pskv.lens.push_back(total_bytes);
pskv.size += total_bytes;
}
const size_t part_size = static_cast<size_t>((params > slice_bound) ? slice_bound : params);
pskv.lens.push_back(part_size);
pskv.size += part_size;

params -= part_size;
server++;
}

CHECK_EQ(static_cast<size_t>(pskv.size), pskv_size);
}
return pskv;
Expand Down

0 comments on commit 07f9818

Please sign in to comment.