-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathconnectivity_utils.py
90 lines (75 loc) · 3.25 KB
/
connectivity_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""Tools to compute the connectivity of the graph."""
import functools
import numpy as np
from sklearn import neighbors
import tensorflow as tf
def _compute_connectivity(positions, radius, add_self_edges):
"""Get the indices of connected edges with radius connectivity.
Args:
positions: Positions of nodes in the graph. Shape:
[num_nodes_in_graph, num_dims].
radius: Radius of connectivity.
add_self_edges: Whether to include self edges or not.
Returns:
senders indices [num_edges_in_graph]
receiver indices [num_edges_in_graph]
"""
tree = neighbors.KDTree(positions)
receivers_list = tree.query_radius(positions, r=radius)
num_nodes = len(positions)
senders = np.repeat(range(num_nodes), [len(a) for a in receivers_list])
receivers = np.concatenate(receivers_list, axis=0)
if not add_self_edges:
# Remove self edges.
mask = senders != receivers
senders = senders[mask]
receivers = receivers[mask]
return senders, receivers
def _compute_connectivity_for_batch(
positions, n_node, radius, add_self_edges):
"""`compute_connectivity` for a batch of graphs.
Args:
positions: Positions of nodes in the batch of graphs. Shape:
[num_nodes_in_batch, num_dims].
n_node: Number of nodes for each graph in the batch. Shape:
[num_graphs in batch].
radius: Radius of connectivity.
add_self_edges: Whether to include self edges or not.
Returns:
senders indices [num_edges_in_batch]
receiver indices [num_edges_in_batch]
number of edges per graph [num_graphs_in_batch]
"""
positions_per_graph_list = np.split(positions, np.cumsum(n_node[:-1]), axis=0)
receivers_list = []
senders_list = []
n_edge_list = []
num_nodes_in_previous_graphs = 0
# Compute connectivity for each graph in the batch.
for positions_graph_i in positions_per_graph_list:
senders_graph_i, receivers_graph_i = _compute_connectivity(
positions_graph_i, radius, add_self_edges)
num_edges_graph_i = len(senders_graph_i)
n_edge_list.append(num_edges_graph_i)
# Because the inputs will be concatenated, we need to add offsets to the
# sender and receiver indices according to the number of nodes in previous
# graphs in the same batch.
receivers_list.append(receivers_graph_i + num_nodes_in_previous_graphs)
senders_list.append(senders_graph_i + num_nodes_in_previous_graphs)
num_nodes_graph_i = len(positions_graph_i)
num_nodes_in_previous_graphs += num_nodes_graph_i
# Concatenate all of the results.
senders = np.concatenate(senders_list, axis=0).astype(np.int32)
receivers = np.concatenate(receivers_list, axis=0).astype(np.int32)
n_edge = np.stack(n_edge_list).astype(np.int32)
return senders, receivers, n_edge
def compute_connectivity_for_batch_pyfunc(
positions, n_node, radius, add_self_edges=True):
"""`_compute_connectivity_for_batch` wrapped in a pyfunc."""
partial_fn = functools.partial(_compute_connectivity_for_batch, add_self_edges=add_self_edges)
assert positions.dtype == np.float32
senders, receivers, n_edge = tf.py_function(partial_fn, [positions, n_node, np.float32(radius)],[tf.int32, tf.int32, tf.int32])
senders.set_shape([None])
receivers.set_shape([None])
n_edge.set_shape(n_node.get_shape())
return senders, receivers, n_edge