Skip to content

Commit

Permalink
Add unit/int test separation (#164)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Dec 31, 2023
1 parent c446b89 commit 75c9b7e
Show file tree
Hide file tree
Showing 8 changed files with 493 additions and 446 deletions.
1 change: 1 addition & 0 deletions .vscode/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
"neox",
"nonlinerity",
"numel",
"openwebtext",
"optim",
"penality",
"perp",
Expand Down
892 changes: 455 additions & 437 deletions poetry.lock

Large diffs are not rendered by default.

25 changes: 17 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
[tool.poetry.dependencies]
datasets=">=2.15.0"
einops=">=0.6"
pydantic="^2.5.2"
pydantic=">=2.5.2"
python=">=3.10, <3.12"
strenum=">=0.4.15"
tokenizers=">=0.15.0"
Expand All @@ -24,32 +24,33 @@
jupyter=">=1"
plotly=">=5"
poethepoet=">=0.24.2"
pydoclint="^0.3.8"
pydoclint=">=0.3.8"
pyright=">=1.1.340"
pytest=">=7"
pytest-cov=">=4"
pytest-integration=">=0.2.3"
pytest-timeout=">=2.2.0"
ruff=">=0.1.4"
syrupy=">=4.6.0"

[tool.poetry.group.demos.dependencies]
huggingface-hub="^0.19.4"
ipywidgets="^8.1.1"
huggingface-hub=">=0.19.4"
ipywidgets=">=8.1.1"
jupyterlab=">=3"
transformer-lens=">=1.9.0"

[tool.poetry.group.docs.dependencies]
mkdocs=">=1.5.3"
mkdocs-gen-files=">=0.5.0"
mkdocs-htmlproofer-plugin="^1.0.0"
mkdocs-htmlproofer-plugin=">=1.0.0"
mkdocs-literate-nav=">=0.6.1"
mkdocs-material=">=9.4.10"
mkdocs-section-index=">=0.3.8"
mkdocstrings={extras=["python"], version=">=0.24.0"}
mkdocstrings-python=">=1.7.3"
mknotebooks="^0.8.0"
pygments="^2.17.2"
pymdown-extensions="^10.5"
mknotebooks=">=0.8.0"
pygments=">=2.17.2"
pymdown-extensions=">=10.5"
pytkdocs-tweaks=">=0.0.7"

[tool.poe.tasks]
Expand Down Expand Up @@ -91,6 +92,14 @@
help=" [alias for test]"
sequence=["test"]

[tool.poe.tasks.unit-test]
cmd="pytest --without-integration"
help="Run unit tests"

[tool.poe.tasks.unit]
help=" [alias for unit-test]"
sequence=["unit-test"]

[tool.poe.tasks.typecheck]
cmd="pyright"
help="Typecheck"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest


TEST_DATASET = "NeelNanda/c4-tokenized-2b"
TEST_DATASET = "alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"


# Mock class for PreTokenizedDataset
Expand Down Expand Up @@ -39,6 +39,7 @@ def _generate_mock_data(self) -> list[dict]:
return mock_data


@pytest.mark.integration_test()
@pytest.mark.parametrize("context_size", [50, 250])
def test_tokenized_prompts_correct_size(context_size: int) -> None:
"""Test that the tokenized prompts have the correct context size."""
Expand Down
2 changes: 2 additions & 0 deletions sparse_autoencoder/source_data/tests/test_text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from sparse_autoencoder.source_data.text_dataset import TextDataset


@pytest.mark.integration_test()
@pytest.mark.parametrize("context_size", [50, 250])
def test_tokenized_prompts_correct_size(context_size: int) -> None:
"""Test that the tokenized prompts have the correct context size."""
Expand All @@ -25,6 +26,7 @@ def test_tokenized_prompts_correct_size(context_size: int) -> None:
assert isinstance(token, int)


@pytest.mark.integration_test()
def test_dataloader_correct_size_items() -> None:
"""Test the dataloader returns the correct number & sized items."""
batch_size = 10
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial

from jaxtyping import Int
import pytest
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
Expand All @@ -11,6 +12,7 @@
from sparse_autoencoder.tensor_types import Axis


@pytest.mark.integration_test()
def test_hook_replaces_activations() -> None:
"""Test that the hook replaces activations."""
torch.random.manual_seed(0)
Expand All @@ -36,6 +38,7 @@ def test_hook_replaces_activations() -> None:
assert torch.all(torch.gt(loss_with_hook, loss_without_hook))


@pytest.mark.integration_test()
def test_hook_replaces_activations_2_components() -> None:
"""Test that the hook replaces activations."""
torch.random.manual_seed(0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial

from jaxtyping import Int
import pytest
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
Expand All @@ -11,6 +12,7 @@
from sparse_autoencoder.tensor_types import Axis


@pytest.mark.integration_test()
def test_hook_stores_activations() -> None:
"""Test that the hook stores activations correctly."""
store = TensorActivationStore(max_items=100, n_neurons=256)
Expand Down
11 changes: 11 additions & 0 deletions sparse_autoencoder/train/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def pipeline_fixture() -> Pipeline:
class TestGenerateActivations:
"""Test the generate_activations method."""

@pytest.mark.integration_test()
def test_generates_store(self, pipeline_fixture: Pipeline) -> None:
"""Test that generate_activations generates a store."""
store_size: int = 1000
Expand All @@ -78,6 +79,7 @@ def test_generates_store(self, pipeline_fixture: Pipeline) -> None:
), "Store must be a TensorActivationStore instance"
assert len(store) == store_size, "Store size should match the specified size"

@pytest.mark.integration_test()
def test_store_has_unique_items(self, pipeline_fixture: Pipeline) -> None:
"""Test that each item from the store iterable is unique."""
store_size: int = 1000
Expand All @@ -93,6 +95,7 @@ def test_store_has_unique_items(self, pipeline_fixture: Pipeline) -> None:

assert len(unique_activations) >= expected_min_length, "Store items should be unique"

@pytest.mark.integration_test()
def test_two_runs_generate_different_activations(self, pipeline_fixture: Pipeline) -> None:
"""Test that two runs of generate_activations generate different activations."""
store_size: int = 1000
Expand All @@ -108,6 +111,7 @@ def test_two_runs_generate_different_activations(self, pipeline_fixture: Pipelin
class TestTrainAutoencoder:
"""Test the train_autoencoder method."""

@pytest.mark.integration_test()
def test_learned_activations_fired_count(self, pipeline_fixture: Pipeline) -> None:
"""Test that the learned activations fired count is updated correctly."""
store_size: int = 1000
Expand All @@ -122,6 +126,7 @@ def test_learned_activations_fired_count(self, pipeline_fixture: Pipeline) -> No

assert fired_count.sum().item() > 0, "Some neurons should have fired."

@pytest.mark.integration_test()
def test_learns_with_backwards_pass(self, pipeline_fixture: Pipeline) -> None:
"""Test that the autoencoder learns with a backwards pass."""
store_size: int = 1000
Expand All @@ -144,6 +149,7 @@ def test_learns_with_backwards_pass(self, pipeline_fixture: Pipeline) -> None:
class TestUpdateParameters:
"""Test the update_parameters method."""

@pytest.mark.integration_test()
def test_weights_biases_changed(self, pipeline_fixture: Pipeline) -> None:
"""Test that the weights and biases have changed after training."""
store_size: int = 1000
Expand Down Expand Up @@ -209,6 +215,7 @@ def test_weights_biases_changed(self, pipeline_fixture: Pipeline) -> None:
pipeline_fixture.autoencoder.decoder.weight[0, :, ~dead_neuron_indices],
), "Decoder weights should not have changed after training."

@pytest.mark.integration_test()
def test_optimizer_state_changed(self, pipeline_fixture: Pipeline) -> None:
"""Test that the optimizer state has changed after training."""
store_size: int = 1000
Expand Down Expand Up @@ -268,6 +275,7 @@ def test_optimizer_state_changed(self, pipeline_fixture: Pipeline) -> None:
class TestValidateSAE:
"""Test the validate_sae method."""

@pytest.mark.integration_test()
def test_validation_loss_calculated(self, pipeline_fixture: Pipeline) -> None:
"""Test that the validation loss numbers are calculated."""

Expand Down Expand Up @@ -308,11 +316,13 @@ def calculate(self, data: ValidationMetricData) -> list[MetricResult]:
class TestSaveCheckpoint:
"""Test the save_checkpoint method."""

@pytest.mark.integration_test()
def test_saves_locally(self, pipeline_fixture: Pipeline) -> None:
"""Test that the save_checkpoint method saves the checkpoint locally."""
saved_checkpoint: Path = pipeline_fixture.save_checkpoint()
assert saved_checkpoint.exists(), "Checkpoint file should exist."

@pytest.mark.integration_test()
def test_saves_final(self, pipeline_fixture: Pipeline) -> None:
"""Test that the save_checkpoint method saves the final checkpoint."""
saved_checkpoint: Path = pipeline_fixture.save_checkpoint(is_final=True)
Expand All @@ -324,6 +334,7 @@ def test_saves_final(self, pipeline_fixture: Pipeline) -> None:
class TestRunPipeline:
"""Test the run_pipeline method."""

@pytest.mark.integration_test()
def test_run_pipeline_calls_all_methods(self, pipeline_fixture: Pipeline) -> None:
"""Test that the run_pipeline method calls all the other methods."""
pipeline_fixture.validate_sae = MagicMock(spec=Pipeline.validate_sae) # type: ignore
Expand Down

0 comments on commit 75c9b7e

Please sign in to comment.