Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use torch.cdist for dist #2336

Merged
merged 2 commits into from
May 9, 2023
Merged

Conversation

esantorella
Copy link
Collaborator

@esantorella esantorella commented May 3, 2023

gpytorch.kernels.kernel.dist is an expensive function that is used heavily because it is used inMaternKernel. MaternKernel is often used as a default, for example in BoTorch's SingleTaskGP.

torch.cdist computes very nearly the same result as gpytorch.kernels.kernel.dist, but is much faster. However, torch.cdist(x1, x2).pow(2) was not faster than gpytorch.kernels.kernel.sq_dist, so I did not change sq_dist. I also did not change the behavior of dist in the x1_eq_x2 case since this caused numerical issues (see below).

Timing

I ran the following micro-benchmarks in a notebook.

Many observations

import torch
from gpytorch.kernels.kernel import sq_dist, dist

x1 = torch.rand((16, 100, 3000, 10), dtype=float, requires_grad=True)
x2 = torch.rand((16, 100, 500, 10), dtype=float, requires_grad=True)
%%timeit
torch.cdist(x1, x2)

Output: 1.91 s ± 62.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
dist(x1, x2)

Output: 7.09 s ± 502 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

So in this case, dist is 3.7x faster.

High-dimensional case

x1 = torch.rand((16, 100, 30, 1000), dtype=float, requires_grad=True)
x2 = torch.rand((16, 100, 5, 1000), dtype=float, requires_grad=True)

Output: 196 ms ± 6.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

%%timeit
dist(x1, x2)

Output: 269 ms ± 10 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

This is a 1.4x speedup.

Internal benchmarks

I also ran more comprehensive benchmarks on Meta-internal code and saw a clear but more moderate speedup, which is to be expected because the benchmarks include other operations. I am currently re-running this because I initially tested a slightly different version.

Accuracy and numerical issues

dist and the proposed replacement don't generate exactly the same results, although they only differ by 1e-15 or so.

sq_dist seems more vulnerable to numerical inaccuracies (1e-15 - 1e-14 ) than squaring torch.cdist. Since dist currently uses a square root, this becomes a much larger error.

Here is an example:

x1 = torch.tensor([[0.], [3 + 1e-14]], dtype=torch.double)
x2 = torch.tensor([[3.]], dtype=torch.double)

expected_sq_dist = 1e-28
print(sq_dist(x1, x2)[1, 0])  # 4.44e-16
print(torch.cdist(x1, x2)[1, 0] ** 2)  # 1.04e-28
print(dist(x1, x2)[1, 0])  # 2.11e-08
print(torch.cdist(x1, x2)[1, 0])  # 1.02e-14

While putting this PR together, I experimented with using torch.cdist in sq_dist and in the x1_eq_x2 case of dist. However, this caused a variety of NotPSDErrors. I am puzzled why the slightly different behavior of dist and torch.cdist would cause this, and why this does not seem to happen with the implementation that's currently in this PR.

@esantorella esantorella requested a review from Balandat May 3, 2023 22:36
@Balandat
Copy link
Collaborator

Balandat commented May 3, 2023

Thanks!

Have you tested this on a GPU to check whether the scaling is the same there?

Also, have you checked that the gradients are correct with cdist? I vaguely recall that this had caused issues in the past.

Aside:

I was wondering if we could use pdist in the following fashion:

def pdist_sqdist(x):
    n = x.shape[-2]
    half_out = torch.zeros(n, n, device=x.device, dtype=x.dtype)
    row_idxr, col_idxr = torch.triu_indices(n, n, offset=1)
    half_out[row_idxr, col_idxr] = torch.pdist(x).pow(2)
    return half_out + half_out.transpose(-1, -2) 

Turns out this works, but is quite a bit a lot slower for large n - presumably torch.pdist is just not implemented very well. Also, it doesn't work in batch mode...

Aside 2:

If pdist were faster, it would be quite useful to have implementations of Cholesky and the like that just take in the flattened upper triangular portion of a p.s.d. matrix. Then there would not really be any need to construct the full matrix...

@Balandat
Copy link
Collaborator

Balandat commented May 3, 2023

cc @gpleiss, @jacobrgardner since this touches some pretty fundamental functionality.

@jacobrgardner
Copy link
Member

I'm fine in theory with the substitution, but we should very carefully benchmark speed, memory and gradient stability first, especially on the GPU and especially for large matrix sizes.

