-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Accelerate DGL csr neighbor sampling #13588
Conversation
@aksnzhy could you please show the benchmark result? |
The benchmark result are as follows:
|
The improvement in this PR has about 2-3 times speedup over the implementation in the master branch.
The speed in this PR:
|
3ee7bec
to
043c1b4
Compare
@BullDemonKing Thanks for the contribution! could you take a look at failed tests? |
@mxnet-label-bot add[Operator, pr-awaiting-review] |
std::queue<ver_node> node_queue; | ||
std::unordered_set<dgl_id_t> sub_ver_mp; | ||
std::vector<std::pair<dgl_id_t, dgl_id_t> > sub_vers; | ||
sub_vers.reserve(num_seeds * 10); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is 10 a general good constant for sampling?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This number is used for reserving memory, so it doesn't have to be very accurate. The goal here is to balance memory reallocation and memory consumption in std::vector. If num_hop is 1, 10 might be good enough; if num_hop is 2, 10 will be too small. But I don't feel that we need to make it overcomplex. After all, we only want to reduce the overhead of memory reallocation in std::vector.
It looks really good for this pr. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just some minor comments about documentation, which I think is ok to improve in later PR.
} | ||
for (dgl_id_t i = num_vertices+1; i <= max_num_vertices; ++i) { | ||
for (size_t i = num_vertices+1; i <= max_num_vertices; ++i) { | ||
indptr_out[i] = indptr_out[i-1]; | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mind also enhancing the documentation for the operator? For example, the number of outputs and what each output is for. For example, -1 will be filled for vertex id not sampled.
Also, printing the output of a.asnumpy()
will be helpful
[[ 0 1 2 3 4]
[ 5 0 6 7 8]
[ 9 10 0 11 12]
[13 14 15 0 16]
[17 18 19 20 0]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regarding the example usage:
out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5)
I don't think num_args
has to be set by the user. It should be automatically set /~https://github.com/apache/incubator-mxnet/blob/779bdc5e7ee3abd6a2d23e8bd97d47fc08ae6bc5/src/imperative/imperative_utils.h#L311-L313
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, the current example seed doesn't result in -1
in the output. Maybe we want to pick some input which is more representative
// Let's check if there is a vertex that we haven't sampled its neighbors. | ||
for (; idx < sub_vers.size(); idx++) { | ||
if (sub_vers[idx].second < num_hops) { | ||
LOG(WARNING) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is LOG(WARNING) working? I was not aware that this is functional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it works. i can see warning messages after the number of sampled vertices exceeds the maximal number.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great improvement
@@ -718,20 +691,37 @@ static void SampleSubgraph(const NDArray &csr, | |||
dgl_id_t* indptr_out = sub_csr.aux_data(0).dptr<dgl_id_t>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: fix indentation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the problem of the indent here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nvm. Incorrect rendering by github. looks good
* Speedup and fix bug in dgl_csr_sampling op * Update dgl_graph.cc * simplify functions. * avoid adding nodes in the last level in the queue. * remove a hashtable lookup in neigh_pos. * reduce a hashtable lookup in sub_ver_mp. * merge copying vids and layers. * reduce hashtable lookup when writing to output csr. * fix a bug. * limit the number of sampled vertices. * fix lint. * fix a compile error. * fix compile error. * fix compile. * remove one hashtable lookup per vertex and hashtable iteration. * remove queue. * use vector for neigh_pos. * fix lint * avoid init output arrays. * fix tests. * fix tests. * update docs. * retrigger * retrigger
Description
The DGL csr neighbor sampling has many hashtable lookups. Hashtable lookups turn out to be expensive operations. This PR tries to accelerate sampling by reducing hashtable lookups.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
Comments