-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathrun_experiments.py
283 lines (235 loc) · 9.47 KB
/
run_experiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
"""
Sparse Autoencoder Experiment Runner
===================================
Main entry point for running SAE experiments on transformer activations.
Handles training, analysis, visualization and checkpointing.
Key Components:
- Model training with multiple activation functions
- Neuron frequency analysis
- Concept emergence tracking
- Experiment checkpointing
- W&B and ASCII visualization
Usage:
------
Basic training:
python run_experiments.py --hidden-dim 256 --epochs 100
Transformer analysis:
python run_experiments.py --model-name gpt2-small --layer 0 --n-samples 1000 --use-wandb
Functions:
----------
parse_args(): Configure experiment parameters
get_dataset(): Load and cache transformer activations
train_model(): Train SAE with specified config
compute_sparsity(): Calculate activation sparsity
run_activation_study(): Compare activation functions
run_full_analysis(): Execute complete analysis suite
"""
import argparse
import torch
from torch.utils.data import DataLoader
from config.config import SAEConfig
import wandb
from tqdm import tqdm
import transformer_lens
from visualization.ascii_viz import ASCIIVisualizer
from visualization.wandb_viz import WandBVisualizer
from models.autoencoder import SparseAutoencoder
from experiments.frequency_analysis import FrequencyAnalyzer
from experiments.concept_emergence import ConceptAnalyzer
from experiments.transformer_data import TransformerActivationDataset
from experiments.checkpointing import CheckpointManager, ExperimentState
# Add model caching
_cached_model = None
_cached_dataset = None
def parse_args():
"""Parse command line arguments for experiment configuration."""
parser = argparse.ArgumentParser(description='Run SAE Experiments')
parser.add_argument('--hidden-dim', type=int, default=256)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--use-wandb', action='store_true')
parser.add_argument('--activation', type=str, choices=['relu', 'jump_relu', 'topk'],
default='relu', help='Activation function type')
parser.add_argument('--model-name', type=str, default='gpt2-small')
parser.add_argument('--layer', type=int, default=0)
parser.add_argument('--n-samples', type=int, default=1000)
return parser.parse_args()
def get_dataset(config):
global _cached_dataset
if (_cached_dataset is None):
_cached_dataset = TransformerActivationDataset(
model_name=config.model_name,
layer=config.layer,
n_samples=config.n_samples
)
return _cached_dataset
def train_model(config, track_frequency=True, visualizer=None):
"""
Train Sparse Autoencoder model.
Args:
config: SAEConfig object with model/training parameters
track_frequency: Enable neuron firing rate tracking
visualizer: Optional W&B visualization handler
Returns:
model: Trained SAE model
freq_analyzer: Frequency analysis results
losses: Training loss history
"""
# Get dataset first to determine input dimension
dataset = get_dataset(config)
sample = dataset[0]["pixel_values"]
input_dim = sample.shape[0] # Get actual dimension from transformer
model = SparseAutoencoder(
input_dim=input_dim, # Use transformer's hidden dimension
hidden_dim=config.hidden_dim,
activation_type=config.activation_type
)
optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
# Setup tracking
frequency_analyzer = FrequencyAnalyzer(model) if track_frequency else None
losses = []
# Get dataset
dataset = get_dataset(config)
dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True)
# Training loop
progress_bar = tqdm(range(config.epochs), desc=f"Training {config.activation_type}")
for epoch in progress_bar:
epoch_loss = 0
for batch in dataloader:
optimizer.zero_grad()
inputs = batch["pixel_values"]
reconstructed, encoded = model(inputs)
if frequency_analyzer:
frequency_analyzer.update(encoded)
loss = torch.nn.functional.mse_loss(reconstructed, inputs)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if visualizer:
visualizer.log_training({
'epoch': epoch,
'loss': loss.item(),
'encoded': encoded.detach(),
'weights': model.encoder.weight.data
})
avg_loss = epoch_loss / len(dataloader)
losses.append(avg_loss)
if config.use_wandb:
wandb.log({
'epoch': epoch,
'loss': avg_loss,
'activation': config.activation_type
})
progress_bar.set_postfix({'loss': f'{avg_loss:.4f}'})
return model, frequency_analyzer, losses
def compute_sparsity(model, config):
"""Compute activation sparsity using cached dataset"""
with torch.no_grad():
dataset = get_dataset(config)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)
batch = next(iter(dataloader))
_, encoded = model(batch["pixel_values"])
zeros = (encoded.abs() < 1e-5).float().mean().item()
return zeros
def run_activation_study(config, visualizer=None):
"""
Compare different activation functions.
Trains models with ReLU, JumpReLU and TopK activations,
tracking performance metrics and neuron behavior.
Args:
config: Experiment configuration
visualizer: Optional visualization handler
Returns:
Dictionary containing results for each activation
"""
checkpoint_manager = CheckpointManager()
state = checkpoint_manager.load_checkpoint()
results = {}
activations = ['relu', 'jump_relu', 'topk']
# Resume from checkpoint if exists
if state:
results = {act: {} for act in state.completed_activations}
start_idx = activations.index(state.current_activation)
else:
start_idx = 0
for activation in activations[start_idx:]:
try:
config.activation_type = activation
print(f"\nStudying {activation} activation:")
model, freq_analyzer, losses = train_model(config, visualizer=visualizer)
results[activation] = {
'final_loss': losses[-1],
'loss_trend': losses,
'frequency_stats': freq_analyzer.analyze() if freq_analyzer else None,
'sparsity': compute_sparsity(model, config),
'feature_weights': model.encoder.weight.data,
'model': model
}
# Save checkpoint after each activation
state = ExperimentState(
completed_activations=list(results.keys()),
current_activation=activation,
epoch=config.epochs,
model_state=model.state_dict(),
optimizer_state=None, # Add if needed
frequency_stats=results[activation]['frequency_stats'],
losses=losses
)
checkpoint_manager.save_checkpoint(state, config)
except Exception as e:
print(f"\nError during {activation} training: {str(e)}")
print("You can resume from this point using the checkpoint")
raise e
return results
def run_full_analysis(config):
"""
Execute comprehensive analysis suite.
Performs:
- Activation function comparison
- Frequency pattern analysis
- Concept emergence tracking
- Result visualization
Args:
config: Experiment configuration
Returns:
Complete analysis results dictionary
"""
visualizer = WandBVisualizer(
model_name=config.model_name,
run_name=f"sae_{config.hidden_dim}_{config.activation_type}"
)
# Get activation study results and final model
print("\n=== Running Activation Function Study ===")
activation_results = run_activation_study(config, visualizer)
# Use last trained model instead of training new one
print("\n=== Analyzing Final Model ===")
model = activation_results[config.activation_type]['model'] # Get cached model
freq_analyzer = FrequencyAnalyzer(model)
freq_stats = freq_analyzer.analyze()
print("\n=== Analyzing Concept Emergence ===")
concept_analyzer = ConceptAnalyzer(model, get_dataset(config))
concept_stats = concept_analyzer.analyze_concepts()
results = {
'activation_comparison': activation_results,
'frequency_analysis': freq_stats,
'concept_analysis': concept_stats
}
visualizer.log_results(results)
ASCIIVisualizer.print_results(results)
return results
if __name__ == "__main__":
args = parse_args()
config = SAEConfig(
input_dim=784,
hidden_dim=args.hidden_dim,
learning_rate=args.lr,
epochs=args.epochs,
batch_size=args.batch_size,
activation_type=args.activation,
use_wandb=args.use_wandb,
model_name=args.model_name,
layer=args.layer,
n_samples=args.n_samples
)
results = run_full_analysis(config)