Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update the metrics to have a components dimension #162

Merged
merged 12 commits into from
Dec 20, 2023
3 changes: 1 addition & 2 deletions sparse_autoencoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from sparse_autoencoder.activation_store.disk_store import DiskActivationStore
from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.loss.abstract_loss import LossLogType, LossReductionType
from sparse_autoencoder.loss.abstract_loss import LossReductionType
from sparse_autoencoder.loss.decoded_activations_l2 import L2ReconstructionLoss
from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss
from sparse_autoencoder.loss.reducer import LossReducer
Expand Down Expand Up @@ -65,7 +65,6 @@
"L2ReconstructionLoss",
"LearnedActivationsL1Loss",
"LossHyperparameters",
"LossLogType",
"LossReducer",
"LossReductionType",
"Method",
Expand Down
32 changes: 15 additions & 17 deletions sparse_autoencoder/loss/abstract_loss.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
"""Abstract loss."""
from abc import ABC, abstractmethod
from typing import TypeAlias, final
from typing import final

from jaxtyping import Float
from strenum import LowercaseStrEnum
import torch
from torch import Tensor
from torch.nn import Module

from sparse_autoencoder.metrics.abstract_metric import MetricLocation, MetricResult
from sparse_autoencoder.tensor_types import Axis


Expand All @@ -21,10 +22,6 @@ class LossReductionType(LowercaseStrEnum):
NONE = "none"


LossLogType: TypeAlias = dict[str, int | float | str]
"""Loss log dict."""


class AbstractLoss(Module, ABC):
"""Abstract loss interface.

Expand Down Expand Up @@ -122,7 +119,7 @@ def scalar_loss_with_log(
],
batch_reduction: LossReductionType = LossReductionType.MEAN,
component_reduction: LossReductionType = LossReductionType.NONE,
) -> tuple[Float[Tensor, Axis.COMPONENT_OPTIONAL], LossLogType]:
) -> tuple[Float[Tensor, Axis.COMPONENT_OPTIONAL], list[MetricResult]]:
"""Scalar loss (reduced across the batch and component axis) with logging.

