Skip to content

Commit

Permalink
make unit tests pass, add make file
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Feb 7, 2024
1 parent ce526df commit 08b2c92
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 174 deletions.
14 changes: 14 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
format:

check-format:


test:
make unit-test
make acceptance-test

unit-test:
pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/unit

acceptance-test:
pytest -v --cov=sae_training/ --cov-report=term-missing --cov-branch tests/acceptance
5 changes: 5 additions & 0 deletions tests/unit/test_activations_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ def cfg():
mock_config.lr = 2e-4
mock_config.train_batch_size = 32
mock_config.context_size = 16
mock_config.use_cached_activations = False
mock_config.hook_point_head_index = None

mock_config.feature_sampling_method = None
mock_config.feature_sampling_window = 50
Expand Down Expand Up @@ -74,6 +76,9 @@ def cfg_head_hook():
mock_config.lr = 2e-4
mock_config.train_batch_size = 32
mock_config.context_size = 128
mock_config.use_cached_activations = False
mock_config.hook_point_head_index = 0


mock_config.feature_sampling_method = None
mock_config.feature_sampling_window = 50
Expand Down
184 changes: 10 additions & 174 deletions tests/unit/test_sparse_autoencoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def cfg():
mock_config.is_dataset_tokenized = False
mock_config.use_cached_activations = False
mock_config.d_in = 64
mock_config.use_ghost_grads = False
mock_config.expansion_factor = 2
mock_config.d_sae = mock_config.d_in * mock_config.expansion_factor
mock_config.l1_coefficient = 2e-3
Expand All @@ -45,11 +46,11 @@ def cfg():
mock_config.wandb_project = "test_project"
mock_config.wandb_entity = "test_entity"
mock_config.wandb_log_frequency = 10
mock_config.device = "cuda"
mock_config.device = "cpu"
mock_config.seed = 24
mock_config.checkpoint_path = "test/checkpoints"
mock_config.dtype = torch.bfloat16
# mock_config.dtype = torch.float32
# mock_config.dtype = torch.bfloat16
mock_config.dtype = torch.float32

return mock_config

Expand Down Expand Up @@ -175,7 +176,7 @@ def test_sparse_autoencoder_forward(sparse_autoencoder):
d_sae = sparse_autoencoder.d_sae

