Skip to content

Commit

Permalink
use device, don't use cuda if not there
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Mar 16, 2024
1 parent ce49658 commit 20334cb
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions sae_training/geometric_median.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,14 +94,14 @@ def compute_geometric_median(

dim1 = 10000
dim2 = 768
device = "cuda"
device = "cuda" if torch.cuda.is_available() else "cpu"

sample = (
torch.randn((dim1, dim2), device="cuda") * 100
torch.randn((dim1, dim2), device=device) * 100
) # seems to be the order of magnitude of the actual use case
weights = torch.randn((dim1,), device="cuda")
weights = torch.randn((dim1,), device=device)

torch.tensor(weights, device="cuda")
torch.tensor(weights, device=device)

tic = time.perf_counter()
new = compute_geometric_median(sample, weights=weights, maxiter=100)
Expand Down

0 comments on commit 20334cb

Please sign in to comment.