I'm highly suspicious because I've personally tried to swap out for cdisr several times, and it has always caused us pain. Like, this comment is probably half the age of the package at this point, and the linked to PR is even just up streaming our distance implementation into pytorch:

# TODO: use torch squared cdist once implemented: /~https://github.com/pytorch/pytorch/pull/25799

@esantorella
Copy link
Collaborator Author

esantorella commented May 4, 2023

Have you tested this on a GPU to check whether the scaling is the same there?

I tried to re-run the small benchmarks above on a GPU, but had to make the batch dimensions smaller so as to not OOM when using dist. There was a 2.08x speedup for the large m/n case and a 1.26x speedup for the large-dim case.

Also, have you checked that the gradients are correct with cdist? I vaguely recall that this had caused issues in the past.

Hmm, I am seeing both dist and torch.cdist(...).clamp_min(1e-15) do poorly for differences near zero according to torch.autograd.gradcheck, but then neither is differentiable in that region. Thoughts on what constitutes "correct" behavior here?

I used Ax’s benchmarking functionality to test a variety of methods (see here) on a variety of problems (see here) and in addition to runtime improving, there was no difference in optimization performance. (On Meta systems, sorry for the lack of transparency.)

@esantorella
Copy link
Collaborator Author

I'm fine in theory with the substitution, but we should very carefully benchmark speed, memory and gradient stability first, especially on the GPU and especially for large matrix sizes.
I'm highly suspicious because I've personally tried to swap out for cdisr several times, and it has always caused us pain.

I’m not sure why things seem to be working better now than in the past. There don't seem to have been major changes to torch.cdist in the last couple years, but there were some bug fixes in the backward pass.

Numerical issues:

These were a PITA. My TLDR on numerical issues is that when x1_eq_x2=False, cdist seems to be better than dist, but dist is better when x1_eq_x2=True, so this PR preserves that behavior. I initially tried using cdist for everything, but this created downstream not-PSD errors in unit tests for beanmachine, kats, and BoTorch. With the implementation in this PR, everything passes.

I think this has to do with torch #57690: diagonal elements that should be zero are not. (I actually see this as worse on CPU.) dist also has this issue when x1_eq_x2=False and requires_grad=False for x1 and x2, but fixes it by directly setting the diagonal to zero to zero otherwise. I got tests passing by using the legacy behavior in that case. (This would also be fixable by using torch.cdist with compute_mode='donot_use_mm_for_euclid_dist', but it's very slow.)

Accuracy

On large examples with random numbers, the new and old versions give very similar results. Example:

device = "cuda:0"
torch.manual_seed(0)
x1 = torch.rand((16, 10, 300, 100), dtype=torch.double, requires_grad=True, device=device)
x2 = torch.rand((16, 10, 500, 100), dtype=torch.double, requires_grad=True, device=device)
cdist_result = torch.cdist(x1, x2).clamp_min(1e-15)
dist_result = dist(x1, x2)
# 1.1102e-14
torch.abs(cdist_result - dist_result).max()

However, it’s possible to construct small examples with near-zero distances where they differ by more and where torch.cdist is more accurate. This seems to stem from a small floating-point error in the adjustment logic, which can become large once you take a square root. Example:

x1 = torch.tensor([[0.], [3 + 1e-14]], dtype=torch.double, requires_grad=True)
x2 = torch.tensor([[3.]], dtype=torch.double)


# should be 1e-28; actual 4e-16
print(sq_dist(x1, x2)[1, 0])
# 1.04e-28 (correct)
print((torch.cdist(x1, x2) ** 2)[1, 0])
# should be 1e-14; actual 2.1e-8
print(dist(x1, x2)[1, 0])
# is 1.02e-14 (correct)
print(torch.cdist(x1, x2)[1, 0])

GPU

I do not see major numerical differences between CPU and GPU. There still appears to be a speed improvement (see above comment).

Gradients

Both dist and cdist will fail torch.autograd.checkgrad near zero, although that's because the finite difference estimate is wrong -- their gradients look correct to me. Away from zero, they have very similar gradients and pass torch.autograd.check_grad.

Memory

Looks very similar on CPU. (It would be much lower with torch.cdist if we got rid of the clamp_min.) I'm trying to figure out how to do this on a GPU.

@jacobrgardner
Copy link
Member

Alright, thanks @esantorella -- I'm pretty convinced by the above that we can once again roll the dice on torch's cdist implementation :-).

@jacobrgardner jacobrgardner merged commit 527546e into cornellius-gp:master May 9, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants