diff --git a/sparse_autoencoder/__init__.py b/sparse_autoencoder/__init__.py index d19212da..6cad6d04 100644 --- a/sparse_autoencoder/__init__.py +++ b/sparse_autoencoder/__init__.py @@ -3,7 +3,7 @@ from sparse_autoencoder.activation_store.disk_store import DiskActivationStore from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore from sparse_autoencoder.autoencoder.model import SparseAutoencoder -from sparse_autoencoder.loss.abstract_loss import LossLogType, LossReductionType +from sparse_autoencoder.loss.abstract_loss import LossReductionType from sparse_autoencoder.loss.decoded_activations_l2 import L2ReconstructionLoss from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss from sparse_autoencoder.loss.reducer import LossReducer @@ -65,7 +65,6 @@ "L2ReconstructionLoss", "LearnedActivationsL1Loss", "LossHyperparameters", - "LossLogType", "LossReducer", "LossReductionType", "Method", diff --git a/sparse_autoencoder/loss/abstract_loss.py b/sparse_autoencoder/loss/abstract_loss.py index 72f850ac..e6dcc77c 100644 --- a/sparse_autoencoder/loss/abstract_loss.py +++ b/sparse_autoencoder/loss/abstract_loss.py @@ -1,6 +1,6 @@ """Abstract loss.""" from abc import ABC, abstractmethod -from typing import TypeAlias, final +from typing import final from jaxtyping import Float from strenum import LowercaseStrEnum @@ -8,6 +8,7 @@ from torch import Tensor from torch.nn import Module +from sparse_autoencoder.metrics.abstract_metric import MetricLocation, MetricResult from sparse_autoencoder.tensor_types import Axis @@ -21,10 +22,6 @@ class LossReductionType(LowercaseStrEnum): NONE = "none" -LossLogType: TypeAlias = dict[str, int | float | str] -"""Loss log dict.""" - - class AbstractLoss(Module, ABC): """Abstract loss interface. @@ -122,7 +119,7 @@ def scalar_loss_with_log( ], batch_reduction: LossReductionType = LossReductionType.MEAN, component_reduction: LossReductionType = LossReductionType.NONE, - ) -> tuple[Float[Tensor, Axis.COMPONENT_OPTIONAL], LossLogType]: + ) -> tuple[Float[Tensor, Axis.COMPONENT_OPTIONAL], list[MetricResult]]: """Scalar loss (reduced across the batch and component axis) with logging. Args: @@ -138,7 +135,7 @@ def scalar_loss_with_log( Tuple of the batch scalar loss and a dict of any properties to log. """ children_loss_scalars: list[Float[Tensor, Axis.COMPONENT_OPTIONAL]] = [] - metrics: LossLogType = {} + metrics: list[MetricResult] = [] # If the loss module has children (e.g. it is a reducer): if len(self._modules) > 0: @@ -152,7 +149,7 @@ def scalar_loss_with_log( # component-wise losses in reducers. ) children_loss_scalars.append(child_loss) - metrics.update(child_metrics) + metrics.extend(child_metrics) # Get the total loss & metric current_module_loss = torch.stack(children_loss_scalars).sum(0) @@ -162,15 +159,16 @@ def scalar_loss_with_log( current_module_loss = self.batch_loss( source_activations, learned_activations, decoded_activations, batch_reduction ) - # Add in the current loss module's metric - log_name = "train/loss/" + self.log_name() - loss_to_log: list | float = current_module_loss.tolist() - if isinstance(loss_to_log, float): - metrics[log_name] = loss_to_log - else: - for component_idx, component_loss in enumerate(loss_to_log): - metrics[log_name + f"/component_{component_idx}"] = component_loss + log = MetricResult( + location=MetricLocation.TRAIN, + name="loss", + postfix=self.log_name(), + component_wise_values=current_module_loss.unsqueeze(0) + if current_module_loss.ndim == 0 + else current_module_loss, + ) + metrics.append(log) # Reduce the current module loss across the component dimension match component_reduction: @@ -196,7 +194,7 @@ def __call__( Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) ], reduction: LossReductionType = LossReductionType.MEAN, - ) -> tuple[Float[Tensor, Axis.SINGLE_ITEM], LossLogType]: + ) -> tuple[Float[Tensor, Axis.SINGLE_ITEM], list[MetricResult]]: """Batch scalar loss. Args: diff --git a/sparse_autoencoder/loss/decoded_activations_l2.py b/sparse_autoencoder/loss/decoded_activations_l2.py index d4feaca0..2109dc3b 100644 --- a/sparse_autoencoder/loss/decoded_activations_l2.py +++ b/sparse_autoencoder/loss/decoded_activations_l2.py @@ -26,8 +26,8 @@ class L2ReconstructionLoss(AbstractLoss): >>> output_activations = torch.tensor([[1.0, 5], [1.0, 5]]) >>> unused_activations = torch.zeros_like(input_activations) >>> # Outputs both loss and metrics to log - >>> loss(input_activations, unused_activations, output_activations) - (tensor(5.5000), {'train/loss/l2_reconstruction_loss': 5.5}) + >>> loss.forward(input_activations, unused_activations, output_activations) + tensor([8.5000, 2.5000]) """ _reduction: LossReductionType diff --git a/sparse_autoencoder/loss/learned_activations_l1.py b/sparse_autoencoder/loss/learned_activations_l1.py index 84a25012..e917aea2 100644 --- a/sparse_autoencoder/loss/learned_activations_l1.py +++ b/sparse_autoencoder/loss/learned_activations_l1.py @@ -5,7 +5,8 @@ import torch from torch import Tensor -from sparse_autoencoder.loss.abstract_loss import AbstractLoss, LossLogType, LossReductionType +from sparse_autoencoder.loss.abstract_loss import AbstractLoss, LossReductionType +from sparse_autoencoder.metrics.abstract_metric import MetricLocation, MetricResult from sparse_autoencoder.tensor_types import Axis @@ -21,7 +22,7 @@ class LearnedActivationsL1Loss(AbstractLoss): >>> learned_activations = torch.tensor([[2.0, -3], [2.0, -3]]) >>> unused_activations = torch.zeros_like(learned_activations) >>> # Returns loss and metrics to log - >>> l1_loss(unused_activations, learned_activations, unused_activations)[0] + >>> l1_loss.forward(unused_activations, learned_activations, unused_activations)[0] tensor(0.5000) """ @@ -128,7 +129,7 @@ def scalar_loss_with_log( ], batch_reduction: LossReductionType = LossReductionType.MEAN, component_reduction: LossReductionType = LossReductionType.NONE, - ) -> tuple[Float[Tensor, Axis.COMPONENT_OPTIONAL], LossLogType]: + ) -> tuple[Float[Tensor, Axis.COMPONENT_OPTIONAL], list[MetricResult]]: """Scalar L1 loss (reduced across the batch and component axis) with logging. Args: @@ -162,24 +163,25 @@ def scalar_loss_with_log( error_message = "Batch reduction type NONE not supported." raise ValueError(error_message) - batch_loss_to_log: list | float = batch_scalar_loss.tolist() - batch_loss_penalty_to_log: list | float = batch_scalar_loss_penalty.tolist() - # Create the log - metrics = {} - if isinstance(batch_loss_to_log, float): - metrics["train/loss/learned_activations_l1_loss"] = batch_loss_to_log - metrics[f"train/loss/{self.log_name()}"] = batch_loss_penalty_to_log - else: - for component_idx, (component_loss, component_loss_penalty) in enumerate( - zip(batch_loss_to_log, batch_loss_penalty_to_log) - ): - metrics[ - f"train/loss/learned_activations_l1_loss/component_{component_idx}" - ] = component_loss - metrics[ - f"train/loss/{self.log_name()}/component_{component_idx}" - ] = component_loss_penalty + metrics: list[MetricResult] = [ + MetricResult( + name="loss", + postfix="learned_activations_l1", + component_wise_values=batch_scalar_loss.unsqueeze(0) + if batch_scalar_loss.ndim == 0 + else batch_scalar_loss, + location=MetricLocation.TRAIN, + ), + MetricResult( + name="loss", + postfix=self.log_name(), + component_wise_values=batch_scalar_loss_penalty.unsqueeze(0) + if batch_scalar_loss_penalty.ndim == 0 + else batch_scalar_loss_penalty, + location=MetricLocation.TRAIN, + ), + ] match component_reduction: case LossReductionType.MEAN: diff --git a/sparse_autoencoder/loss/tests/test_abstract_loss.py b/sparse_autoencoder/loss/tests/test_abstract_loss.py index 390400ba..8dd8f58e 100644 --- a/sparse_autoencoder/loss/tests/test_abstract_loss.py +++ b/sparse_autoencoder/loss/tests/test_abstract_loss.py @@ -83,7 +83,7 @@ def test_batch_loss_with_log(dummy_loss: DummyLoss) -> None: source_activations, learned_activations, decoded_activations ) expected = 2.0 # Mean of [1.0, 2.0, 3.0] - assert log["train/loss/dummy"] == expected + assert log[0].component_wise_values[0] == expected def test_scalar_loss_with_log_and_component_axis(dummy_loss: DummyLoss) -> None: @@ -96,7 +96,7 @@ def test_scalar_loss_with_log_and_component_axis(dummy_loss: DummyLoss) -> None: ) expected = 2.0 # Mean of [1.0, 2.0, 3.0] for component_idx in range(num_components): - assert log[f"train/loss/dummy/component_{component_idx}"] == expected + assert log[0].component_wise_values[component_idx] == expected def test_call_method(dummy_loss: DummyLoss) -> None: @@ -104,4 +104,4 @@ def test_call_method(dummy_loss: DummyLoss) -> None: source_activations = learned_activations = decoded_activations = torch.ones((1, 3)) _loss, log = dummy_loss(source_activations, learned_activations, decoded_activations) expected = 2.0 # Mean of [1.0, 2.0, 3.0] - assert log["train/loss/dummy"] == expected + assert log[0].component_wise_values[0] == expected diff --git a/sparse_autoencoder/metrics/abstract_metric.py b/sparse_autoencoder/metrics/abstract_metric.py new file mode 100644 index 00000000..62a520b8 --- /dev/null +++ b/sparse_autoencoder/metrics/abstract_metric.py @@ -0,0 +1,325 @@ +"""Abstract metric. + +Defines the shared functionality across all types of metrics. Note that for creating your own +metric, you probably want to extend one of the subclasses such as `TrainMetric` or `ValidateMetric`. +These subclasses define the interface for metrics that can be implemented at different points in the +training pipeline. +""" +from abc import ABC, abstractmethod +from collections.abc import Sequence +from enum import auto +from typing import Any, TypeAlias, cast, final + +from jaxtyping import Float, Int +import numpy as np +from strenum import LowercaseStrEnum, SnakeCaseStrEnum +from torch import Tensor +from wandb import data_types + +from sparse_autoencoder.tensor_types import Axis + + +class MetricLocation(SnakeCaseStrEnum): + """Metric location. + + Metrics can be logged at different stages of the training pipeline. This enum is used to define + when the metric was logged. + """ + + GENERATE = auto() + TRAIN = auto() + RESAMPLE = auto() + VALIDATE = auto() + SAVE = auto() + + +class ComponentAggregationApproach(LowercaseStrEnum): + """Component aggregation method. + + When training multiple SAEs on multiple components (e.g. every MLP layer in a source model), it + can be useful to see summary statistics across all components as well. This enum is used to + define how the component-wise values should be aggregated. + """ + + MEAN = auto() + """Mean of the component-wise values.""" + + SUM = auto() + """Sum of the component-wise values.""" + + TABLE = auto() + """Table of all component-wise values in one place.""" + + +WandbSupportedLogTypes: TypeAlias = ( + bool + | data_types.Audio + | data_types.Bokeh + | data_types.Histogram + | data_types.Html + | data_types.Image + | data_types.Molecule + | data_types.Object3D + | data_types.Plotly + | data_types.Table + | data_types.Video + | data_types.WBTraceTree + | float + | Float[Tensor, Axis.names(Axis.SINGLE_ITEM)] + | int + | Int[Tensor, Axis.names(Axis.SINGLE_ITEM)] + | list["WandbSupportedLogTypes"] + | np.ndarray +) +"""All supported component-wise W&B log types.""" + + +class MetricResult: + """Metric result. + + Every metric (and loss module) should return a list of metric results (a list so that it can + return more than one metric result if needed). Each metric result defines the name of the + result, as well as the component-wise values and how they should be aggregated. + """ + + location: MetricLocation + name: str + postfix: str | None + _component_names: list[str] + component_wise_values: Sequence[WandbSupportedLogTypes] | Float[ + Tensor, Axis.names(Axis.COMPONENT) + ] | Int[Tensor, Axis.names(Axis.COMPONENT)] + aggregate_approach: ComponentAggregationApproach | None + _aggregate_value: Any | None + + def __init__( + self, + component_wise_values: Sequence[WandbSupportedLogTypes] + | Float[Tensor, Axis.names(Axis.COMPONENT)] + | Int[Tensor, Axis.names(Axis.COMPONENT)], + name: str, + location: MetricLocation, + aggregate_approach: ComponentAggregationApproach + | None = ComponentAggregationApproach.TABLE, + aggregate_value: Any | None = None, # noqa: ANN401 + postfix: str | None = None, + ) -> None: + """Initialize a metric result. + + Example: + >>> metric_result = MetricResult( + ... location=MetricLocation.TRAIN, + ... name="loss", + ... component_wise_values=[1.0, 2.0, 3.0], + ... aggregate_approach=ComponentAggregationApproach.MEAN, + ... ) + >>> for k, v in metric_result.wandb_log.items(): + ... print(f"{k}: {v}") + component_0/train/loss: 1.0 + component_1/train/loss: 2.0 + component_2/train/loss: 3.0 + train/loss: 2.0 + + + Args: + component_wise_values: Values for each component. + name: Metric name (e.g. `l2_loss`). This will be combined with the component name and + metric locations, as well as an optional postfix, to create a Weights and Biases + name of the form `component_name/metric_location/metric_name/metric_postfix`. + location: Metric location. + aggregate_approach: Component aggregation approach. + aggregate_value: Override the aggregate value across components. For most metric results + you can instead just specify the `aggregate_approach` and it will be automatically + calculated. + postfix: Metric name postfix. + """ + self.location = location + self.name = name + self.component_wise_values = component_wise_values + self.aggregate_approach = aggregate_approach + self._aggregate_value = aggregate_value + self.postfix = postfix + self._component_names = [f"component_{i}" for i in range(len(component_wise_values))] + + @final + @property + def n_components(self) -> int: + """Number of components.""" + return len(self.component_wise_values) + + @final + @property + def aggregate_value( # noqa: PLR0911 + self, + ) -> ( + WandbSupportedLogTypes + | Float[Tensor, Axis.names(Axis.COMPONENT)] + | Int[Tensor, Axis.names(Axis.COMPONENT)] + ): + """Aggregate value across components. + + Returns: + Aggregate value (defaults to the initialised aggregate value if set, or otherwise + attempts to automatically aggregate the component-wise values). + + Raises: + ValueError: If the component-wise values cannot be automatically aggregated. + """ + # Allow overriding + if self._aggregate_value is not None: + return self._aggregate_value + + if self.n_components == 1: + return self.component_wise_values[0] + + cannot_aggregate_error_message = "Cannot aggregate component-wise values." + + # Automatically aggregate number lists/sequences/tuples/sets + if (isinstance(self.component_wise_values, (Sequence, list, tuple, set))) and all( + isinstance(x, (int, float)) for x in self.component_wise_values + ): + values: list = cast(list[float], self.component_wise_values) + match self.aggregate_approach: + case ComponentAggregationApproach.MEAN: + return sum(values) / len(values) + case ComponentAggregationApproach.SUM: + return sum(values) + case ComponentAggregationApproach.TABLE: + return values + case _: + raise ValueError(cannot_aggregate_error_message) + + # Automatically aggregate number tensors + if ( + isinstance(self.component_wise_values, Tensor) + and self.component_wise_values.shape[0] == self.n_components + ): + match self.aggregate_approach: + case ComponentAggregationApproach.MEAN: + return self.component_wise_values.mean(dim=0) + case ComponentAggregationApproach.SUM: + return self.component_wise_values.sum(dim=0) + case ComponentAggregationApproach.TABLE: + return self.component_wise_values + case _: + raise ValueError(cannot_aggregate_error_message) + + # Raise otherwise + raise ValueError(cannot_aggregate_error_message) + + @final + def create_wandb_name(self, component_name: str | None = None) -> str: + """Weights and Biases Metric Name. + + Note Weights and Biases categorises metrics using a forward slash (`/`) in the name string. + + Example: + >>> metric_result = MetricResult( + ... location=MetricLocation.VALIDATE, + ... name="loss", + ... component_wise_values=[1.0, 2.0, 3.0], + ... ) + >>> metric_result.create_wandb_name() + 'validate/loss' + + >>> metric_result.create_wandb_name(component_name="component_0") + 'component_0/validate/loss' + + Args: + component_name: Component name, if creating a Weights and Biases name for a specific + component. + + Returns: + Weights and Biases metric name. + """ + name_parts = [] + + if component_name is not None: + name_parts.append(component_name) + + name_parts.extend([self.location.value, self.name]) + + if self.postfix is not None: + name_parts.append(self.postfix) + + return "/".join(name_parts) + + @final + @property + def wandb_log(self) -> dict[str, WandbSupportedLogTypes]: + """Create the Weights and Biases Log data. + + For use with `wandb.log()`. + + https://docs.wandb.ai/ref/python/log + + Examples: + With just one component: + + >>> metric_result = MetricResult( + ... location=MetricLocation.VALIDATE, + ... name="loss", + ... component_wise_values=[1.5], + ... ) + >>> for k, v in metric_result.wandb_log.items(): + ... print(f"{k}: {v}") + validate/loss: 1.5 + + With multiple components: + + >>> metric_result = MetricResult( + ... location=MetricLocation.VALIDATE, + ... name="loss", + ... component_wise_values=[1.0, 2.0], + ... aggregate_approach=ComponentAggregationApproach.MEAN, + ... ) + >>> for k, v in metric_result.wandb_log.items(): + ... print(f"{k}: {v}") + component_0/validate/loss: 1.0 + component_1/validate/loss: 2.0 + validate/loss: 1.5 + + Returns: + Weights and Biases log data. + """ + # Create the component wise logs if there is more than one component + component_wise_logs = {} + if self.n_components > 1: + for component_name, value in zip(self._component_names, self.component_wise_values): + component_wise_logs[self.create_wandb_name(component_name=component_name)] = value + + # Create the aggregate log if there is an aggregate value + aggregate_log = {} + if self.aggregate_approach is not None or self._aggregate_value is not None: + aggregate_log = {self.create_wandb_name(): self.aggregate_value} + + return {**component_wise_logs, **aggregate_log} + + def __str__(self) -> str: + """String representation.""" + return str(self.wandb_log) + + def __repr__(self) -> str: + """Representation.""" + class_name = self.__class__.__name__ + return f"""{class_name}( + location={self.location}, + name={self.name}, + postfix={self.postfix}, + component_wise_values={self.component_wise_values}, + aggregate_approach={self.aggregate_approach}, + aggregate_value={self._aggregate_value}, + )""" + + +class AbstractMetric(ABC): + """Abstract metric.""" + + @property + @abstractmethod + def location(self) -> MetricLocation: + """Metric location.""" + + @abstractmethod + def calculate(self, data) -> list[MetricResult]: # type: ignore # noqa: ANN001 (type to be narrowed by abstract subclasses) + """Calculate metrics.""" diff --git a/sparse_autoencoder/metrics/generate/__init__.py b/sparse_autoencoder/metrics/generate/__init__.py deleted file mode 100644 index 261aee81..00000000 --- a/sparse_autoencoder/metrics/generate/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Generate step metrics.""" -from sparse_autoencoder.metrics.generate.abstract_generate_metric import AbstractGenerateMetric - - -__all__ = ["AbstractGenerateMetric"] diff --git a/sparse_autoencoder/metrics/generate/abstract_generate_metric.py b/sparse_autoencoder/metrics/generate/abstract_generate_metric.py deleted file mode 100644 index 98bdc9b2..00000000 --- a/sparse_autoencoder/metrics/generate/abstract_generate_metric.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Abstract generate metric.""" -from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any - -from jaxtyping import Float -from torch import Tensor - -from sparse_autoencoder.tensor_types import Axis - - -@dataclass -class GenerateMetricData: - """Generate metric data.""" - - generated_activations: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] - - -class AbstractGenerateMetric(ABC): - """Abstract generate metric.""" - - @abstractmethod - def calculate(self, data: GenerateMetricData) -> dict[str, Any]: - """Calculate any metrics.""" diff --git a/sparse_autoencoder/metrics/metrics_container.py b/sparse_autoencoder/metrics/metrics_container.py index 70283f30..1979e8da 100644 --- a/sparse_autoencoder/metrics/metrics_container.py +++ b/sparse_autoencoder/metrics/metrics_container.py @@ -1,7 +1,6 @@ """Metrics container.""" from dataclasses import dataclass, field -from sparse_autoencoder.metrics.generate.abstract_generate_metric import AbstractGenerateMetric from sparse_autoencoder.metrics.train.abstract_train_metric import AbstractTrainMetric from sparse_autoencoder.metrics.train.capacity import CapacityMetric from sparse_autoencoder.metrics.train.feature_density import TrainBatchFeatureDensityMetric @@ -15,12 +14,10 @@ class MetricsContainer: """Metrics container. - Stores all metrics used in a pipeline. + Stores all metrics used in a pipeline, and allows updating of the component names for all at + once. """ - generate_metrics: list[AbstractGenerateMetric] = field(default_factory=list) - """Metrics for the generate section.""" - train_metrics: list[AbstractTrainMetric] = field(default_factory=list) """Metrics for the train section.""" diff --git a/sparse_autoencoder/metrics/train/abstract_train_metric.py b/sparse_autoencoder/metrics/train/abstract_train_metric.py index 04bc517d..3598028c 100644 --- a/sparse_autoencoder/metrics/train/abstract_train_metric.py +++ b/sparse_autoencoder/metrics/train/abstract_train_metric.py @@ -1,31 +1,76 @@ """Abstract train metric.""" from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any +from typing import final from jaxtyping import Float from torch import Tensor +from sparse_autoencoder.metrics.abstract_metric import ( + AbstractMetric, + MetricLocation, + MetricResult, +) +from sparse_autoencoder.metrics.utils.add_component_axis_if_missing import ( + add_component_axis_if_missing, +) from sparse_autoencoder.tensor_types import Axis +@final @dataclass class TrainMetricData: """Train metric data.""" - input_activations: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] + input_activations: Float[ + Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.INPUT_OUTPUT_FEATURE) + ] + """Input activations.""" - learned_activations: Float[Tensor, Axis.names(Axis.BATCH, Axis.LEARNT_FEATURE)] + learned_activations: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.LEARNT_FEATURE)] + """Learned activations.""" - decoded_activations: Float[Tensor, Axis.names(Axis.BATCH, Axis.INPUT_OUTPUT_FEATURE)] + decoded_activations: Float[ + Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.INPUT_OUTPUT_FEATURE) + ] + """Decoded activations.""" + def __init__( + self, + input_activations: Float[ + Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) + ], + learned_activations: Float[ + Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE) + ], + decoded_activations: Float[ + Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT_OPTIONAL, Axis.INPUT_OUTPUT_FEATURE) + ], + ) -> None: + """Initialize the train metric data.""" + self.input_activations = add_component_axis_if_missing( + input_activations, dimensions_without_component=2 + ).detach() + self.learned_activations = add_component_axis_if_missing( + learned_activations, dimensions_without_component=2 + ).detach() + self.decoded_activations = add_component_axis_if_missing( + decoded_activations, dimensions_without_component=2 + ).detach() -class AbstractTrainMetric(ABC): + +class AbstractTrainMetric(AbstractMetric, ABC): """Abstract train metric.""" + @final + @property + def location(self) -> MetricLocation: + """Metric type name.""" + return MetricLocation.TRAIN + @abstractmethod - def calculate(self, data: TrainMetricData) -> dict[str, Any]: - """Calculate any metrics. + def calculate(self, data: TrainMetricData) -> list[MetricResult]: + """Calculate any metrics component wise. Args: data: Train metric data. diff --git a/sparse_autoencoder/metrics/train/capacity.py b/sparse_autoencoder/metrics/train/capacity.py index 3013b444..b6a708ed 100644 --- a/sparse_autoencoder/metrics/train/capacity.py +++ b/sparse_autoencoder/metrics/train/capacity.py @@ -1,15 +1,14 @@ """Capacity Metrics.""" -from typing import Any import einops from jaxtyping import Float import numpy as np from numpy import histogram -from numpy.typing import NDArray import torch from torch import Tensor import wandb +from sparse_autoencoder.metrics.abstract_metric import MetricResult from sparse_autoencoder.metrics.train.abstract_train_metric import ( AbstractTrainMetric, TrainMetricData, @@ -20,7 +19,8 @@ class CapacityMetric(AbstractTrainMetric): """Capacities Metrics for Learned Features. - Measure the capacity of a set of features as defined in [Polysemanticity and Capacity in Neural Networks](https://arxiv.org/pdf/2210.01892.pdf). + Measure the capacity of a set of features as defined in [Polysemanticity and Capacity in Neural + Networks](https://arxiv.org/pdf/2210.01892.pdf). Capacity is intuitively measuring the 'proportion of a dimension' assigned to a feature. Formally it's the ratio of the squared dot product of a feature with itself to the sum of its @@ -32,16 +32,16 @@ class CapacityMetric(AbstractTrainMetric): @staticmethod def capacities( - features: Float[Tensor, Axis.names(Axis.BATCH, Axis.LEARNT_FEATURE)], - ) -> Float[Tensor, Axis.BATCH]: - """Calculate capacities. + features: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.LEARNT_FEATURE)], + ) -> Float[Tensor, Axis.names(Axis.COMPONENT, Axis.BATCH)]: + r"""Calculate capacities. Example: >>> import torch - >>> orthogonal_features = torch.tensor([[1., 0., 0.], [0., 1., 0.], [0., 0., 1.]]) + >>> orthogonal_features = torch.tensor([[[1., 0., 0.]], [[0., 1., 0.]], [[0., 0., 1.]]]) >>> orthogonal_caps = CapacityMetric.capacities(orthogonal_features) >>> orthogonal_caps - tensor([1., 1., 1.]) + tensor([[1., 1., 1.]]) Args: features: A collection of features. @@ -50,19 +50,31 @@ def capacities( A 1D tensor of capacities, where each element is the capacity of the corresponding feature. """ - squared_dot_products = ( + squared_dot_products: Float[Tensor, Axis.names(Axis.BATCH, Axis.BATCH, Axis.COMPONENT)] = ( einops.einsum( - features, features, "n_feats1 feat_dim, n_feats2 feat_dim -> n_feats1 n_feats2" + features, + features, + f"batch_1 {Axis.COMPONENT} {Axis.LEARNT_FEATURE}, \ + batch_2 {Axis.COMPONENT} {Axis.LEARNT_FEATURE} \ + -> {Axis.COMPONENT} batch_1 batch_2", ) ** 2 ) - sum_of_sq_dot = squared_dot_products.sum(dim=-1) - return torch.diag(squared_dot_products) / sum_of_sq_dot + + sum_of_sq_dot: Float[ + Tensor, Axis.names(Axis.COMPONENT, Axis.BATCH) + ] = squared_dot_products.sum(dim=-1) + + diagonal: Float[Tensor, Axis.names(Axis.COMPONENT, Axis.BATCH)] = torch.diagonal( + squared_dot_products, dim1=1, dim2=2 + ) + + return diagonal / sum_of_sq_dot @staticmethod def wandb_capacities_histogram( - capacities: Float[Tensor, Axis.BATCH], - ) -> wandb.Histogram: + capacities: Float[Tensor, Axis.names(Axis.COMPONENT, Axis.BATCH)], + ) -> list[wandb.Histogram]: """Create a W&B histogram of the capacities. This can be logged with Weights & Biases using e.g. `wandb.log({"capacities_histogram": @@ -74,16 +86,25 @@ def wandb_capacities_histogram( Returns: Weights & Biases histogram for logging with `wandb.log`. """ - numpy_capacities: NDArray[np.float_] = capacities.detach().cpu().numpy() + np_capacities: Float[ + np.ndarray, Axis.names(Axis.COMPONENT, Axis.BATCH) + ] = capacities.cpu().numpy() - bins, values = histogram(numpy_capacities, bins=20, range=(0, 1)) - return wandb.Histogram(np_histogram=(bins, values)) + np_histograms = [histogram(capacity, bins=20, range=(0, 1)) for capacity in np_capacities] - def calculate(self, data: TrainMetricData) -> dict[str, Any]: + return [wandb.Histogram(np_histogram=np_histogram) for np_histogram in np_histograms] + + def calculate(self, data: TrainMetricData) -> list[MetricResult]: """Calculate the capacities for a training batch.""" train_batch_capacities = self.capacities(data.learned_activations) - train_batch_capacities_histogram = self.wandb_capacities_histogram(train_batch_capacities) - return { - "train/batch_capacities_histogram": train_batch_capacities_histogram, - } + histograms = self.wandb_capacities_histogram(train_batch_capacities) + + return [ + MetricResult( + name="capacities", + component_wise_values=histograms, + location=self.location, + aggregate_approach=None, # Don't aggregate histograms + ) + ] diff --git a/sparse_autoencoder/metrics/train/feature_density.py b/sparse_autoencoder/metrics/train/feature_density.py index a16c8f4d..9922d94d 100644 --- a/sparse_autoencoder/metrics/train/feature_density.py +++ b/sparse_autoencoder/metrics/train/feature_density.py @@ -1,15 +1,13 @@ """Train batch feature density.""" -from typing import Any - import einops from jaxtyping import Float import numpy as np from numpy import histogram -from numpy.typing import NDArray import torch from torch import Tensor import wandb +from sparse_autoencoder.metrics.abstract_metric import MetricResult from sparse_autoencoder.metrics.train.abstract_train_metric import ( AbstractTrainMetric, TrainMetricData, @@ -34,7 +32,10 @@ class TrainBatchFeatureDensityMetric(AbstractTrainMetric): threshold: float - def __init__(self, threshold: float = 0.0) -> None: + def __init__( + self, + threshold: float = 0.0, + ) -> None: """Initialise the train batch feature density metric. Args: @@ -45,17 +46,18 @@ def __init__(self, threshold: float = 0.0) -> None: self.threshold = threshold def feature_density( - self, activations: Float[Tensor, Axis.names(Axis.BATCH, Axis.LEARNT_FEATURE)] - ) -> Float[Tensor, Axis.LEARNT_FEATURE]: + self, + activations: Float[Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.LEARNT_FEATURE)], + ) -> Float[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)]: """Count how many times each feature was active. Percentage of samples in which each feature was active (i.e. the neuron has "fired"). Example: >>> import torch - >>> activations = torch.tensor([[0.5, 0.5, 0.0], [0.5, 0.0, 0.0001]]) + >>> activations = torch.tensor([[[0.5, 0.5, 0.0]], [[0.5, 0.0, 0.0001]]]) >>> TrainBatchFeatureDensityMetric(0.001).feature_density(activations).tolist() - [1.0, 0.5, 0.0] + [[1.0, 0.5, 0.0]] Args: activations: Sample of cached activations (the Autoencoder's learned features). @@ -63,18 +65,23 @@ def feature_density( Returns: Number of times each feature was active in a sample. """ - has_fired: Float[Tensor, Axis.names(Axis.BATCH, Axis.LEARNT_FEATURE)] = torch.gt( - activations, self.threshold - ).to( + has_fired: Float[ + Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.LEARNT_FEATURE) + ] = torch.gt(activations, self.threshold).to( dtype=torch.float # Move to float so it can be averaged ) - return einops.reduce(has_fired, "sample activation -> activation", "mean") + return einops.reduce( + has_fired, + f"{Axis.BATCH} {Axis.COMPONENT} {Axis.LEARNT_FEATURE} \ + -> {Axis.COMPONENT} {Axis.LEARNT_FEATURE}", + "mean", + ) @staticmethod def wandb_feature_density_histogram( - feature_density: Float[Tensor, Axis.LEARNT_FEATURE], - ) -> wandb.Histogram: + feature_density: Float[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)], + ) -> list[wandb.Histogram]: """Create a W&B histogram of the feature density. This can be logged with Weights & Biases using e.g. `wandb.log({"feature_density_histogram": @@ -87,12 +94,18 @@ def wandb_feature_density_histogram( Returns: Weights & Biases histogram for logging with `wandb.log`. """ - numpy_feature_density: NDArray[np.float_] = feature_density.detach().cpu().numpy() + numpy_feature_density: Float[ + np.ndarray, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE) + ] = feature_density.cpu().numpy() - bins, values = histogram(numpy_feature_density, bins=50) - return wandb.Histogram(np_histogram=(bins, values)) + np_histograms = [ + histogram(component_feature_density, bins=50) + for component_feature_density in numpy_feature_density + ] - def calculate(self, data: TrainMetricData) -> dict[str, Any]: + return [wandb.Histogram(np_histogram=np_histogram) for np_histogram in np_histograms] + + def calculate(self, data: TrainMetricData) -> list[MetricResult]: """Calculate the train batch feature density metrics. Args: @@ -102,14 +115,19 @@ def calculate(self, data: TrainMetricData) -> dict[str, Any]: Dictionary with the train batch feature density metric, and a histogram of the feature density. """ - train_batch_feature_density: Float[Tensor, Axis.LEARNT_FEATURE] = self.feature_density( - data.learned_activations - ) + train_batch_feature_density: Float[ + Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE) + ] = self.feature_density(data.learned_activations) - train_batch_feature_density_histogram: wandb.Histogram = ( - self.wandb_feature_density_histogram(train_batch_feature_density) + component_wise_histograms = self.wandb_feature_density_histogram( + train_batch_feature_density ) - return { - "train/batch_feature_density_histogram": train_batch_feature_density_histogram, - } + return [ + MetricResult( + name="feature_density", + component_wise_values=component_wise_histograms, + location=self.location, + aggregate_approach=None, # Don't aggregate the histograms + ) + ] diff --git a/sparse_autoencoder/metrics/train/l0_norm_metric.py b/sparse_autoencoder/metrics/train/l0_norm_metric.py index 4084ed77..2e6ebb4b 100644 --- a/sparse_autoencoder/metrics/train/l0_norm_metric.py +++ b/sparse_autoencoder/metrics/train/l0_norm_metric.py @@ -1,12 +1,17 @@ """L0 norm sparsity metric.""" from typing import final +import einops +from jaxtyping import Float import torch +from torch import Tensor +from sparse_autoencoder.metrics.abstract_metric import MetricResult from sparse_autoencoder.metrics.train.abstract_train_metric import ( AbstractTrainMetric, TrainMetricData, ) +from sparse_autoencoder.tensor_types import Axis @final @@ -17,9 +22,24 @@ class TrainBatchLearnedActivationsL0(AbstractTrainMetric): this over the batch. """ - def calculate(self, data: TrainMetricData) -> dict[str, float]: - """Create a log item for Weights and Biases.""" - batch_size = data.learned_activations.size(0) - n_non_zero_activations = torch.count_nonzero(data.learned_activations) - batch_average = n_non_zero_activations / batch_size - return {"train/learned_activations_l0_norm": batch_average.item()} + def calculate(self, data: TrainMetricData) -> list[MetricResult]: + """Create the L0 norm sparsity metric, component wise..""" + learned_activations: Float[ + Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT, Axis.LEARNT_FEATURE) + ] = data.learned_activations + + n_non_zero_activations: Float[ + Tensor, Axis.names(Axis.BATCH, Axis.COMPONENT) + ] = torch.count_nonzero(learned_activations, dim=-1).to(dtype=torch.float) + + batch_average: Float[Tensor, Axis.COMPONENT] = einops.reduce( + n_non_zero_activations, f"{Axis.BATCH} {Axis.COMPONENT} -> {Axis.COMPONENT}", "mean" + ) + + return [ + MetricResult( + location=self.location, + name="learned_activations_l0_norm", + component_wise_values=batch_average, + ) + ] diff --git a/sparse_autoencoder/metrics/train/neuron_activity_metric.py b/sparse_autoencoder/metrics/train/neuron_activity_metric.py index 8d87b93f..60f60020 100644 --- a/sparse_autoencoder/metrics/train/neuron_activity_metric.py +++ b/sparse_autoencoder/metrics/train/neuron_activity_metric.py @@ -3,15 +3,16 @@ Logs the number of dead and alive neurons at various horizons. Also logs histograms of neuron activity, and the number of neurons that are almost dead. """ -from typing import Any - -from jaxtyping import Int64 +from jaxtyping import Float, Int, Int64 import numpy as np -from numpy.typing import NDArray import torch from torch import Tensor import wandb +from sparse_autoencoder.metrics.abstract_metric import ( + MetricLocation, + MetricResult, +) from sparse_autoencoder.metrics.train.abstract_train_metric import ( AbstractTrainMetric, TrainMetricData, @@ -19,15 +20,15 @@ from sparse_autoencoder.tensor_types import Axis -DEFAULT_HORIZONS = [10_000, 100_000, 500_000, 1_000_000, 10_000_000] -"""Default horizons.""" +DEFAULT_HORIZONS = [10_000, 100_000, 1_000_000, 10_000_000] +"""Default horizons (in number of logged activations).""" DEFAULT_THRESHOLDS = [1e-5, 1e-6] """Default thresholds for determining if a neuron is almost dead.""" class NeuronActivityHorizonData: - """Neuron activity data for a single horizon. + """Neuron activity data for a specific horizon (number of activations seen). For each time horizon we store some data (e.g. the number of times each neuron fired inside this time horizon). This class also contains some helper methods for then calculating metrics from @@ -43,105 +44,156 @@ class NeuronActivityHorizonData: _steps_since_last_calculated: int """Steps since last calculated.""" - _neuron_activity: Int64[Tensor, Axis.LEARNT_FEATURE] + _neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)] """Neuron activity since inception.""" _thresholds: list[float] """Thresholds for almost dead neurons.""" + _n_components: int + """Number of components.""" + + _n_learned_features: int + """Number of learned features.""" + @property - def _dead_count(self) -> int: + def _dead_count(self) -> Int[Tensor, Axis.COMPONENT]: """Dead count.""" - dead_bool_mask: Int64[Tensor, Axis.LEARNT_FEATURE] = self._neuron_activity == 0 - count_dead: Int64[Tensor, Axis.SINGLE_ITEM] = dead_bool_mask.sum() - return int(count_dead.item()) + dead_bool_mask: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)] = ( + self._neuron_activity == 0 + ) + return dead_bool_mask.sum(-1) @property - def _dead_fraction(self) -> float: + def _dead_fraction(self) -> Float[Tensor, Axis.COMPONENT]: """Dead fraction.""" - return self._dead_count / self._neuron_activity.shape[-1] + return self._dead_count / self._n_learned_features @property - def _alive_count(self) -> int: + def _alive_count(self) -> Int[Tensor, Axis.COMPONENT]: """Alive count.""" - alive_bool_mask: Int64[Tensor, Axis.LEARNT_FEATURE] = self._neuron_activity > 0 - count_alive: Int64[Tensor, Axis.SINGLE_ITEM] = alive_bool_mask.sum() - return int(count_alive.item()) + alive_bool_mask: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)] = ( + self._neuron_activity > 0 + ) + + return alive_bool_mask.sum(-1) - def _almost_dead(self, threshold: float) -> int | None: + def _almost_dead(self, threshold: float) -> Int[Tensor, Axis.COMPONENT]: """Almost dead count.""" threshold_in_activations: float = threshold * self._horizon_number_activations - if threshold_in_activations < 1: - return None - almost_dead_bool_mask: Int64[Tensor, Axis.LEARNT_FEATURE] = ( + almost_dead_bool_mask: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)] = ( self._neuron_activity < threshold_in_activations ) - count_almost_dead: Int64[Tensor, Axis.SINGLE_ITEM] = almost_dead_bool_mask.sum() - return int(count_almost_dead.item()) + + return almost_dead_bool_mask.sum(-1) @property - def _activity_histogram(self) -> wandb.Histogram: + def _activity_histogram(self) -> list[wandb.Histogram]: """Activity histogram.""" - numpy_neuron_activity: NDArray[np.float_] = self._neuron_activity.detach().cpu().numpy() - bins, values = np.histogram(numpy_neuron_activity, bins=50) - return wandb.Histogram(np_histogram=(bins, values)) + numpy_neuron_activity: Float[ + np.ndarray, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE) + ] = self._neuron_activity.cpu().numpy() + + np_histograms = [np.histogram(activity) for activity in numpy_neuron_activity] + + return [wandb.Histogram(np_histogram=histogram) for histogram in np_histograms] @property - def _log_activity_histogram(self) -> wandb.Histogram: + def _log_activity_histogram(self) -> list[wandb.Histogram]: """Log activity histogram.""" - numpy_neuron_activity: NDArray[np.float_] = self._neuron_activity.detach().cpu().numpy() log_epsilon = 0.1 # To avoid log(0) - log_neuron_activity = np.log(numpy_neuron_activity + log_epsilon) - bins, values = np.histogram(log_neuron_activity, bins=50) - return wandb.Histogram(np_histogram=(bins, values)) + log_neuron_activity: Float[ + Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE) + ] = torch.log(self._neuron_activity + log_epsilon) - @property - def name(self) -> str: - """Name.""" - return f"over_{self._horizon_number_activations}_activations" + numpy_log_neuron_activity: Float[ + np.ndarray, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE) + ] = log_neuron_activity.cpu().numpy() + + np_histograms = [np.histogram(activity) for activity in numpy_log_neuron_activity] + + return [wandb.Histogram(np_histogram=histogram) for histogram in np_histograms] @property - def wandb_log_values(self) -> dict[str, Any]: - """Wandb log values.""" - log = { - f"train/activity/{self.name}/dead_count": self._dead_count, - f"train/activity/{self.name}/alive_count": self._alive_count, - f"train/activity/{self.name}/activity_histogram": self._activity_histogram, - f"train/activity/{self.name}/log_activity_histogram": self._log_activity_histogram, - } - - for threshold in self._thresholds: - almost_dead_count = self._almost_dead(threshold) - if almost_dead_count is not None: - log[f"train/activity/{self.name}/almost_dead_{threshold}"] = almost_dead_count - - return log + def metric_results(self) -> list[MetricResult]: + """Metric results.""" + metric_location = MetricLocation.TRAIN + name = "learned_neuron_activity" + + results = [ + MetricResult( + component_wise_values=self._dead_count, + location=metric_location, + name=name, + postfix=f"dead_over_{self._horizon_number_activations}_activations", + ), + MetricResult( + component_wise_values=self._alive_count, + location=metric_location, + name=name, + postfix=f"alive_over_{self._horizon_number_activations}_activations", + ), + MetricResult( + component_wise_values=self._activity_histogram, + location=metric_location, + name=name, + postfix=f"activity_histogram_over_{self._horizon_number_activations}_activations", + aggregate_approach=None, # Don't show aggregate across components + ), + MetricResult( + component_wise_values=self._log_activity_histogram, + location=metric_location, + name=name, + postfix=f"log_activity_histogram_over_{self._horizon_number_activations}_activations", + aggregate_approach=None, # Don't show aggregate across components + ), + ] + + threshold_results = [ + MetricResult( + component_wise_values=self._almost_dead(threshold), + location=metric_location, + name=name, + postfix=f"almost_dead_{threshold:.1e}_over_{self._horizon_number_activations}_activations", + ) + for threshold in self._thresholds + ] + + return results + threshold_results def __init__( self, approximate_activation_horizon: int, - train_batch_size: int, + number_components: int, number_learned_features: int, thresholds: list[float], + train_batch_size: int, ) -> None: """Initialise the neuron activity horizon data. Args: approximate_activation_horizon: Approximate activation horizon. - train_batch_size: Train batch size. + number_components: Number of components. number_learned_features: Number of learned features. thresholds: Thresholds for almost dead neurons. + train_batch_size: Train batch size. """ self._steps_since_last_calculated = 0 - self._neuron_activity = torch.zeros(number_learned_features, dtype=torch.int64) + self._neuron_activity = torch.zeros( + (number_components, number_learned_features), dtype=torch.int64 + ) self._thresholds = thresholds + self._n_components = number_components + self._n_learned_features = number_learned_features # Get a precise activation_horizon self._horizon_steps = approximate_activation_horizon // train_batch_size self._horizon_number_activations = self._horizon_steps * train_batch_size - def step(self, neuron_activity: Int64[Tensor, Axis.LEARNT_FEATURE]) -> dict[str, Any]: + def step( + self, neuron_activity: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)] + ) -> list[MetricResult]: """Step the neuron activity horizon data. Args: @@ -151,26 +203,23 @@ def step(self, neuron_activity: Int64[Tensor, Axis.LEARNT_FEATURE]) -> dict[str, Dictionary of metrics (or empty dictionary if no metrics are ready to be logged). """ self._steps_since_last_calculated += 1 - self._neuron_activity += neuron_activity + self._neuron_activity += neuron_activity.cpu() if self._steps_since_last_calculated >= self._horizon_steps: - result = {**self.wandb_log_values} + result = [*self.metric_results] self._steps_since_last_calculated = 0 self._neuron_activity = torch.zeros_like(self._neuron_activity) return result - return {} + return [] class NeuronActivityMetric(AbstractTrainMetric): """Neuron activity metric.""" _approximate_horizons: list[int] - _data: list[NeuronActivityHorizonData] - _initialised: bool = False - _thresholds: list[float] def __init__( @@ -199,6 +248,7 @@ def initialise_horizons(self, data: TrainMetricData) -> None: """ train_batch_size = data.learned_activations.shape[0] number_learned_features = data.learned_activations.shape[-1] + number_components = data.learned_activations.shape[-2] for horizon in self._approximate_horizons: # Don't add horizons that are smaller than the train batch size @@ -208,15 +258,16 @@ def initialise_horizons(self, data: TrainMetricData) -> None: self._data.append( NeuronActivityHorizonData( approximate_activation_horizon=horizon, - train_batch_size=train_batch_size, + number_components=number_components, number_learned_features=number_learned_features, thresholds=self._thresholds, + train_batch_size=train_batch_size, ) ) self._initialised = True - def calculate(self, data: TrainMetricData) -> dict[str, Any]: + def calculate(self, data: TrainMetricData) -> list[MetricResult]: """Calculate the neuron activity metrics. Args: @@ -228,13 +279,13 @@ def calculate(self, data: TrainMetricData) -> dict[str, Any]: if not self._initialised: self.initialise_horizons(data) - log = {} + fired_count: Int64[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)] = ( + data.learned_activations > 0 + ).sum(dim=0) - for horizon_data in self._data: - fired_count: Int64[Tensor, Axis.LEARNT_FEATURE] = ( - (data.learned_activations > 0).sum(dim=0).detach().cpu() - ) - horizon_specific_log = horizon_data.step(fired_count) - log.update(horizon_specific_log) + horizon_specific_logs: list[list[MetricResult]] = [ + horizon_data.step(fired_count) for horizon_data in self._data + ] - return log + # Flatten and return + return [log for logs in horizon_specific_logs for log in logs] diff --git a/sparse_autoencoder/metrics/train/tests/__snapshots__/test_capacities.ambr b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_capacities.ambr index 5e6ad8bc..b7aa5b1a 100644 --- a/sparse_autoencoder/metrics/train/tests/__snapshots__/test_capacities.ambr +++ b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_capacities.ambr @@ -1,25 +1,295 @@ # serializer version: 1 -# name: test_wandb_capacity_histogram +# name: test_weights_biases_log_matches_snapshot list([ - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 1, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 0, - 3, + dict({ + 'component_0/train/capacities': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.0, + 0.05000000074505806, + 0.10000000149011612, + 0.15000000596046448, + 0.20000000298023224, + 0.25, + 0.30000001192092896, + 0.3499999940395355, + 0.4000000059604645, + 0.44999998807907104, + 0.5, + 0.550000011920929, + 0.6000000238418579, + 0.6499999761581421, + 0.699999988079071, + 0.75, + 0.800000011920929, + 0.8500000238418579, + 0.8999999761581421, + 0.949999988079071, + 1.0, + ]), + histogram=list([ + 0, + 1, + 5, + 3, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + 'component_1/train/capacities': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.0, + 0.05000000074505806, + 0.10000000149011612, + 0.15000000596046448, + 0.20000000298023224, + 0.25, + 0.30000001192092896, + 0.3499999940395355, + 0.4000000059604645, + 0.44999998807907104, + 0.5, + 0.550000011920929, + 0.6000000238418579, + 0.6499999761581421, + 0.699999988079071, + 0.75, + 0.800000011920929, + 0.8500000238418579, + 0.8999999761581421, + 0.949999988079071, + 1.0, + ]), + histogram=list([ + 0, + 1, + 3, + 5, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + 'component_2/train/capacities': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.0, + 0.05000000074505806, + 0.10000000149011612, + 0.15000000596046448, + 0.20000000298023224, + 0.25, + 0.30000001192092896, + 0.3499999940395355, + 0.4000000059604645, + 0.44999998807907104, + 0.5, + 0.550000011920929, + 0.6000000238418579, + 0.6499999761581421, + 0.699999988079071, + 0.75, + 0.800000011920929, + 0.8500000238418579, + 0.8999999761581421, + 0.949999988079071, + 1.0, + ]), + histogram=list([ + 0, + 0, + 5, + 3, + 2, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + 'component_3/train/capacities': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.0, + 0.05000000074505806, + 0.10000000149011612, + 0.15000000596046448, + 0.20000000298023224, + 0.25, + 0.30000001192092896, + 0.3499999940395355, + 0.4000000059604645, + 0.44999998807907104, + 0.5, + 0.550000011920929, + 0.6000000238418579, + 0.6499999761581421, + 0.699999988079071, + 0.75, + 0.800000011920929, + 0.8500000238418579, + 0.8999999761581421, + 0.949999988079071, + 1.0, + ]), + histogram=list([ + 0, + 0, + 3, + 6, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + 'component_4/train/capacities': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.0, + 0.05000000074505806, + 0.10000000149011612, + 0.15000000596046448, + 0.20000000298023224, + 0.25, + 0.30000001192092896, + 0.3499999940395355, + 0.4000000059604645, + 0.44999998807907104, + 0.5, + 0.550000011920929, + 0.6000000238418579, + 0.6499999761581421, + 0.699999988079071, + 0.75, + 0.800000011920929, + 0.8500000238418579, + 0.8999999761581421, + 0.949999988079071, + 1.0, + ]), + histogram=list([ + 0, + 1, + 4, + 5, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + 'component_5/train/capacities': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.0, + 0.05000000074505806, + 0.10000000149011612, + 0.15000000596046448, + 0.20000000298023224, + 0.25, + 0.30000001192092896, + 0.3499999940395355, + 0.4000000059604645, + 0.44999998807907104, + 0.5, + 0.550000011920929, + 0.6000000238418579, + 0.6499999761581421, + 0.699999988079071, + 0.75, + 0.800000011920929, + 0.8500000238418579, + 0.8999999761581421, + 0.949999988079071, + 1.0, + ]), + histogram=list([ + 0, + 1, + 3, + 5, + 1, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + }), ]) # --- diff --git a/sparse_autoencoder/metrics/train/tests/__snapshots__/test_feature_density.ambr b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_feature_density.ambr new file mode 100644 index 00000000..af8dccf0 --- /dev/null +++ b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_feature_density.ambr @@ -0,0 +1,655 @@ +# serializer version: 1 +# name: test_weights_biases_log_matches_snapshot + list([ + dict({ + 'component_0/train/feature_density': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.5, + 0.5199999809265137, + 0.5400000214576721, + 0.5600000023841858, + 0.5799999833106995, + 0.6000000238418579, + 0.6200000047683716, + 0.6399999856948853, + 0.6600000262260437, + 0.6800000071525574, + 0.699999988079071, + 0.7200000286102295, + 0.7400000095367432, + 0.7599999904632568, + 0.7799999713897705, + 0.800000011920929, + 0.8199999928474426, + 0.8399999737739563, + 0.8600000143051147, + 0.8799999952316284, + 0.8999999761581421, + 0.9200000166893005, + 0.9399999976158142, + 0.9599999785423279, + 0.9800000190734863, + 1.0, + 1.0199999809265137, + 1.0399999618530273, + 1.059999942779541, + 1.0800000429153442, + 1.100000023841858, + 1.1200000047683716, + 1.1399999856948853, + 1.159999966621399, + 1.1799999475479126, + 1.2000000476837158, + 1.2200000286102295, + 1.2400000095367432, + 1.2599999904632568, + 1.2799999713897705, + 1.2999999523162842, + 1.3200000524520874, + 1.340000033378601, + 1.3600000143051147, + 1.3799999952316284, + 1.399999976158142, + 1.4199999570846558, + 1.440000057220459, + 1.4600000381469727, + 1.4800000190734863, + 1.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + 'component_1/train/feature_density': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.5, + 0.5199999809265137, + 0.5400000214576721, + 0.5600000023841858, + 0.5799999833106995, + 0.6000000238418579, + 0.6200000047683716, + 0.6399999856948853, + 0.6600000262260437, + 0.6800000071525574, + 0.699999988079071, + 0.7200000286102295, + 0.7400000095367432, + 0.7599999904632568, + 0.7799999713897705, + 0.800000011920929, + 0.8199999928474426, + 0.8399999737739563, + 0.8600000143051147, + 0.8799999952316284, + 0.8999999761581421, + 0.9200000166893005, + 0.9399999976158142, + 0.9599999785423279, + 0.9800000190734863, + 1.0, + 1.0199999809265137, + 1.0399999618530273, + 1.059999942779541, + 1.0800000429153442, + 1.100000023841858, + 1.1200000047683716, + 1.1399999856948853, + 1.159999966621399, + 1.1799999475479126, + 1.2000000476837158, + 1.2200000286102295, + 1.2400000095367432, + 1.2599999904632568, + 1.2799999713897705, + 1.2999999523162842, + 1.3200000524520874, + 1.340000033378601, + 1.3600000143051147, + 1.3799999952316284, + 1.399999976158142, + 1.4199999570846558, + 1.440000057220459, + 1.4600000381469727, + 1.4800000190734863, + 1.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + 'component_2/train/feature_density': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.5, + 0.5199999809265137, + 0.5400000214576721, + 0.5600000023841858, + 0.5799999833106995, + 0.6000000238418579, + 0.6200000047683716, + 0.6399999856948853, + 0.6600000262260437, + 0.6800000071525574, + 0.699999988079071, + 0.7200000286102295, + 0.7400000095367432, + 0.7599999904632568, + 0.7799999713897705, + 0.800000011920929, + 0.8199999928474426, + 0.8399999737739563, + 0.8600000143051147, + 0.8799999952316284, + 0.8999999761581421, + 0.9200000166893005, + 0.9399999976158142, + 0.9599999785423279, + 0.9800000190734863, + 1.0, + 1.0199999809265137, + 1.0399999618530273, + 1.059999942779541, + 1.0800000429153442, + 1.100000023841858, + 1.1200000047683716, + 1.1399999856948853, + 1.159999966621399, + 1.1799999475479126, + 1.2000000476837158, + 1.2200000286102295, + 1.2400000095367432, + 1.2599999904632568, + 1.2799999713897705, + 1.2999999523162842, + 1.3200000524520874, + 1.340000033378601, + 1.3600000143051147, + 1.3799999952316284, + 1.399999976158142, + 1.4199999570846558, + 1.440000057220459, + 1.4600000381469727, + 1.4800000190734863, + 1.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + 'component_3/train/feature_density': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.5, + 0.5199999809265137, + 0.5400000214576721, + 0.5600000023841858, + 0.5799999833106995, + 0.6000000238418579, + 0.6200000047683716, + 0.6399999856948853, + 0.6600000262260437, + 0.6800000071525574, + 0.699999988079071, + 0.7200000286102295, + 0.7400000095367432, + 0.7599999904632568, + 0.7799999713897705, + 0.800000011920929, + 0.8199999928474426, + 0.8399999737739563, + 0.8600000143051147, + 0.8799999952316284, + 0.8999999761581421, + 0.9200000166893005, + 0.9399999976158142, + 0.9599999785423279, + 0.9800000190734863, + 1.0, + 1.0199999809265137, + 1.0399999618530273, + 1.059999942779541, + 1.0800000429153442, + 1.100000023841858, + 1.1200000047683716, + 1.1399999856948853, + 1.159999966621399, + 1.1799999475479126, + 1.2000000476837158, + 1.2200000286102295, + 1.2400000095367432, + 1.2599999904632568, + 1.2799999713897705, + 1.2999999523162842, + 1.3200000524520874, + 1.340000033378601, + 1.3600000143051147, + 1.3799999952316284, + 1.399999976158142, + 1.4199999570846558, + 1.440000057220459, + 1.4600000381469727, + 1.4800000190734863, + 1.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + 'component_4/train/feature_density': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.5, + 0.5199999809265137, + 0.5400000214576721, + 0.5600000023841858, + 0.5799999833106995, + 0.6000000238418579, + 0.6200000047683716, + 0.6399999856948853, + 0.6600000262260437, + 0.6800000071525574, + 0.699999988079071, + 0.7200000286102295, + 0.7400000095367432, + 0.7599999904632568, + 0.7799999713897705, + 0.800000011920929, + 0.8199999928474426, + 0.8399999737739563, + 0.8600000143051147, + 0.8799999952316284, + 0.8999999761581421, + 0.9200000166893005, + 0.9399999976158142, + 0.9599999785423279, + 0.9800000190734863, + 1.0, + 1.0199999809265137, + 1.0399999618530273, + 1.059999942779541, + 1.0800000429153442, + 1.100000023841858, + 1.1200000047683716, + 1.1399999856948853, + 1.159999966621399, + 1.1799999475479126, + 1.2000000476837158, + 1.2200000286102295, + 1.2400000095367432, + 1.2599999904632568, + 1.2799999713897705, + 1.2999999523162842, + 1.3200000524520874, + 1.340000033378601, + 1.3600000143051147, + 1.3799999952316284, + 1.399999976158142, + 1.4199999570846558, + 1.440000057220459, + 1.4600000381469727, + 1.4800000190734863, + 1.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + 'component_5/train/feature_density': Histogram( + MAX_LENGTH=512, + bins=list([ + 0.5, + 0.5199999809265137, + 0.5400000214576721, + 0.5600000023841858, + 0.5799999833106995, + 0.6000000238418579, + 0.6200000047683716, + 0.6399999856948853, + 0.6600000262260437, + 0.6800000071525574, + 0.699999988079071, + 0.7200000286102295, + 0.7400000095367432, + 0.7599999904632568, + 0.7799999713897705, + 0.800000011920929, + 0.8199999928474426, + 0.8399999737739563, + 0.8600000143051147, + 0.8799999952316284, + 0.8999999761581421, + 0.9200000166893005, + 0.9399999976158142, + 0.9599999785423279, + 0.9800000190734863, + 1.0, + 1.0199999809265137, + 1.0399999618530273, + 1.059999942779541, + 1.0800000429153442, + 1.100000023841858, + 1.1200000047683716, + 1.1399999856948853, + 1.159999966621399, + 1.1799999475479126, + 1.2000000476837158, + 1.2200000286102295, + 1.2400000095367432, + 1.2599999904632568, + 1.2799999713897705, + 1.2999999523162842, + 1.3200000524520874, + 1.340000033378601, + 1.3600000143051147, + 1.3799999952316284, + 1.399999976158142, + 1.4199999570846558, + 1.440000057220459, + 1.4600000381469727, + 1.4800000190734863, + 1.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + ]), + ), + }), + ]) +# --- diff --git a/sparse_autoencoder/metrics/train/tests/__snapshots__/test_l0_norm_metric.ambr b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_l0_norm_metric.ambr new file mode 100644 index 00000000..8f2f6149 --- /dev/null +++ b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_l0_norm_metric.ambr @@ -0,0 +1,14 @@ +# serializer version: 1 +# name: test_weights_biases_log_matches_snapshot + list([ + dict({ + 'component_0/train/learned_activations_l0_norm': tensor(8.), + 'component_1/train/learned_activations_l0_norm': tensor(8.), + 'component_2/train/learned_activations_l0_norm': tensor(8.), + 'component_3/train/learned_activations_l0_norm': tensor(8.), + 'component_4/train/learned_activations_l0_norm': tensor(8.), + 'component_5/train/learned_activations_l0_norm': tensor(8.), + 'train/learned_activations_l0_norm': tensor([8., 8., 8., 8., 8., 8.]), + }), + ]) +# --- diff --git a/sparse_autoencoder/metrics/train/tests/__snapshots__/test_neuron_activity_metric.ambr b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_neuron_activity_metric.ambr new file mode 100644 index 00000000..55bc5ffb --- /dev/null +++ b/sparse_autoencoder/metrics/train/tests/__snapshots__/test_neuron_activity_metric.ambr @@ -0,0 +1,381 @@ +# serializer version: 1 +# name: test_weights_biases_log_matches_snapshot + list([ + dict({ + 'component_0/train/learned_neuron_activity/dead_over_10_activations': tensor(0), + 'component_1/train/learned_neuron_activity/dead_over_10_activations': tensor(0), + 'component_2/train/learned_neuron_activity/dead_over_10_activations': tensor(0), + 'component_3/train/learned_neuron_activity/dead_over_10_activations': tensor(0), + 'component_4/train/learned_neuron_activity/dead_over_10_activations': tensor(0), + 'component_5/train/learned_neuron_activity/dead_over_10_activations': tensor(0), + 'train/learned_neuron_activity/dead_over_10_activations': tensor([0, 0, 0, 0, 0, 0]), + }), + dict({ + 'component_0/train/learned_neuron_activity/alive_over_10_activations': tensor(8), + 'component_1/train/learned_neuron_activity/alive_over_10_activations': tensor(8), + 'component_2/train/learned_neuron_activity/alive_over_10_activations': tensor(8), + 'component_3/train/learned_neuron_activity/alive_over_10_activations': tensor(8), + 'component_4/train/learned_neuron_activity/alive_over_10_activations': tensor(8), + 'component_5/train/learned_neuron_activity/alive_over_10_activations': tensor(8), + 'train/learned_neuron_activity/alive_over_10_activations': tensor([8, 8, 8, 8, 8, 8]), + }), + dict({ + 'component_0/train/learned_neuron_activity/activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 9.5, + 9.6, + 9.7, + 9.8, + 9.9, + 10.0, + 10.1, + 10.2, + 10.3, + 10.4, + 10.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + 'component_1/train/learned_neuron_activity/activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 9.5, + 9.6, + 9.7, + 9.8, + 9.9, + 10.0, + 10.1, + 10.2, + 10.3, + 10.4, + 10.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + 'component_2/train/learned_neuron_activity/activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 9.5, + 9.6, + 9.7, + 9.8, + 9.9, + 10.0, + 10.1, + 10.2, + 10.3, + 10.4, + 10.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + 'component_3/train/learned_neuron_activity/activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 9.5, + 9.6, + 9.7, + 9.8, + 9.9, + 10.0, + 10.1, + 10.2, + 10.3, + 10.4, + 10.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + 'component_4/train/learned_neuron_activity/activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 9.5, + 9.6, + 9.7, + 9.8, + 9.9, + 10.0, + 10.1, + 10.2, + 10.3, + 10.4, + 10.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + 'component_5/train/learned_neuron_activity/activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 9.5, + 9.6, + 9.7, + 9.8, + 9.9, + 10.0, + 10.1, + 10.2, + 10.3, + 10.4, + 10.5, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + }), + dict({ + 'component_0/train/learned_neuron_activity/log_activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 1.8125355243682861, + 1.912535548210144, + 2.012535572052002, + 2.1125354766845703, + 2.2125356197357178, + 2.312535524368286, + 2.4125354290008545, + 2.512535572052002, + 2.6125354766845703, + 2.7125356197357178, + 2.812535524368286, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + 'component_1/train/learned_neuron_activity/log_activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 1.8125355243682861, + 1.912535548210144, + 2.012535572052002, + 2.1125354766845703, + 2.2125356197357178, + 2.312535524368286, + 2.4125354290008545, + 2.512535572052002, + 2.6125354766845703, + 2.7125356197357178, + 2.812535524368286, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + 'component_2/train/learned_neuron_activity/log_activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 1.8125355243682861, + 1.912535548210144, + 2.012535572052002, + 2.1125354766845703, + 2.2125356197357178, + 2.312535524368286, + 2.4125354290008545, + 2.512535572052002, + 2.6125354766845703, + 2.7125356197357178, + 2.812535524368286, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + 'component_3/train/learned_neuron_activity/log_activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 1.8125355243682861, + 1.912535548210144, + 2.012535572052002, + 2.1125354766845703, + 2.2125356197357178, + 2.312535524368286, + 2.4125354290008545, + 2.512535572052002, + 2.6125354766845703, + 2.7125356197357178, + 2.812535524368286, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + 'component_4/train/learned_neuron_activity/log_activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 1.8125355243682861, + 1.912535548210144, + 2.012535572052002, + 2.1125354766845703, + 2.2125356197357178, + 2.312535524368286, + 2.4125354290008545, + 2.512535572052002, + 2.6125354766845703, + 2.7125356197357178, + 2.812535524368286, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + 'component_5/train/learned_neuron_activity/log_activity_histogram_over_10_activations': Histogram( + MAX_LENGTH=512, + bins=list([ + 1.8125355243682861, + 1.912535548210144, + 2.012535572052002, + 2.1125354766845703, + 2.2125356197357178, + 2.312535524368286, + 2.4125354290008545, + 2.512535572052002, + 2.6125354766845703, + 2.7125356197357178, + 2.812535524368286, + ]), + histogram=list([ + 0, + 0, + 0, + 0, + 0, + 8, + 0, + 0, + 0, + 0, + ]), + ), + }), + dict({ + 'component_0/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0), + 'component_1/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0), + 'component_2/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0), + 'component_3/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0), + 'component_4/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0), + 'component_5/train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor(0), + 'train/learned_neuron_activity/almost_dead_1.0e-05_over_10_activations': tensor([0, 0, 0, 0, 0, 0]), + }), + dict({ + 'component_0/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0), + 'component_1/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0), + 'component_2/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0), + 'component_3/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0), + 'component_4/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0), + 'component_5/train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor(0), + 'train/learned_neuron_activity/almost_dead_1.0e-06_over_10_activations': tensor([0, 0, 0, 0, 0, 0]), + }), + ]) +# --- diff --git a/sparse_autoencoder/metrics/train/tests/test_abstract_train_metric.py b/sparse_autoencoder/metrics/train/tests/test_abstract_train_metric.py new file mode 100644 index 00000000..966bdc3f --- /dev/null +++ b/sparse_autoencoder/metrics/train/tests/test_abstract_train_metric.py @@ -0,0 +1,39 @@ +"""Tests for the abstract train metric class.""" +import torch + +from sparse_autoencoder.metrics.train.abstract_train_metric import TrainMetricData + + +def test_adds_component_dimension() -> None: + """Test that it adds a component dimension if not initialised with one.""" + d_batch: int = 2 + n_input_output_features: int = 4 + n_learned_features: int = 8 + + metric_data = TrainMetricData( + input_activations=torch.randn(d_batch, n_input_output_features), + learned_activations=torch.randn(d_batch, n_learned_features), + decoded_activations=torch.randn(d_batch, n_input_output_features), + ) + + assert metric_data.input_activations.shape == (d_batch, 1, n_input_output_features) + assert metric_data.learned_activations.shape == (d_batch, 1, n_learned_features) + assert metric_data.decoded_activations.shape == (d_batch, 1, n_input_output_features) + + +def test_no_changes_with_component_dimension_already_added() -> None: + """Test that it does not change the input if it already has a component dimension.""" + d_batch: int = 2 + n_components: int = 3 + n_input_output_features: int = 4 + n_learned_features: int = 8 + + metric_data = TrainMetricData( + input_activations=torch.randn(d_batch, n_components, n_input_output_features), + learned_activations=torch.randn(d_batch, n_components, n_learned_features), + decoded_activations=torch.randn(d_batch, n_components, n_input_output_features), + ) + + assert metric_data.input_activations.shape == (d_batch, n_components, n_input_output_features) + assert metric_data.learned_activations.shape == (d_batch, n_components, n_learned_features) + assert metric_data.decoded_activations.shape == (d_batch, n_components, n_input_output_features) diff --git a/sparse_autoencoder/metrics/train/tests/test_capacities.py b/sparse_autoencoder/metrics/train/tests/test_capacities.py index 7400df2f..3777ab62 100644 --- a/sparse_autoencoder/metrics/train/tests/test_capacities.py +++ b/sparse_autoencoder/metrics/train/tests/test_capacities.py @@ -10,25 +10,35 @@ from sparse_autoencoder.metrics.train.abstract_train_metric import TrainMetricData from sparse_autoencoder.metrics.train.capacity import CapacityMetric +from sparse_autoencoder.metrics.utils.find_metric_result import find_metric_result from sparse_autoencoder.tensor_types import Axis @pytest.mark.parametrize( ("features", "expected_capacities"), [ - ( - torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]]), - torch.tensor([1.0, 1.0]), + pytest.param( + torch.tensor([[[1.0, 0.0, 0.0]], [[0.0, 1.0, 0.0]]]), + torch.tensor([[1.0, 1.0]]), + id="orthogonal", ), - ( - torch.tensor([[-0.8, -0.8, -0.8], [-0.8, -0.8, -0.8]]), - torch.ones(2) / 2, + pytest.param( + torch.tensor( + [[[1.0, 0.0, 0.0], [-0.8, -0.8, -0.8]], [[0.0, 1.0, 0.0], [-0.8, -0.8, -0.8]]] + ), + torch.tensor([[1.0, 1.0], [0.5, 0.5]]), + id="orthogonal_2_components", ), - ( + pytest.param( + torch.tensor([[[-0.8, -0.8, -0.8]], [[-0.8, -0.8, -0.8]]]), + torch.ones(2).unsqueeze(0) / 2, + id="same_feature", + ), + pytest.param( torch.tensor( - [[1.0, 0.0, 0], [1 / math.sqrt(2), 1 / math.sqrt(2), 0.0], [0.0, 0.0, 1.0]] + [[[1.0, 0.0, 0]], [[1 / math.sqrt(2), 1 / math.sqrt(2), 0.0]], [[0.0, 0.0, 1.0]]] ), - torch.tensor([2 / 3, 2 / 3, 1.0]), + torch.tensor([2 / 3, 2 / 3, 1.0]).unsqueeze(0), ), ], ) @@ -43,14 +53,6 @@ def test_calc_capacities( ), "Capacity calculation is incorrect." -def test_wandb_capacity_histogram(snapshot: SnapshotSession) -> None: - """Check the Weights & Biases Histogram is created correctly.""" - capacities = torch.tensor([0.5, 0.1, 1, 1, 1]) - res = CapacityMetric.wandb_capacities_histogram(capacities) - - assert res.histogram == snapshot - - def test_calculate_returns_histogram() -> None: """Check the calculate function returns a histogram.""" metric = CapacityMetric() @@ -62,4 +64,31 @@ def test_calculate_returns_histogram() -> None: decoded_activations=activations, ) ) - assert "train/batch_capacities_histogram" in res + find_metric_result(res, name="capacities") + + +def test_weights_biases_log_matches_snapshot(snapshot: SnapshotSession) -> None: + """Test the log function for Weights & Biases.""" + n_batches = 10 + n_components = 6 + n_input_features = 4 + n_learned_features = 8 + + # Create some data + torch.manual_seed(0) + data = TrainMetricData( + input_activations=torch.rand((n_batches, n_components, n_input_features)), + learned_activations=torch.rand((n_batches, n_components, n_learned_features)), + decoded_activations=torch.rand((n_batches, n_components, n_input_features)), + ) + + # Get the wandb log + metric = CapacityMetric() + results = metric.calculate(data) + weights_biases_logs = [result.wandb_log for result in results] + + assert len(weights_biases_logs) == 1, """Should only be one metric result.""" + assert ( + len(results[0].component_wise_values) == n_components + ), """Should be one histogram per component.""" + assert weights_biases_logs == snapshot diff --git a/sparse_autoencoder/metrics/train/tests/test_feature_density.py b/sparse_autoencoder/metrics/train/tests/test_feature_density.py index ae2460c5..76286162 100644 --- a/sparse_autoencoder/metrics/train/tests/test_feature_density.py +++ b/sparse_autoencoder/metrics/train/tests/test_feature_density.py @@ -1,14 +1,16 @@ """Test the feature density metric.""" +from syrupy.session import SnapshotSession import torch from sparse_autoencoder.metrics.train.abstract_train_metric import TrainMetricData from sparse_autoencoder.metrics.train.feature_density import TrainBatchFeatureDensityMetric +from sparse_autoencoder.metrics.utils.find_metric_result import find_metric_result def test_calc_feature_density() -> None: """Check that the feature density matches an alternative way of doing the calc.""" - activations = torch.tensor([[0.5, 0.5, 0.0], [0.5, 0.0, 0.0001], [0.0, 0.1, 0.0]]) + activations = torch.tensor([[[0.5, 0.5, 0.0]], [[0.5, 0.0, 0.0001]], [[0.0, 0.1, 0.0]]]) # Use different approach to check threshold = 0.01 @@ -21,17 +23,17 @@ def test_calc_feature_density() -> None: def test_wandb_feature_density_histogram() -> None: """Check the Weights & Biases Histogram is created correctly.""" - feature_density = torch.tensor([0.001, 0.001, 0.001, 0.5, 0.5, 1.0]) + feature_density = torch.tensor([[0.001, 0.001, 0.001, 0.5, 0.5, 1.0]]) res = TrainBatchFeatureDensityMetric().wandb_feature_density_histogram(feature_density) # Check 0.001 is in the first bin 3 times expected_first_bin_value = 3 - assert res.histogram[0] == expected_first_bin_value + assert res[0].histogram[0] == expected_first_bin_value def test_calculate_aggregates() -> None: """Check that the metrics are aggregated in the calculate method.""" - activations = torch.tensor([[0.5, 0.5, 0.0], [0.5, 0.0, 0.0001], [0.0, 0.1, 0.0]]) + activations = torch.tensor([[[0.5, 0.5, 0.0]], [[0.5, 0.0, 0.0001]], [[0.0, 0.1, 0.0]]]) res = TrainBatchFeatureDensityMetric().calculate( TrainMetricData( input_activations=activations, @@ -40,5 +42,31 @@ def test_calculate_aggregates() -> None: ) ) - # Check both metrics are in the result - assert "train/batch_feature_density_histogram" in res + find_metric_result(res, name="feature_density") + + +def test_weights_biases_log_matches_snapshot(snapshot: SnapshotSession) -> None: + """Test the log function for Weights & Biases.""" + n_batches = 10 + n_components = 6 + n_input_features = 4 + n_learned_features = 8 + + # Create some data + torch.manual_seed(0) + data = TrainMetricData( + input_activations=torch.rand((n_batches, n_components, n_input_features)), + learned_activations=torch.rand((n_batches, n_components, n_learned_features)), + decoded_activations=torch.rand((n_batches, n_components, n_input_features)), + ) + + # Get the wandb log + metric = TrainBatchFeatureDensityMetric() + results = metric.calculate(data) + weights_biases_logs = [result.wandb_log for result in results] + + assert len(weights_biases_logs) == 1, """Should only be one metric result.""" + assert ( + len(results[0].component_wise_values) == n_components + ), """Should be one histogram per component.""" + assert weights_biases_logs == snapshot diff --git a/sparse_autoencoder/metrics/train/tests/test_l0_norm_metric.py b/sparse_autoencoder/metrics/train/tests/test_l0_norm_metric.py index dafee027..eb16c4aa 100644 --- a/sparse_autoencoder/metrics/train/tests/test_l0_norm_metric.py +++ b/sparse_autoencoder/metrics/train/tests/test_l0_norm_metric.py @@ -1,4 +1,5 @@ """Tests for the L0NormMetric class.""" +from syrupy.session import SnapshotSession import torch from sparse_autoencoder.metrics.train.abstract_train_metric import TrainMetricData @@ -7,7 +8,7 @@ def test_l0_norm_metric() -> None: """Test the L0NormMetric.""" - learned_activations = torch.tensor([[1.0, 0.0, 0.0], [0.0, 0.01, 2.0]]) + learned_activations = torch.tensor([[[1.0, 0.0, 0.0]], [[0.0, 0.01, 2.0]]]) l0_norm_metric = TrainBatchLearnedActivationsL0() data = TrainMetricData( input_activations=torch.zeros_like(learned_activations), @@ -16,4 +17,31 @@ def test_l0_norm_metric() -> None: ) log = l0_norm_metric.calculate(data) expected = 3 / 2 - assert log["train/learned_activations_l0_norm"] == expected + assert log[0].component_wise_values == expected + + +def test_weights_biases_log_matches_snapshot(snapshot: SnapshotSession) -> None: + """Test the log function for Weights & Biases.""" + n_batches = 10 + n_components = 6 + n_input_features = 4 + n_learned_features = 8 + + # Create some data + torch.manual_seed(0) + data = TrainMetricData( + input_activations=torch.rand((n_batches, n_components, n_input_features)), + learned_activations=torch.rand((n_batches, n_components, n_learned_features)), + decoded_activations=torch.rand((n_batches, n_components, n_input_features)), + ) + + # Get the wandb log + metric = TrainBatchLearnedActivationsL0() + results = metric.calculate(data) + weights_biases_logs = [result.wandb_log for result in results] + + assert len(weights_biases_logs) == 1, """Should only be one metric result.""" + assert ( + len(results[0].component_wise_values) == n_components + ), """Should be one histogram per component.""" + assert weights_biases_logs == snapshot diff --git a/sparse_autoencoder/metrics/train/tests/test_neuron_activity_metric.py b/sparse_autoencoder/metrics/train/tests/test_neuron_activity_metric.py index 727bb8c2..666b1691 100644 --- a/sparse_autoencoder/metrics/train/tests/test_neuron_activity_metric.py +++ b/sparse_autoencoder/metrics/train/tests/test_neuron_activity_metric.py @@ -1,12 +1,17 @@ """Tests for the NeuronActivityMetric class.""" +from jaxtyping import Float, Int64 import pytest +from syrupy.session import SnapshotSession import torch +from torch import Tensor from sparse_autoencoder.metrics.train.abstract_train_metric import TrainMetricData from sparse_autoencoder.metrics.train.neuron_activity_metric import ( NeuronActivityHorizonData, NeuronActivityMetric, ) +from sparse_autoencoder.metrics.utils.find_metric_result import find_metric_result +from sparse_autoencoder.tensor_types import Axis @pytest.fixture() @@ -30,13 +35,21 @@ def sample_neuron_activity_data() -> TrainMetricData: class TestNeuronActivityHorizonData: """Test the NeuronActivityHorizonData class.""" - def test_initialisation(self) -> None: + @pytest.mark.parametrize( + ("number_components"), + [ + pytest.param(1, id="1 component"), + pytest.param(2, id="2 components"), + ], + ) + def test_initialisation(self, number_components: int) -> None: """Test it initialises without errors.""" NeuronActivityHorizonData( approximate_activation_horizon=5, train_batch_size=2, number_learned_features=10, thresholds=[0.5], + number_components=number_components, ) def test_step_calculates_when_at_horizon( @@ -51,51 +64,103 @@ def test_step_calculates_when_at_horizon( train_batch_size=train_batch_size, number_learned_features=4, thresholds=[0.5], + number_components=1, ) for step in range(1, 10): - data = torch.randint(0, 2, (1, 4)).squeeze() + data = torch.randint(0, 2, (1, 4)) res = threshold_data_store.step(data) if step % horizon_in_steps == 0: - assert len(res.keys()) > 0 + assert len(res) > 0 else: - assert len(res.keys()) == 0 + assert len(res) == 0 - def test_results(self) -> None: + def test_results_match_expectations(self) -> None: """Test that the results are calculated correctly.""" threshold_data_store = NeuronActivityHorizonData( approximate_activation_horizon=30, train_batch_size=30, number_learned_features=5, thresholds=[0.5], + number_components=1, ) - data = torch.tensor([0, 30, 4, 1, 0]) + data = torch.tensor([[0, 30, 4, 1, 0]]) res = threshold_data_store.step(data) expected_dead = 2 expected_alive = 3 expected_almost_dead = 4 - assert res["train/activity/over_30_activations/dead_count"] == expected_dead - assert res["train/activity/over_30_activations/alive_count"] == expected_alive - assert res["train/activity/over_30_activations/almost_dead_0.5"] == expected_almost_dead + dead_over_30_activations = find_metric_result(res, postfix="dead_over_30_activations") + assert dead_over_30_activations.component_wise_values[0] == expected_dead + + alive_over_30_activations = find_metric_result(res, postfix="alive_over_30_activations") + assert alive_over_30_activations.component_wise_values[0] == expected_alive + + almost_dead_over_30_activations = find_metric_result( + res, postfix="almost_dead_5.0e-01_over_30_activations" + ) + assert almost_dead_over_30_activations.component_wise_values[0] == expected_almost_dead class TestNeuronActivityMetric: """Test the NeuronActivityMetric class.""" - def test_dead_neuron_count(self, sample_neuron_activity_data: TrainMetricData) -> None: - """Test if dead neuron count is correctly calculated. - - Args: - sample_neuron_activity_data: The sample neuron activity data for testing. - """ + @pytest.mark.parametrize( + ("learned_activations", "expected_dead_count", "expected_alive_count"), + [ + pytest.param( + torch.tensor([[0.0, 0, 0, 0, 0]]), + torch.tensor([5]), + torch.tensor([0]), + id="All dead", + ), + pytest.param( + torch.tensor([[1.0, 1, 1, 1, 1]]), + torch.tensor([0]), + torch.tensor([5]), + id="All alive", + ), + pytest.param( + torch.tensor([[0.0, 1, 0, 1, 0]]), + torch.tensor([3]), + torch.tensor([2]), + id="Some dead", + ), + pytest.param( + torch.tensor([[[0.0, 1, 0, 1, 0], [0.0, 0, 0, 0, 0]]]), + torch.tensor([3, 5]), + torch.tensor([2, 0]), + id="Multiple components with some dead", + ), + ], + ) + def test_dead_neuron_count( + self, + learned_activations: Float[Tensor, Axis.names(Axis.COMPONENT, Axis.LEARNT_FEATURE)], + expected_dead_count: Int64[Tensor, Axis.names(Axis.COMPONENT)], + expected_alive_count: Int64[Tensor, Axis.names(Axis.COMPONENT)], + ) -> None: + """Test if dead neuron count is correctly calculated.""" + input_activations = torch.zeros_like(learned_activations, dtype=torch.float) + data = TrainMetricData( + learned_activations=learned_activations, + # Input and decoded activations are not used in this metric + input_activations=input_activations, + decoded_activations=input_activations, + ) neuron_activity_metric = NeuronActivityMetric(approximate_horizons=[1]) - metrics = neuron_activity_metric.calculate(sample_neuron_activity_data) - expected_dead_neuron_count = 2 - assert metrics["train/activity/over_1_activations/dead_count"] == expected_dead_neuron_count + metrics = neuron_activity_metric.calculate(data) + + dead_over_1_activations = find_metric_result(metrics, postfix="dead_over_1_activations") + alive_over_1_activations = find_metric_result(metrics, postfix="alive_over_1_activations") + + assert isinstance(dead_over_1_activations.component_wise_values, torch.Tensor) + assert isinstance(alive_over_1_activations.component_wise_values, torch.Tensor) + assert torch.allclose(dead_over_1_activations.component_wise_values, expected_dead_count) + assert torch.allclose(alive_over_1_activations.component_wise_values, expected_alive_count) def test_alive_neuron_count(self, sample_neuron_activity_data: TrainMetricData) -> None: """Test if alive neuron count is correctly calculated. @@ -106,9 +171,8 @@ def test_alive_neuron_count(self, sample_neuron_activity_data: TrainMetricData) neuron_activity_metric = NeuronActivityMetric(approximate_horizons=[1]) metrics = neuron_activity_metric.calculate(sample_neuron_activity_data) expected_alive_neuron_count = 8 - assert ( - metrics["train/activity/over_1_activations/alive_count"] == expected_alive_neuron_count - ) + alive_over_1_activations = find_metric_result(metrics, postfix="alive_over_1_activations") + assert alive_over_1_activations.component_wise_values[0] == expected_alive_neuron_count def test_histogram_generation(self, sample_neuron_activity_data: TrainMetricData) -> None: """Test if histogram is correctly generated in the metrics. @@ -119,8 +183,34 @@ def test_histogram_generation(self, sample_neuron_activity_data: TrainMetricData neuron_activity_metric = NeuronActivityMetric(approximate_horizons=[5]) for _ in range(4): metrics = neuron_activity_metric.calculate(sample_neuron_activity_data) - assert metrics == {} metrics = neuron_activity_metric.calculate(sample_neuron_activity_data) - assert "train/activity/over_5_activations/activity_histogram" in metrics - assert "train/activity/over_5_activations/log_activity_histogram" in metrics + + find_metric_result(metrics, postfix="activity_histogram_over_5_activations") + find_metric_result(metrics, postfix="log_activity_histogram_over_5_activations") + + +def test_weights_biases_log_matches_snapshot(snapshot: SnapshotSession) -> None: + """Test the log function for Weights & Biases.""" + n_batches = 10 + n_components = 6 + n_input_features = 4 + n_learned_features = 8 + + # Create some data + torch.manual_seed(0) + data = TrainMetricData( + input_activations=torch.rand((n_batches, n_components, n_input_features)), + learned_activations=torch.rand((n_batches, n_components, n_learned_features)), + decoded_activations=torch.rand((n_batches, n_components, n_input_features)), + ) + + # Get the wandb log + metric = NeuronActivityMetric(approximate_horizons=[n_batches]) + results = metric.calculate(data) + weights_biases_logs = [result.wandb_log for result in results] + + assert ( + len(results[0].component_wise_values) == n_components + ), """Should be one histogram per component.""" + assert weights_biases_logs == snapshot diff --git a/sparse_autoencoder/metrics/utils/__init__.py b/sparse_autoencoder/metrics/utils/__init__.py new file mode 100644 index 00000000..1036c7db --- /dev/null +++ b/sparse_autoencoder/metrics/utils/__init__.py @@ -0,0 +1 @@ +"""Metric utils.""" diff --git a/sparse_autoencoder/metrics/utils/add_component_axis_if_missing.py b/sparse_autoencoder/metrics/utils/add_component_axis_if_missing.py new file mode 100644 index 00000000..13fa77de --- /dev/null +++ b/sparse_autoencoder/metrics/utils/add_component_axis_if_missing.py @@ -0,0 +1,50 @@ +"""Util to add a component axis (dimension) if missing to a tensor.""" +from torch import Tensor + + +def add_component_axis_if_missing( + input_tensor: Tensor, + unsqueeze_dim: int = 1, + dimensions_without_component: int = 1, +) -> Tensor: + """Add component axis if missing. + + Examples: + If the component axis is missing, add it: + + >>> import torch + >>> input = torch.tensor([1.0, 2.0, 3.0]) + >>> add_component_axis_if_missing(input) + tensor([[1.], + [2.], + [3.]]) + + If the component axis is present, do nothing: + + >>> import torch + >>> input = torch.tensor([[1.0], [2.0], [3.0]]) + >>> add_component_axis_if_missing(input) + tensor([[1.], + [2.], + [3.]]) + + Args: + input_tensor: Tensor with or without a component axis. + unsqueeze_dim: The dimension to unsqueeze the component axis. + dimensions_without_component: The number of dimensions of the input tensor without a + component axis. + + Returns: + Tensor with a component axis. + + Raises: + ValueError: If the number of dimensions of the input tensor is not supported. + """ + if input_tensor.ndim == dimensions_without_component: + return input_tensor.unsqueeze(unsqueeze_dim) + + if input_tensor.ndim == dimensions_without_component + 1: + return input_tensor + + error_message = f"Unexpected number of dimensions: {input_tensor.ndim}" + raise ValueError(error_message) diff --git a/sparse_autoencoder/metrics/utils/find_metric_result.py b/sparse_autoencoder/metrics/utils/find_metric_result.py new file mode 100644 index 00000000..4053175c --- /dev/null +++ b/sparse_autoencoder/metrics/utils/find_metric_result.py @@ -0,0 +1,74 @@ +"""Find metric result.""" +from sparse_autoencoder.metrics.abstract_metric import MetricLocation, MetricResult + + +def find_metric_result( + metrics: list[MetricResult], + *, + location: MetricLocation | None = None, + name: str | None = None, + postfix: str | None = None, +) -> MetricResult: + """Find exactly one metric result from a list of results. + + Motivation: + For automated testing, it's useful to search for a specific result and check it is as + expected. + + Example: + >>> import torch + >>> metric_results = [ + ... MetricResult( + ... component_wise_values=torch.tensor([1.0, 2.0, 3.0]), + ... location=MetricLocation.TRAIN, + ... name="loss", + ... postfix="baseline_loss", + ... ), + ... MetricResult( + ... component_wise_values=torch.tensor([4.0, 5.0, 6.0]), + ... location=MetricLocation.TRAIN, + ... name="loss", + ... postfix="loss_with_reconstruction", + ... ) + ... ] + >>> find_metric_result( + ... metric_results, name="loss", postfix="baseline_loss" + ... ).component_wise_values + tensor([1., 2., 3.]) + + Args: + metrics: List of metric results. + location: Location of the metric to find. None means all locations. + name: Name of the metric to find. None means all names. + postfix: Postfix of the metric to find. None means **no postfix**. + + Returns: + Metric result. + + Raises: + ValueError: If the metric is not found. + """ + if name is None and postfix is None and location is None: + error_message = "At least one of name, postfix or location must be provided." + raise ValueError(error_message) + + results: list[MetricResult] = [] + + for metric in metrics: + if ( + (metric.location == location or location is None) + and (metric.name == name or name is None) + and (metric.postfix == postfix) + ): + results.append(metric) # noqa: PERF401 + + if len(results) == 0: + metric_names = ",\n ".join([f"{metric.name} {metric.postfix or ''}" for metric in metrics]) + error_message = f"Metric not found. The only metrics found were:\n {metric_names}" + raise ValueError(error_message) + + if len(results) == 1: + return results[0] + + error_message = f"Multiple metrics found: name={name}, postfix={postfix}, location={location}" + raise ValueError(error_message) diff --git a/sparse_autoencoder/metrics/validate/abstract_validate_metric.py b/sparse_autoencoder/metrics/validate/abstract_validate_metric.py index 778ee53d..68bfc3c2 100644 --- a/sparse_autoencoder/metrics/validate/abstract_validate_metric.py +++ b/sparse_autoencoder/metrics/validate/abstract_validate_metric.py @@ -1,28 +1,68 @@ """Abstract metric classes.""" from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any +from typing import final from jaxtyping import Float from torch import Tensor +from sparse_autoencoder.metrics.abstract_metric import ( + AbstractMetric, + MetricLocation, + MetricResult, +) +from sparse_autoencoder.metrics.utils.add_component_axis_if_missing import ( + add_component_axis_if_missing, +) from sparse_autoencoder.tensor_types import Axis +@final @dataclass class ValidationMetricData: - """Validation metric data.""" + """Validation metric data. - source_model_loss: Float[Tensor, Axis.ITEMS] + Dataclass that always has a component axis. + """ - source_model_loss_with_reconstruction: Float[Tensor, Axis.ITEMS] + source_model_loss: Float[Tensor, Axis.names(Axis.ITEMS, Axis.COMPONENT)] + """Source model loss (without the SAE).""" - source_model_loss_with_zero_ablation: Float[Tensor, Axis.ITEMS] + source_model_loss_with_reconstruction: Float[Tensor, Axis.names(Axis.ITEMS, Axis.COMPONENT)] + """Source model loss with SAE reconstruction.""" + source_model_loss_with_zero_ablation: Float[Tensor, Axis.names(Axis.ITEMS, Axis.COMPONENT)] + """Source model loss with zero ablation.""" -class AbstractValidationMetric(ABC): + def __init__( + self, + source_model_loss: Float[Tensor, Axis.names(Axis.ITEMS, Axis.COMPONENT_OPTIONAL)], + source_model_loss_with_reconstruction: Float[ + Tensor, Axis.names(Axis.ITEMS, Axis.COMPONENT_OPTIONAL) + ], + source_model_loss_with_zero_ablation: Float[ + Tensor, Axis.names(Axis.ITEMS, Axis.COMPONENT_OPTIONAL) + ], + ) -> None: + """Initialize the validation metric data.""" + self.source_model_loss = add_component_axis_if_missing(source_model_loss).detach() + self.source_model_loss_with_reconstruction = add_component_axis_if_missing( + source_model_loss_with_reconstruction + ).detach() + self.source_model_loss_with_zero_ablation = add_component_axis_if_missing( + source_model_loss_with_zero_ablation + ).detach() + + +class AbstractValidationMetric(AbstractMetric, ABC): """Abstract validation metric.""" + @final + @property + def location(self) -> MetricLocation: + """Metric type name.""" + return MetricLocation.VALIDATE + @abstractmethod - def calculate(self, data: ValidationMetricData) -> dict[str, Any]: + def calculate(self, data: ValidationMetricData) -> list[MetricResult]: """Calculate any metrics.""" diff --git a/sparse_autoencoder/metrics/validate/model_reconstruction_score.py b/sparse_autoencoder/metrics/validate/model_reconstruction_score.py index 0c0cadda..315d127e 100644 --- a/sparse_autoencoder/metrics/validate/model_reconstruction_score.py +++ b/sparse_autoencoder/metrics/validate/model_reconstruction_score.py @@ -1,6 +1,7 @@ """Model reconstruction score.""" -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING +from sparse_autoencoder.metrics.abstract_metric import MetricResult from sparse_autoencoder.metrics.validate.abstract_validate_metric import ( AbstractValidationMetric, ValidationMetricData, @@ -32,7 +33,7 @@ class ModelReconstructionScore(AbstractValidationMetric): $$ """ - def calculate(self, data: ValidationMetricData) -> dict[str, Any]: + def calculate(self, data: ValidationMetricData) -> list[MetricResult]: """Calculate the model reconstruction score. Example: @@ -44,7 +45,7 @@ def calculate(self, data: ValidationMetricData) -> dict[str, Any]: ... ) >>> metric = ModelReconstructionScore() >>> result = metric.calculate(data) - >>> round(result['validate/model_reconstruction_score'], 3) + >>> round(result[3].aggregate_value.item(), 3) 0.667 Args: @@ -55,32 +56,47 @@ def calculate(self, data: ValidationMetricData) -> dict[str, Any]: """ # Return no statistics if the data is empty (e.g. if we're at the very end of training) if data.source_model_loss.numel() == 0: - return {} + return [] # Calculate the reconstruction score - zero_ablate_loss_minus_default_loss: Float[Tensor, Axis.ITEMS] = ( - data.source_model_loss_with_zero_ablation - data.source_model_loss - ) - zero_ablate_loss_minus_reconstruction_loss: Float[Tensor, Axis.ITEMS] = ( - data.source_model_loss_with_zero_ablation - data.source_model_loss_with_reconstruction - ) - model_reconstruction_score: float = ( - zero_ablate_loss_minus_reconstruction_loss.mean().item() - / zero_ablate_loss_minus_default_loss.mean().item() - ) + zero_ablate_loss_minus_default_loss: Float[ + Tensor, Axis.names(Axis.ITEMS, Axis.COMPONENT_OPTIONAL) + ] = data.source_model_loss_with_zero_ablation - data.source_model_loss + zero_ablate_loss_minus_reconstruction_loss: Float[ + Tensor, Axis.names(Axis.ITEMS, Axis.COMPONENT_OPTIONAL) + ] = data.source_model_loss_with_zero_ablation - data.source_model_loss_with_reconstruction + + model_reconstruction_score = zero_ablate_loss_minus_reconstruction_loss.mean( + 0 + ) / zero_ablate_loss_minus_default_loss.mean(0) # Get the other metrics - validation_baseline_loss: float = data.source_model_loss.mean().item() - validation_loss_with_reconstruction: float = ( - data.source_model_loss_with_reconstruction.mean().item() - ) - validation_loss_with_zero_ablation: float = ( - data.source_model_loss_with_zero_ablation.mean().item() - ) - - return { - "validate/baseline_loss": validation_baseline_loss, - "validate/loss_with_reconstruction": validation_loss_with_reconstruction, - "validate/loss_with_zero_ablation": validation_loss_with_zero_ablation, - "validate/model_reconstruction_score": model_reconstruction_score, - } + validation_baseline_loss = data.source_model_loss.mean(0) + validation_loss_with_reconstruction = data.source_model_loss_with_reconstruction.mean(0) + validation_loss_with_zero_ablation = data.source_model_loss_with_zero_ablation.mean(0) + + return [ + MetricResult( + component_wise_values=validation_baseline_loss, + location=self.location, + name="reconstruction_score", + postfix="baseline_loss", + ), + MetricResult( + component_wise_values=validation_loss_with_reconstruction, + location=self.location, + name="reconstruction_score", + postfix="loss_with_reconstruction", + ), + MetricResult( + component_wise_values=validation_loss_with_zero_ablation, + location=self.location, + name="reconstruction_score", + postfix="loss_with_zero_ablation", + ), + MetricResult( + component_wise_values=model_reconstruction_score, + location=self.location, + name="reconstruction_score", + ), + ] diff --git a/sparse_autoencoder/metrics/validate/tests/__snapshots__/test_model_reconstruction_score.ambr b/sparse_autoencoder/metrics/validate/tests/__snapshots__/test_model_reconstruction_score.ambr new file mode 100644 index 00000000..6284e1e2 --- /dev/null +++ b/sparse_autoencoder/metrics/validate/tests/__snapshots__/test_model_reconstruction_score.ambr @@ -0,0 +1,42 @@ +# serializer version: 1 +# name: test_weights_biases_log_matches_snapshot + list([ + dict({ + 'component_0/validate/reconstruction_score/baseline_loss': tensor(0.3800), + 'component_1/validate/reconstruction_score/baseline_loss': tensor(0.5251), + 'component_2/validate/reconstruction_score/baseline_loss': tensor(0.4923), + 'component_3/validate/reconstruction_score/baseline_loss': tensor(0.4598), + 'component_4/validate/reconstruction_score/baseline_loss': tensor(0.4281), + 'component_5/validate/reconstruction_score/baseline_loss': tensor(0.4961), + 'validate/reconstruction_score/baseline_loss': tensor([0.3800, 0.5251, 0.4923, 0.4598, 0.4281, 0.4961]), + }), + dict({ + 'component_0/validate/reconstruction_score/loss_with_reconstruction': tensor(0.6111), + 'component_1/validate/reconstruction_score/loss_with_reconstruction': tensor(0.5219), + 'component_2/validate/reconstruction_score/loss_with_reconstruction': tensor(0.4063), + 'component_3/validate/reconstruction_score/loss_with_reconstruction': tensor(0.6497), + 'component_4/validate/reconstruction_score/loss_with_reconstruction': tensor(0.4929), + 'component_5/validate/reconstruction_score/loss_with_reconstruction': tensor(0.3723), + 'validate/reconstruction_score/loss_with_reconstruction': tensor([0.6111, 0.5219, 0.4063, 0.6497, 0.4929, 0.3723]), + }), + dict({ + 'component_0/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.2891), + 'component_1/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.3879), + 'component_2/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.5850), + 'component_3/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.4740), + 'component_4/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.5452), + 'component_5/validate/reconstruction_score/loss_with_zero_ablation': tensor(0.3733), + 'validate/reconstruction_score/loss_with_zero_ablation': tensor([0.2891, 0.3879, 0.5850, 0.4740, 0.5452, 0.3733]), + }), + dict({ + 'component_0/validate/reconstruction_score': tensor(3.5422), + 'component_1/validate/reconstruction_score': tensor(0.9767), + 'component_2/validate/reconstruction_score': tensor(1.9278), + 'component_3/validate/reconstruction_score': tensor(-12.3338), + 'component_4/validate/reconstruction_score': tensor(0.4468), + 'component_5/validate/reconstruction_score': tensor(-0.0081), + 'validate/reconstruction_score': tensor([ 3.5422e+00, 9.7672e-01, 1.9278e+00, -1.2334e+01, 4.4681e-01, + -8.1113e-03]), + }), + ]) +# --- diff --git a/sparse_autoencoder/metrics/validate/tests/test_model_reconstruction_score.py b/sparse_autoencoder/metrics/validate/tests/test_model_reconstruction_score.py index a5cf996d..1cf8739c 100644 --- a/sparse_autoencoder/metrics/validate/tests/test_model_reconstruction_score.py +++ b/sparse_autoencoder/metrics/validate/tests/test_model_reconstruction_score.py @@ -2,8 +2,11 @@ from jaxtyping import Float import pytest +from syrupy.session import SnapshotSession +import torch from torch import Tensor +from sparse_autoencoder.metrics.utils.find_metric_result import find_metric_result from sparse_autoencoder.metrics.validate.abstract_validate_metric import ValidationMetricData from sparse_autoencoder.metrics.validate.model_reconstruction_score import ModelReconstructionScore from sparse_autoencoder.tensor_types import Axis @@ -22,7 +25,7 @@ def test_model_reconstruction_score_empty_data() -> None: ) metric = ModelReconstructionScore() result = metric.calculate(data) - assert result == {} + assert result == [] @pytest.mark.parametrize( @@ -30,17 +33,21 @@ def test_model_reconstruction_score_empty_data() -> None: [ ( ValidationMetricData( - source_model_loss=Float[Tensor, Axis.ITEMS]([3.0, 3.0, 3.0]), - source_model_loss_with_reconstruction=Float[Tensor, Axis.ITEMS]([3.0, 3.0, 3.0]), - source_model_loss_with_zero_ablation=Float[Tensor, Axis.ITEMS]([4.0, 4.0, 4.0]), + source_model_loss=torch.tensor([[3.0], [3.0], [3.0]]), + source_model_loss_with_reconstruction=torch.tensor([[3.0], [3.0], [3.0]]), + source_model_loss_with_zero_ablation=torch.tensor([[4.0], [4.0], [4.0]]), ), 1.0, ), ( ValidationMetricData( - source_model_loss=Float[Tensor, Axis.ITEMS]([0.5, 1.5, 2.5]), - source_model_loss_with_reconstruction=Float[Tensor, Axis.ITEMS]([1.5, 2.5, 3.5]), - source_model_loss_with_zero_ablation=Float[Tensor, Axis.ITEMS]([8.0, 7.0, 4.0]), + source_model_loss=Float[Tensor, Axis.ITEMS]([[0.5], [1.5], [2.5]]), + source_model_loss_with_reconstruction=Float[Tensor, Axis.ITEMS]( + [[1.5], [2.5], [3.5]] + ), + source_model_loss_with_zero_ablation=Float[Tensor, Axis.ITEMS]( + [[8.0], [7.0], [4.0]] + ), ), 0.79, ), @@ -55,5 +62,36 @@ def test_model_reconstruction_score_various_data( calculation for different sets of input data. """ metric = ModelReconstructionScore() - result = metric.calculate(data) - assert round(result["validate/model_reconstruction_score"], 2) == expected_score + calculated = metric.calculate(data) + + reconstruction_score = find_metric_result(calculated, name="reconstruction_score", postfix=None) + + result = reconstruction_score.component_wise_values + assert isinstance(result, Tensor) + assert round(result[0].item(), 2) == expected_score + + +def test_weights_biases_log_matches_snapshot(snapshot: SnapshotSession) -> None: + """Test the log function for Weights & Biases.""" + n_items = 10 + n_components = 6 + + # Create some data + torch.manual_seed(0) + data = ValidationMetricData( + source_model_loss=torch.rand((n_items, n_components)), + source_model_loss_with_reconstruction=torch.rand((n_items, n_components)), + source_model_loss_with_zero_ablation=torch.rand((n_items, n_components)), + ) + + # Get the wandb log + metric = ModelReconstructionScore() + results = metric.calculate(data) + weights_biases_logs = [result.wandb_log for result in results] + + for result in results: + assert ( + len(result.component_wise_values) == n_components + ), """Should be one histogram per component.""" + + assert weights_biases_logs == snapshot diff --git a/sparse_autoencoder/train/pipeline.py b/sparse_autoencoder/train/pipeline.py index 0dfee8a8..aa1fbdc8 100644 --- a/sparse_autoencoder/train/pipeline.py +++ b/sparse_autoencoder/train/pipeline.py @@ -3,7 +3,7 @@ from functools import partial from pathlib import Path import tempfile -from typing import final +from typing import TYPE_CHECKING, final from urllib.parse import quote_plus from jaxtyping import Int, Int64 @@ -33,6 +33,9 @@ from sparse_autoencoder.train.utils import get_model_device +if TYPE_CHECKING: + from sparse_autoencoder.metrics.abstract_metric import MetricResult + DEFAULT_CHECKPOINT_DIRECTORY: Path = Path(tempfile.gettempdir()) / "sparse_autoencoder" @@ -219,21 +222,21 @@ def train_autoencoder( learned_activations, reconstructed_activations = self.autoencoder(batch) # Get loss & metrics - metrics = {} + metrics: list[MetricResult] = [] total_loss, loss_metrics = self.loss.scalar_loss_with_log( batch, learned_activations, reconstructed_activations, component_reduction=LossReductionType.MEAN, ) - metrics.update(loss_metrics) + metrics.extend(loss_metrics) with torch.no_grad(): for metric in self.metrics.train_metrics: calculated = metric.calculate( TrainMetricData(batch, learned_activations, reconstructed_activations) ) - metrics.update(calculated) + metrics.extend(calculated) # Store count of how many neurons have fired with torch.no_grad(): @@ -252,8 +255,11 @@ def train_autoencoder( and int(self.total_activations_trained_on / train_batch_size) % self.log_frequency == 0 ): + log = {} + for metric_result in metrics: + log.update(metric_result.wandb_log) wandb.log( - data={**metrics, **loss_metrics}, + log, step=self.total_activations_trained_on, commit=True, ) @@ -341,9 +347,11 @@ def validate_sae(self, validation_number_activations: int) -> None: source_model_loss_with_zero_ablation=torch.tensor(losses_with_zero_ablation), ) for metric in self.metrics.validation_metrics: - calculated = metric.calculate(validation_data) + log = {} + for metric_result in metric.calculate(validation_data): + log.update(metric_result.wandb_log) if wandb.run is not None: - wandb.log(data=calculated, commit=False) + wandb.log(log, commit=False) @final def save_checkpoint(self, *, is_final: bool = False) -> Path: diff --git a/sparse_autoencoder/train/sweep.py b/sparse_autoencoder/train/sweep.py index df32d3fb..f61d71b3 100644 --- a/sparse_autoencoder/train/sweep.py +++ b/sparse_autoencoder/train/sweep.py @@ -184,7 +184,7 @@ def setup_source_data(hyperparameters: RuntimeHyperparameters) -> SourceDataset: def setup_wandb() -> RuntimeHyperparameters: """Initialise wandb for experiment tracking.""" - wandb.init(project="sparse-autoencoder") + wandb.init() return dict(wandb.config) # type: ignore diff --git a/sparse_autoencoder/train/sweep_config.py b/sparse_autoencoder/train/sweep_config.py index 0125a4da..8db112a2 100644 --- a/sparse_autoencoder/train/sweep_config.py +++ b/sparse_autoencoder/train/sweep_config.py @@ -281,8 +281,7 @@ class PipelineHyperparameters(NestedParameter): """Validation frequency.""" validation_number_activations: Parameter[int] = field( - # Default to a single batch of source data prompts - default=Parameter(DEFAULT_BATCH_SIZE * DEFAULT_SOURCE_CONTEXT_SIZE * 16) + default=Parameter(DEFAULT_SOURCE_BATCH_SIZE * DEFAULT_SOURCE_CONTEXT_SIZE * 2) ) """Number of activations to use for validation.""" diff --git a/sparse_autoencoder/train/tests/test_pipeline.py b/sparse_autoencoder/train/tests/test_pipeline.py index 0f2246ed..e2c27e44 100644 --- a/sparse_autoencoder/train/tests/test_pipeline.py +++ b/sparse_autoencoder/train/tests/test_pipeline.py @@ -1,5 +1,5 @@ """Test the pipeline module.""" -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING from unittest.mock import MagicMock import pytest @@ -19,6 +19,7 @@ ) from sparse_autoencoder.activation_resampler.activation_resampler import ActivationResampler from sparse_autoencoder.activation_store.tensor_store import TensorActivationStore +from sparse_autoencoder.metrics.abstract_metric import MetricResult from sparse_autoencoder.metrics.validate.abstract_validate_metric import ( AbstractValidationMetric, ValidationMetricData, @@ -265,10 +266,10 @@ class StoreValidationMetric(AbstractValidationMetric): data: ValidationMetricData | None - def calculate(self, data: ValidationMetricData) -> dict[str, Any]: + def calculate(self, data: ValidationMetricData) -> list[MetricResult]: """Store the data.""" self.data = data - return {} + return [] dummy_metric = StoreValidationMetric() pipeline_fixture.metrics.validation_metrics.append(dummy_metric)