x = torch.randn(batch_size, d_in)
sae_out, feature_acts, loss, mse_loss, l1_loss = sparse_autoencoder.forward(
sae_out, feature_acts, loss, mse_loss, l1_loss, ghost_grad_loss = sparse_autoencoder.forward(
x,
)

Expand All @@ -186,8 +187,11 @@ def test_sparse_autoencoder_forward(sparse_autoencoder):
assert l1_loss.shape == ()
assert torch.allclose(loss, mse_loss + l1_loss)

assert torch.allclose(mse_loss, (sae_out.float() - x.float()).pow(2).sum(-1).mean(0))
assert torch.allclose(l1_loss, sparse_autoencoder.l1_coefficient * torch.abs(feature_acts).sum())

expected_mse_loss = (torch.pow((sae_out-x.float()), 2) / (x**2).sum(dim=-1, keepdim=True).sqrt()).mean()
assert torch.allclose(mse_loss, expected_mse_loss)
expected_l1_loss = torch.abs(feature_acts).sum(dim=1).mean(dim=(0,))
assert torch.allclose(l1_loss, sparse_autoencoder.l1_coefficient * expected_l1_loss)

# check everything has the right dtype
assert sae_out.dtype == sparse_autoencoder.dtype
Expand All @@ -196,171 +200,3 @@ def test_sparse_autoencoder_forward(sparse_autoencoder):
assert mse_loss.dtype == sparse_autoencoder.dtype
assert l1_loss.dtype == sparse_autoencoder.dtype

def test_sparse_autoencoder_resample_neurons_l2(sparse_autoencoder):

batch_size = 32
d_in =sparse_autoencoder.d_in
d_sae = sparse_autoencoder.d_sae

x = torch.randn(batch_size, d_in)
feature_sparsity = torch.exp((torch.randn(d_sae) - 17))
neuron_resample_scale = 0.2
optimizer = torch.optim.Adam(sparse_autoencoder.parameters(), lr=1e-4)

# set weight of the sparse autoencoder to be non-zero (and not have unit norm)
sparse_autoencoder.W_enc.data = torch.randn(d_in, d_sae)*10
sparse_autoencoder.W_dec.data = torch.randn(d_sae, d_in)*10
sparse_autoencoder.b_enc.data = torch.randn(d_sae)*10
sparse_autoencoder.b_dec.data = torch.randn(d_in)*10

# Set optimizer state so we can tell when it is reset:
dummy_value = 5.0
for dict_idx, (k, v) in enumerate(optimizer.state.items()):
for v_key in ["exp_avg", "exp_avg_sq"]:
if dict_idx == 0: # W_enc
assert k.data.shape == (d_in, d_sae)
v[v_key] = dummy_value
elif dict_idx == 1: # b_enc
assert k.data.shape == (d_sae,)
v[v_key] = dummy_value
elif dict_idx == 2: # W_dec
assert k.data.shape == (d_sae, d_in)
v[v_key]= dummy_value
elif dict_idx == 3: # b_dec
assert k.data.shape == (d_in,)
is_dead = feature_sparsity < sparse_autoencoder.cfg.dead_feature_threshold
alive_neurons = feature_sparsity >= sparse_autoencoder.cfg.dead_feature_threshold


n_resampled_neurons = sparse_autoencoder.resample_neurons_l2(x, feature_sparsity, neuron_resample_scale, optimizer)

# want to check the following:
# 1. that the number of neurons reset is equal to the number of neurons that should be reset
assert n_resampled_neurons == is_dead.sum().item()

# 2. for each neuron we reset:
# a. the bias is zero
assert torch.allclose(
sparse_autoencoder.b_enc.data[is_dead],
torch.zeros_like(sparse_autoencoder.b_enc.data[is_dead]))
# b. the encoder weights have norm 0.2 * average of other weights.
mean_decoder_norm = sparse_autoencoder.W_enc[:, alive_neurons].norm(dim=0).mean().item()
assert torch.allclose(
sparse_autoencoder.W_enc[:, is_dead].norm(dim=0),
torch.ones(n_resampled_neurons) * 0.2 * mean_decoder_norm
)
# c. the decoder weights have unit norm
assert torch.allclose(
sparse_autoencoder.W_dec[is_dead, :].norm(dim=1),
torch.ones(n_resampled_neurons)
)

# d. the Adam parameters are reset
for dict_idx, (k, v) in enumerate(optimizer.state.items()):
for v_key in ["exp_avg", "exp_avg_sq"]:
if dict_idx == 0:
if k.data.shape != (d_in, d_sae):
print(
"Warning: it does not seem as if resetting the Adam parameters worked, there are shapes mismatches"
)
if v[v_key][:, is_dead].abs().max().item() > 1e-6:
print(
"Warning: it does not seem as if resetting the Adam parameters worked"
)

# e. check that the decoder weights for reset neurons match the encoder weights for reset neurons
# (given both are normalized)
assert torch.allclose(
(sparse_autoencoder.W_enc[:, is_dead] / sparse_autoencoder.W_enc[:, is_dead].norm(dim=0)).T,
sparse_autoencoder.W_dec[is_dead, :] / sparse_autoencoder.W_dec[is_dead, :].norm(dim=1).unsqueeze(1)
)

def test_sparse_autoencoder_resample_neurons_anthropic(sparse_autoencoder, model, activation_store):
'''
Not sure how to test this properly so for now
we'll just check that it runs without error.
'''

batch_size = sparse_autoencoder.cfg.store_batch_size
d_in =sparse_autoencoder.d_in
d_sae = sparse_autoencoder.d_sae
neuron_resample_scale = sparse_autoencoder.cfg.feature_reinit_scale
feature_sparsity = torch.exp((torch.randn(d_sae) - 17))
optimizer = torch.optim.Adam(sparse_autoencoder.parameters(), lr=1e-4)

# Set optimizer state so we can tell when it is reset:
dummy_value = 5.0
for dict_idx, (k, v) in enumerate(optimizer.state.items()):
for v_key in ["exp_avg", "exp_avg_sq"]:
if dict_idx == 0: # W_enc
assert k.data.shape == (d_in, d_sae)
v[v_key] = dummy_value
elif dict_idx == 1: # b_enc
assert k.data.shape == (d_sae,)
v[v_key] = dummy_value
elif dict_idx == 2: # W_dec
assert k.data.shape == (d_sae, d_in)
v[v_key]= dummy_value
elif dict_idx == 3: # b_dec
assert k.data.shape == (d_in,)

dead_neuron_indices = (feature_sparsity < sparse_autoencoder.cfg.dead_feature_threshold).nonzero(as_tuple=False)[:, 0]
alive_neurons = (feature_sparsity >= sparse_autoencoder.cfg.dead_feature_threshold).nonzero(as_tuple=False)[:, 0]

sparse_autoencoder.resample_neurons_anthropic(
dead_neuron_indices,
model,
optimizer,
activation_store
)

# # want to check the following:
# # 1. that the number of neurons reset is equal to the number of neurons that should be reset
# assert n_resampled_neurons == is_dead.sum().item()

# # 2. for each neuron we reset:
# # a. the bias is zero
# assert torch.allclose(
# sparse_autoencoder.b_enc.data[is_dead],
# torch.zeros_like(sparse_autoencoder.b_enc.data[is_dead]))
# # b. the encoder weights have norm 0.2 * average of other weights.
# mean_decoder_norm = sparse_autoencoder.W_enc[:, alive_neurons].norm(dim=0).mean().item()
# assert torch.allclose(
# sparse_autoencoder.W_enc[:, is_dead].norm(dim=0),
# torch.ones(n_resampled_neurons) * 0.2 * mean_decoder_norm
# )
# # c. the decoder weights have unit norm
# assert torch.allclose(
# sparse_autoencoder.W_dec[is_dead, :].norm(dim=1),
# torch.ones(n_resampled_neurons)
# )

# # d. the Adam parameters are reset
# for dict_idx, (k, v) in enumerate(optimizer.state.items()):
# for v_key in ["exp_avg", "exp_avg_sq"]:
# if dict_idx == 0:
# if k.data.shape != (d_in, d_sae):
# print(
# "Warning: it does not seem as if resetting the Adam parameters worked, there are shapes mismatches"
# )
# if v[v_key][:, is_dead].abs().max().item() > 1e-6:
# print(
# "Warning: it does not seem as if resetting the Adam parameters worked"
# )

# # e. check that the decoder weights for reset neurons match the encoder weights for reset neurons
# # (given both are normalized)
# assert torch.allclose(
# (sparse_autoencoder.W_enc[:, is_dead] / sparse_autoencoder.W_enc[:, is_dead].norm(dim=0)).T,
# sparse_autoencoder.W_dec[is_dead, :] / sparse_autoencoder.W_dec[is_dead, :].norm(dim=1).unsqueeze(1)
# )



@pytest.mark.skip("TODO")
def test_sparse_eautoencoder_remove_gradient_parallel_to_decoder_directions(cfg):
pass # TODO



2 changes: 2 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ def cfg():
mock_config.seed = 24
mock_config.checkpoint_path = "test/checkpoints"
mock_config.dtype = torch.float32
mock_config.use_cached_activations = False
mock_config.hook_point_head_index = None

return mock_config

Expand Down

0 comments on commit 08b2c92

Please sign in to comment.