Skip to content

Commit

Permalink
Replace the numpy.meshgrid() with more efficient torch.meshgrid() (#1475
Browse files Browse the repository at this point in the history
)

Summary:
This PR fixes a performance issue for the model [Super-SloMo](/~https://github.com/pytorch/benchmark/tree/main/torchbenchmark/models/Super_SloMo).  The [`backwarp` class](/~https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/Super_SloMo/slomo_model.py#L213) calls `np.meshgrid()` and `torch.tensor()` to create a grid in [the class constructor](/~https://github.com/pytorch/benchmark/blob/main/torchbenchmark/models/Super_SloMo/slomo_model.py#L232). The `torch` modules provides similar API [`torch.meshgrid()`](https://pytorch.org/docs/stable/generated/torch.meshgrid.html) with far better performance. According to my [example profiling script](https://gist.github.com/CuiJinku/d85436d31aade0f49d13cc7e5f4f844b), the `torch.meshgrid()` has **25X** speedup on a single NVIDIA 3090 GPU.

Pull Request resolved: #1475

Reviewed By: aaronenyeshi

Differential Revision: D43954893

Pulled By: xuzhao9

fbshipit-source-id: 2b38e653594a64364fe299c84a327d5407ba39dc
  • Loading branch information
CuiJinku authored and facebook-github-bot committed Mar 11, 2023
1 parent 0f02ca6 commit d8e5325
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions torchbenchmark/models/Super_SloMo/slomo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,13 @@ def __init__(self, W, H, device):

super(backWarp, self).__init__()
# create a grid
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
self.W = W
self.H = H
self.gridX = torch.tensor(gridX, requires_grad=False, device=device)
self.gridY = torch.tensor(gridY, requires_grad=False, device=device)

# Use torch.meshgrid instead of np.meshgrid to imrpove performance
# /~https://github.com/avinashpaliwal/Super-SloMo/pull/111
self.gridX, self.gridY = torch.meshgrid(torch.arange(W, requires_grad=False, device=device),
torch.arange(H, requires_grad=False, device=device), indexing='xy')

def forward(self, img, flow):
"""
Expand Down

0 comments on commit d8e5325

Please sign in to comment.