diff --git a/tests/python/unittest/test_dgl_graph.py b/tests/python/unittest/test_dgl_graph.py index f996d7f38de8..069fef6e32f0 100644 --- a/tests/python/unittest/test_dgl_graph.py +++ b/tests/python/unittest/test_dgl_graph.py @@ -63,6 +63,18 @@ def check_non_uniform(out, num_hops, max_num_vertices): for data in layer: assert(data <= num_hops) +def check_compact(csr, id_arr, num_nodes): + compact = mx.nd.contrib.dgl_graph_compact(csr, id_arr, graph_sizes=num_nodes, return_mapping=False) + assert compact.shape[0] == num_nodes + assert compact.shape[1] == num_nodes + assert mx.nd.sum(compact.indptr == csr.indptr[0:(num_nodes + 1)]).asnumpy() == num_nodes + 1 + sub_indices = compact.indices.asnumpy() + indices = csr.indices.asnumpy() + id_arr = id_arr.asnumpy() + for i in range(len(sub_indices)): + sub_id = sub_indices[i] + assert id_arr[sub_id] == indices[i] + def test_uniform_sample(): shape = (5, 5) data_np = np.array([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20], dtype=np.int64) @@ -74,36 +86,64 @@ def test_uniform_sample(): out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5) assert (len(out) == 3) check_uniform(out, num_hops=1, max_num_vertices=5) + num_nodes = out[0][-1].asnumpy() + assert num_nodes > 0 + assert num_nodes < len(out[0]) + check_compact(out[1], out[0], num_nodes) seed = mx.nd.array([0], dtype=np.int64) out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=1, max_num_vertices=4) assert (len(out) == 3) check_uniform(out, num_hops=1, max_num_vertices=4) + num_nodes = out[0][-1].asnumpy() + assert num_nodes > 0 + assert num_nodes < len(out[0]) + check_compact(out[1], out[0], num_nodes) seed = mx.nd.array([0], dtype=np.int64) out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=2, num_neighbor=1, max_num_vertices=4) assert (len(out) == 3) check_uniform(out, num_hops=2, max_num_vertices=4) + num_nodes = out[0][-1].asnumpy() + assert num_nodes > 0 + assert num_nodes < len(out[0]) + check_compact(out[1], out[0], num_nodes) seed = mx.nd.array([0,2,4], dtype=np.int64) out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5) assert (len(out) == 3) check_uniform(out, num_hops=1, max_num_vertices=5) + num_nodes = out[0][-1].asnumpy() + assert num_nodes > 0 + assert num_nodes < len(out[0]) + check_compact(out[1], out[0], num_nodes) seed = mx.nd.array([0,4], dtype=np.int64) out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5) assert (len(out) == 3) check_uniform(out, num_hops=1, max_num_vertices=5) + num_nodes = out[0][-1].asnumpy() + assert num_nodes > 0 + assert num_nodes < len(out[0]) + check_compact(out[1], out[0], num_nodes) seed = mx.nd.array([0,4], dtype=np.int64) out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=2, num_neighbor=2, max_num_vertices=5) assert (len(out) == 3) check_uniform(out, num_hops=2, max_num_vertices=5) + num_nodes = out[0][-1].asnumpy() + assert num_nodes > 0 + assert num_nodes < len(out[0]) + check_compact(out[1], out[0], num_nodes) seed = mx.nd.array([0,4], dtype=np.int64) out = mx.nd.contrib.dgl_csr_neighbor_uniform_sample(a, seed, num_args=2, num_hops=1, num_neighbor=2, max_num_vertices=5) assert (len(out) == 3) check_uniform(out, num_hops=1, max_num_vertices=5) + num_nodes = out[0][-1].asnumpy() + assert num_nodes > 0 + assert num_nodes < len(out[0]) + check_compact(out[1], out[0], num_nodes) def test_non_uniform_sample(): shape = (5, 5)