Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support training with DataParallel #178

Merged
merged 3 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,10 @@
cmd="mkdocs serve"
help="Hot reload the docs site (so changes appear instantly)"

[tool.poe.tasks.run]
cmd="poetry run python"
help="Run a python file (append with file name)"

[build-system]
build-backend="poetry.core.masonry.api"
requires=["poetry-core"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


@dataclass
Expand Down Expand Up @@ -58,7 +59,7 @@ def step_resampler(
self,
batch_neuron_activity: Int[Tensor, Axis.LEARNT_FEATURE],
activation_store: TensorActivationStore,
autoencoder: SparseAutoencoder,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults] | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sparse_autoencoder.loss.abstract_loss import AbstractLoss
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.train.utils.get_model_device import get_model_device
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


class LossInputActivationsTuple(NamedTuple):
Expand Down Expand Up @@ -188,7 +189,7 @@ def _get_dead_neuron_indices(
def compute_loss_and_get_activations(
self,
store: ActivationStore,
autoencoder: SparseAutoencoder,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> LossInputActivationsTuple:
Expand Down Expand Up @@ -421,7 +422,7 @@ def renormalize_and_scale(
def resample_dead_neurons(
self,
activation_store: ActivationStore,
autoencoder: SparseAutoencoder,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults]:
Expand Down Expand Up @@ -513,7 +514,7 @@ def step_resampler(
Tensor, Axis.names(Axis.COMPONENT_OPTIONAL, Axis.LEARNT_FEATURE)
],
activation_store: ActivationStore,
autoencoder: SparseAutoencoder,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss_fn: AbstractLoss,
train_batch_size: int,
) -> list[ParameterUpdateResults] | None:
Expand Down
6 changes: 3 additions & 3 deletions sparse_autoencoder/source_data/tests/test_text_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Pile Uncopyrighted Dataset Tests."""
import pytest
from transformers import PreTrainedTokenizerFast
from transformers import GPT2Tokenizer

from sparse_autoencoder.source_data.text_dataset import TextDataset

Expand All @@ -9,7 +9,7 @@
@pytest.mark.parametrize("context_size", [50, 250])
def test_tokenized_prompts_correct_size(context_size: int) -> None:
"""Test that the tokenized prompts have the correct context size."""
tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

data = TextDataset(
tokenizer=tokenizer, context_size=context_size, dataset_path="monology/pile-uncopyrighted"
Expand All @@ -31,7 +31,7 @@ def test_dataloader_correct_size_items() -> None:
"""Test the dataloader returns the correct number & sized items."""
batch_size = 10
context_size = 250
tokenizer = PreTrainedTokenizerFast.from_pretrained("gpt2")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
data = TextDataset(
tokenizer=tokenizer, context_size=context_size, dataset_path="monology/pile-uncopyrighted"
)
Expand Down
3 changes: 2 additions & 1 deletion sparse_autoencoder/source_model/replace_activations_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from transformer_lens.hook_points import HookPoint

from sparse_autoencoder.autoencoder.model import SparseAutoencoder
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


if TYPE_CHECKING:
Expand All @@ -15,7 +16,7 @@
def replace_activations_hook(
value: Tensor,
hook: HookPoint, # noqa: ARG001
sparse_autoencoder: SparseAutoencoder,
sparse_autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
component_idx: int | None = None,
) -> Tensor:
"""Replace activations hook.
Expand Down
9 changes: 5 additions & 4 deletions sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from sparse_autoencoder.source_model.zero_ablate_hook import zero_ablate_hook
from sparse_autoencoder.tensor_types import Axis
from sparse_autoencoder.train.utils.get_model_device import get_model_device
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


if TYPE_CHECKING:
Expand All @@ -49,7 +50,7 @@ class Pipeline:
activation_resampler: AbstractActivationResampler | None
"""Activation resampler to use."""

autoencoder: SparseAutoencoder
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder]
"""Sparse autoencoder to train."""

cache_names: list[str]
Expand Down Expand Up @@ -79,7 +80,7 @@ class Pipeline:
source_dataset: SourceDataset
"""Source dataset to generate activation data from (tokenized prompts)."""

source_model: HookedTransformer
source_model: HookedTransformer | DataParallelWithModelAttributes[HookedTransformer]
"""Source model to get activations from."""

total_activations_trained_on: int = 0
Expand All @@ -95,13 +96,13 @@ def n_components(self) -> int:
def __init__(
self,
activation_resampler: AbstractActivationResampler | None,
autoencoder: SparseAutoencoder,
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
cache_names: list[str],
layer: NonNegativeInt,
loss: AbstractLoss,
optimizer: AbstractOptimizerWithReset,
source_dataset: SourceDataset,
source_model: HookedTransformer,
source_model: HookedTransformer | DataParallelWithModelAttributes[HookedTransformer],
run_name: str = "sparse_autoencoder",
checkpoint_directory: Path = DEFAULT_CHECKPOINT_DIRECTORY,
log_frequency: PositiveInt = 100,
Expand Down
7 changes: 4 additions & 3 deletions sparse_autoencoder/train/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
RuntimeHyperparameters,
SweepConfig,
)
from sparse_autoencoder.utils.data_parallel import DataParallelWithModelAttributes


def setup_activation_resampler(hyperparameters: RuntimeHyperparameters) -> ActivationResampler:
Expand Down Expand Up @@ -239,8 +240,8 @@ def stop_layer_from_cache_names(cache_names: list[str]) -> int:

def run_training_pipeline(
hyperparameters: RuntimeHyperparameters,
source_model: HookedTransformer,
autoencoder: SparseAutoencoder,
source_model: HookedTransformer | DataParallelWithModelAttributes[HookedTransformer],
autoencoder: SparseAutoencoder | DataParallelWithModelAttributes[SparseAutoencoder],
loss: LossReducer,
optimizer: AdamWithReset,
activation_resampler: ActivationResampler,
Expand Down Expand Up @@ -324,7 +325,7 @@ def train() -> None:
run_training_pipeline(
hyperparameters=hyperparameters,
source_model=source_model,
autoencoder=autoencoder,
autoencoder=DataParallelWithModelAttributes(autoencoder),
loss=loss_function,
optimizer=optimizer,
activation_resampler=activation_resampler,
Expand Down
40 changes: 40 additions & 0 deletions sparse_autoencoder/utils/data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
"""Data parallel utils."""
from typing import Any, Generic, TypeVar

from torch.nn import DataParallel, Module


T = TypeVar("T", bound=Module)


class DataParallelWithModelAttributes(DataParallel[T], Generic[T]):
"""Data parallel with access to underlying model attributes/methods.

Allows access to underlying model attributes/methods, which is not possible with the default
`DataParallel` class. Based on:
https://pytorch.org/tutorials/beginner/former_torchies/parallelism_tutorial.html

Example:
>>> from sparse_autoencoder import SparseAutoencoder, SparseAutoencoderConfig
>>> model = SparseAutoencoder(SparseAutoencoderConfig(
... n_input_features=2,
... n_learned_features=4,
... ))
>>> distributed_model = DataParallelWithModelAttributes(model)
>>> distributed_model.config.n_learned_features
4
"""

def __getattr__(self, name: str) -> Any: # noqa: ANN401
"""Allow access to underlying model attributes/methods.

Args:
name: Attribute/method name.

Returns:
Attribute value/method.
"""
try:
return super().__getattr__(name)
except AttributeError:
return getattr(self.module, name)