Skip to content

Commit

Permalink
makefile check type and types of geometric media
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Mar 19, 2024
1 parent 0168612 commit 736bf83
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
2 changes: 2 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ check-format:
poetry run black --check .
poetry run isort --check-only --diff .

check-type:
poetry run pyright .

test:
make unit-test
Expand Down
17 changes: 10 additions & 7 deletions sae_training/geometric_median.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from types import SimpleNamespace
from typing import Optional

import torch
import tqdm


def weighted_average(points, weights):
def weighted_average(points: torch.Tensor, weights: torch.Tensor):
weights = weights / weights.sum()
return (points * weights.view(-1, 1)).sum(dim=0)


@torch.no_grad()
def geometric_median_objective(median, points, weights):
def geometric_median_objective(
median: torch.Tensor, points: torch.Tensor, weights: torch.Tensor
) -> torch.Tensor:

norms = torch.linalg.norm(points - median.view(1, -1), dim=1)

Expand All @@ -19,11 +22,11 @@ def geometric_median_objective(median, points, weights):

def compute_geometric_median(
points: torch.Tensor,
weights: torch.Tensor = None,
eps=1e-6,
maxiter=100,
ftol=1e-20,
do_log=False,
weights: Optional[torch.Tensor] = None,
eps: float = 1e-6,
maxiter: int = 100,
ftol: float = 1e-20,
do_log: bool = False,
):
"""
:param points: ``torch.Tensor`` of shape ``(n, d)``
Expand Down

0 comments on commit 736bf83

Please sign in to comment.