-
Notifications
You must be signed in to change notification settings - Fork 561
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use torch.cdist for dist
#2336
Use torch.cdist for dist
#2336
Conversation
Thanks! Have you tested this on a GPU to check whether the scaling is the same there? Also, have you checked that the gradients are correct with Aside:I was wondering if we could use
Turns out this works, but is quite a bit a lot slower for large Aside 2:If |
cc @gpleiss, @jacobrgardner since this touches some pretty fundamental functionality. |
I'm fine in theory with the substitution, but we should very carefully benchmark speed, memory and gradient stability first, especially on the GPU and especially for large matrix sizes. I'm highly suspicious because I've personally tried to swap out for cdisr several times, and it has always caused us pain. Like, this comment is probably half the age of the package at this point, and the linked to PR is even just up streaming our distance implementation into pytorch: gpytorch/gpytorch/kernels/kernel.py Line 27 in ee35601
|
I tried to re-run the small benchmarks above on a GPU, but had to make the batch dimensions smaller so as to not OOM when using
Hmm, I am seeing both I used Ax’s benchmarking functionality to test a variety of methods (see here) on a variety of problems (see here) and in addition to runtime improving, there was no difference in optimization performance. (On Meta systems, sorry for the lack of transparency.) |
I’m not sure why things seem to be working better now than in the past. There don't seem to have been major changes to torch.cdist in the last couple years, but there were some bug fixes in the backward pass. Numerical issues:These were a PITA. My TLDR on numerical issues is that when I think this has to do with torch #57690: diagonal elements that should be zero are not. (I actually see this as worse on CPU.) AccuracyOn large examples with random numbers, the new and old versions give very similar results. Example: device = "cuda:0"
torch.manual_seed(0)
x1 = torch.rand((16, 10, 300, 100), dtype=torch.double, requires_grad=True, device=device)
x2 = torch.rand((16, 10, 500, 100), dtype=torch.double, requires_grad=True, device=device)
cdist_result = torch.cdist(x1, x2).clamp_min(1e-15)
dist_result = dist(x1, x2)
# 1.1102e-14
torch.abs(cdist_result - dist_result).max() However, it’s possible to construct small examples with near-zero distances where they differ by more and where x1 = torch.tensor([[0.], [3 + 1e-14]], dtype=torch.double, requires_grad=True)
x2 = torch.tensor([[3.]], dtype=torch.double)
# should be 1e-28; actual 4e-16
print(sq_dist(x1, x2)[1, 0])
# 1.04e-28 (correct)
print((torch.cdist(x1, x2) ** 2)[1, 0])
# should be 1e-14; actual 2.1e-8
print(dist(x1, x2)[1, 0])
# is 1.02e-14 (correct)
print(torch.cdist(x1, x2)[1, 0]) GPUI do not see major numerical differences between CPU and GPU. There still appears to be a speed improvement (see above comment). GradientsBoth MemoryLooks very similar on CPU. (It would be much lower with torch.cdist if we got rid of the |
Alright, thanks @esantorella -- I'm pretty convinced by the above that we can once again roll the dice on torch's cdist implementation :-). |
gpytorch.kernels.kernel.dist
is an expensive function that is used heavily because it is used inMaternKernel
.MaternKernel
is often used as a default, for example in BoTorch'sSingleTaskGP
.torch.cdist
computes very nearly the same result asgpytorch.kernels.kernel.dist
, but is much faster. However,torch.cdist(x1, x2).pow(2)
was not faster thangpytorch.kernels.kernel.sq_dist
, so I did not changesq_dist
. I also did not change the behavior ofdist
in thex1_eq_x2
case since this caused numerical issues (see below).Timing
I ran the following micro-benchmarks in a notebook.
Many observations
Output: 1.91 s ± 62.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Output: 7.09 s ± 502 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
So in this case,
dist
is 3.7x faster.High-dimensional case
Output: 196 ms ± 6.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Output: 269 ms ± 10 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
This is a 1.4x speedup.
Internal benchmarks
I also ran more comprehensive benchmarks on Meta-internal code and saw a clear but more moderate speedup, which is to be expected because the benchmarks include other operations. I am currently re-running this because I initially tested a slightly different version.
Accuracy and numerical issues
dist
and the proposed replacement don't generate exactly the same results, although they only differ by 1e-15 or so.sq_dist
seems more vulnerable to numerical inaccuracies (1e-15 - 1e-14 ) than squaringtorch.cdist
. Sincedist
currently uses a square root, this becomes a much larger error.Here is an example:
While putting this PR together, I experimented with using
torch.cdist
insq_dist
and in thex1_eq_x2
case ofdist
. However, this caused a variety ofNotPSDError
s. I am puzzled why the slightly different behavior ofdist
andtorch.cdist
would cause this, and why this does not seem to happen with the implementation that's currently in this PR.