Args:
Expand All @@ -138,7 +135,7 @@ def scalar_loss_with_log(
Tuple of the batch scalar loss and a dict of any properties to log.
"""
children_loss_scalars: list[Float[Tensor, Axis.COMPONENT_OPTIONAL]] = []
metrics: LossLogType = {}
metrics: list[MetricResult] = []

# If the loss module has children (e.g. it is a reducer):
if len(self._modules) > 0:
Expand All @@ -152,7 +149,7 @@ def scalar_loss_with_log(
# component-wise losses in reducers.
)
children_loss_scalars.append(child_loss)
metrics.update(child_metrics)
metrics.extend(child_metrics)

# Get the total loss & metric
current_module_loss = torch.stack(children_loss_scalars).sum(0)
Expand All @@ -162,15 +159,16 @@ def scalar_loss_with_log(
current_module_loss = self.batch_loss(
source_activations, learned_activations, decoded_activations, batch_reduction
)

# Add in the current loss module's metric
log_name = "train/loss/" + self.log_name()
loss_to_log: list | float = current_module_loss.tolist()
if isinstance(loss_to_log, float):
metrics[log_name] = loss_to_log
else:
for component_idx, component_loss in enumerate(loss_to_log):
metrics[log_name + f"/component_{component_idx}"] = component_loss
log = MetricResult(
location=MetricLocation.TRAIN,
name="loss",
postfix=self.log_name(),
component_wise_values=current_module_loss.unsqueeze(0)
if current_module_loss.ndim == 0
else current_module_loss,
)
metrics.append(log)

# Reduce the current module loss across the component dimension
match component_reduction:
Expand All @@ -196,7 +194,7 @@ def __call__(
Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE)
],
reduction: LossReductionType = LossReductionType.MEAN,
) -> tuple[Float[Tensor, Axis.SINGLE_ITEM], LossLogType]:
) -> tuple[Float[Tensor, Axis.SINGLE_ITEM], list[MetricResult]]:
"""Batch scalar loss.

Args:
Expand Down
4 changes: 2 additions & 2 deletions sparse_autoencoder/loss/decoded_activations_l2.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class L2ReconstructionLoss(AbstractLoss):
>>> output_activations = torch.tensor([[1.0, 5], [1.0, 5]])
>>> unused_activations = torch.zeros_like(input_activations)
>>> # Outputs both loss and metrics to log
>>> loss(input_activations, unused_activations, output_activations)
(tensor(5.5000), {'train/loss/l2_reconstruction_loss': 5.5})
>>> loss.forward(input_activations, unused_activations, output_activations)
tensor([8.5000, 2.5000])
"""

_reduction: LossReductionType
Expand Down
42 changes: 22 additions & 20 deletions sparse_autoencoder/loss/learned_activations_l1.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
from torch import Tensor

from sparse_autoencoder.loss.abstract_loss import AbstractLoss, LossLogType, LossReductionType
from sparse_autoencoder.loss.abstract_loss import AbstractLoss, LossReductionType
from sparse_autoencoder.metrics.abstract_metric import MetricLocation, MetricResult
from sparse_autoencoder.tensor_types import Axis


Expand All @@ -21,7 +22,7 @@ class LearnedActivationsL1Loss(AbstractLoss):
>>> learned_activations = torch.tensor([[2.0, -3], [2.0, -3]])
>>> unused_activations = torch.zeros_like(learned_activations)
>>> # Returns loss and metrics to log
>>> l1_loss(unused_activations, learned_activations, unused_activations)[0]
>>> l1_loss.forward(unused_activations, learned_activations, unused_activations)[0]
tensor(0.5000)
"""

Expand Down Expand Up @@ -128,7 +129,7 @@ def scalar_loss_with_log(
],
batch_reduction: LossReductionType = LossReductionType.MEAN,
component_reduction: LossReductionType = LossReductionType.NONE,
) -> tuple[Float[Tensor, Axis.COMPONENT_OPTIONAL], LossLogType]:
) -> tuple[Float[Tensor, Axis.COMPONENT_OPTIONAL], list[MetricResult]]:
"""Scalar L1 loss (reduced across the batch and component axis) with logging.

Args:
Expand Down Expand Up @@ -162,24 +163,25 @@ def scalar_loss_with_log(
error_message = "Batch reduction type NONE not supported."
raise ValueError(error_message)

batch_loss_to_log: list | float = batch_scalar_loss.tolist()
batch_loss_penalty_to_log: list | float = batch_scalar_loss_penalty.tolist()

# Create the log
metrics = {}
if isinstance(batch_loss_to_log, float):
metrics["train/loss/learned_activations_l1_loss"] = batch_loss_to_log
metrics[f"train/loss/{self.log_name()}"] = batch_loss_penalty_to_log
else:
for component_idx, (component_loss, component_loss_penalty) in enumerate(
zip(batch_loss_to_log, batch_loss_penalty_to_log)
):
metrics[
f"train/loss/learned_activations_l1_loss/component_{component_idx}"
] = component_loss
metrics[
f"train/loss/{self.log_name()}/component_{component_idx}"
] = component_loss_penalty
metrics: list[MetricResult] = [
MetricResult(
name="loss",
postfix="learned_activations_l1",
component_wise_values=batch_scalar_loss.unsqueeze(0)
if batch_scalar_loss.ndim == 0
else batch_scalar_loss,
location=MetricLocation.TRAIN,
),
MetricResult(
name="loss",
postfix=self.log_name(),
component_wise_values=batch_scalar_loss_penalty.unsqueeze(0)
if batch_scalar_loss_penalty.ndim == 0
else batch_scalar_loss_penalty,
location=MetricLocation.TRAIN,
),
]

match component_reduction:
case LossReductionType.MEAN:
Expand Down
6 changes: 3 additions & 3 deletions sparse_autoencoder/loss/tests/test_abstract_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_batch_loss_with_log(dummy_loss: DummyLoss) -> None:
source_activations, learned_activations, decoded_activations
)
expected = 2.0 # Mean of [1.0, 2.0, 3.0]
assert log["train/loss/dummy"] == expected
assert log[0].component_wise_values[0] == expected


def test_scalar_loss_with_log_and_component_axis(dummy_loss: DummyLoss) -> None:
Expand All @@ -96,12 +96,12 @@ def test_scalar_loss_with_log_and_component_axis(dummy_loss: DummyLoss) -> None:
)
expected = 2.0 # Mean of [1.0, 2.0, 3.0]
for component_idx in range(num_components):
assert log[f"train/loss/dummy/component_{component_idx}"] == expected
assert log[0].component_wise_values[component_idx] == expected


def test_call_method(dummy_loss: DummyLoss) -> None:
"""Test the call method."""
source_activations = learned_activations = decoded_activations = torch.ones((1, 3))
_loss, log = dummy_loss(source_activations, learned_activations, decoded_activations)
expected = 2.0 # Mean of [1.0, 2.0, 3.0]
assert log["train/loss/dummy"] == expected
assert log[0].component_wise_values[0] == expected
Loading