Skip to content

Commit

Permalink
test out einstein notation for indexing, using einx.get_at
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jan 17, 2024
1 parent f3cb662 commit 2e33a53
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 24 deletions.
18 changes: 11 additions & 7 deletions equiformer_pytorch/equiformer_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from equiformer_pytorch.utils import (
exists,
default,
batched_index_select,
masked_mean,
to_order,
cast_tuple,
Expand All @@ -37,6 +36,8 @@
pad_for_centering_y_to_x
)

from einx import get_at

from einops import rearrange, repeat, reduce, einsum, pack, unpack
from einops.layers.torch import Rearrange

Expand Down Expand Up @@ -336,7 +337,9 @@ def forward(

xi, xj = source[degree_in], target[degree_in]

x = batched_index_select(xj, neighbor_indices, dim = 1)
flattened_neighbor_indices, ps = pack_one(neighbor_indices, 'b *')
x = get_at('b [i] d m, b k -> b k d m', xj, flattened_neighbor_indices)
x = unpack_one(x, ps, 'b * d m')

if self.project_xi_xj:
xi = rearrange(xi, 'b i d m -> b i 1 d m')
Expand Down Expand Up @@ -1215,15 +1218,16 @@ def forward(
dist_values, nearest_indices = modified_rel_dist.topk(total_neighbors, dim = -1, largest = False)
neighbor_mask = dist_values <= valid_radius

neighbor_rel_dist = batched_index_select(rel_dist, nearest_indices, dim = 2)
neighbor_rel_pos = batched_index_select(rel_pos, nearest_indices, dim = 2)
neighbor_indices = batched_index_select(indices, nearest_indices, dim = 2)
neighbor_rel_dist = get_at('b i [j], b i k -> b i k', rel_dist, nearest_indices)
neighbor_rel_pos = get_at('b i [j] c, b i k -> b i k c', rel_pos, nearest_indices)
neighbor_indices = get_at('b i [j], b i k -> b i k', indices, nearest_indices)

if exists(mask):
neighbor_mask = neighbor_mask & batched_index_select(mask, nearest_indices, dim = 2)
nearest_mask = get_at('b i [j], b i k -> b i k', mask, nearest_indices)
neighbor_mask = neighbor_mask & nearest_mask

if exists(edges):
edges = batched_index_select(edges, nearest_indices, dim = 2)
edges = get_at('b i [j] d, b i k -> b i k d', edges, nearest_indices)

# embed relative distances

Expand Down
16 changes: 0 additions & 16 deletions equiformer_pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,22 +51,6 @@ def safe_cat(arr, el, dim):
def cast_tuple(val, depth = 1):
return val if isinstance(val, tuple) else (val,) * depth

def batched_index_select(values, indices, dim = 1):
value_dims = values.shape[(dim + 1):]
values_shape, indices_shape = map(lambda t: list(t.shape), (values, indices))
indices = indices[(..., *((None,) * len(value_dims)))]
indices = indices.expand(*((-1,) * len(indices_shape)), *value_dims)
value_expand_len = len(indices_shape) - (dim + 1)
values = values[(*((slice(None),) * dim), *((None,) * value_expand_len), ...)]

value_expand_shape = [-1] * len(values.shape)
expand_slice = slice(dim, (dim + value_expand_len))
value_expand_shape[expand_slice] = indices.shape[expand_slice]
values = values.expand(*value_expand_shape)

dim += value_expand_len
return values.gather(dim, indices)

def fast_split(arr, splits, dim=0):
axis_len = arr.shape[dim]
splits = min(axis_len, max(splits, 1))
Expand Down
2 changes: 1 addition & 1 deletion equiformer_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.5.1'
__version__ = '0.5.2'
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
install_requires=[
'beartype',
'einops>=0.6',
'einx',
'filelock',
'opt-einsum',
'taylor-series-linear-attention>=0.1.4',
Expand Down

0 comments on commit 2e33a53

Please sign in to comment.