Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
hijohnnylin committed Apr 16, 2024
1 parent f769e7a commit 1e3d53e
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 38 deletions.
115 changes: 79 additions & 36 deletions sae_lens/analysis/neuronpedia_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,14 @@ def init_sae_session(self):
self.sae_path, device=self.device
)
loader = LMSparseAutoencoderSessionloader(self.sparse_autoencoder.cfg)
self.model, _, self.activation_store = loader.load_sae_training_group_session()
self.model, _, self.activation_store = (
loader.load_sae_training_group_session()
)

def get_tokens(
self, n_batches_to_sample_from: int = 2**12, n_prompts_to_select: int = 4096 * 6
self,
n_batches_to_sample_from: int = 2**12,
n_prompts_to_select: int = 4096 * 6,
):
all_tokens_list = []
pbar = tqdm(range(n_batches_to_sample_from))
Expand All @@ -132,7 +136,9 @@ def round_list(self, to_round: list[float]):
return list(np.round(to_round, 3))

def to_str_tokens_safe(
self, vocab_dict: Dict[int, str], tokens: Union[int, List[int], torch.Tensor]
self,
vocab_dict: Dict[int, str],
tokens: Union[int, List[int], torch.Tensor],
):
"""
does to_str_tokens, except handles out of range
Expand Down Expand Up @@ -173,12 +179,16 @@ def run(self):
sparsity = load_sparsity(self.sae_path)
sparsity = sparsity.to(self.device)
self.target_feature_indexes = (
(sparsity > self.sparsity_threshold).nonzero(as_tuple=True)[0].tolist()
(sparsity > self.sparsity_threshold)
.nonzero(as_tuple=True)[0]
.tolist()
)

# divide into batches
feature_idx = torch.tensor(self.target_feature_indexes)
n_subarrays = np.ceil(len(feature_idx) / self.n_features_at_a_time).astype(int)
n_subarrays = np.ceil(
len(feature_idx) / self.n_features_at_a_time
).astype(int)
feature_idx = np.array_split(feature_idx, n_subarrays)
feature_idx = [x.tolist() for x in feature_idx]

Expand All @@ -193,7 +203,9 @@ def run(self):
exit()

# write dead into file so we can create them as dead in Neuronpedia
skipped_indexes = set(range(self.n_features)) - set(self.target_feature_indexes)
skipped_indexes = set(range(self.n_features)) - set(
self.target_feature_indexes
)
skipped_indexes_json = json.dumps(
{
"model_id": self.model_id,
Expand All @@ -209,7 +221,9 @@ def run(self):
print(f"Total skipped: {len(skipped_indexes)}")
print(f"Total batches: {len(feature_idx)}")

print(f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}")
print(
f"Hook Point Layer: {self.sparse_autoencoder.cfg.hook_point_layer}"
)
print(f"Hook Point: {self.sparse_autoencoder.cfg.hook_point}")
print(f"Writing files to: {self.outputs_folder}")

Expand Down Expand Up @@ -237,7 +251,9 @@ def run(self):
for k, v in vocab_dict.items():
modified_key = k
for anomaly in HTML_ANOMALIES:
modified_key = modified_key.replace(anomaly, HTML_ANOMALIES[anomaly])
modified_key = modified_key.replace(
anomaly, HTML_ANOMALIES[anomaly]
)
new_vocab_dict[v] = modified_key
vocab_dict = new_vocab_dict

Expand All @@ -253,7 +269,10 @@ def run(self):
if feature_batch_count < self.start_batch:
# print(f"Skipping batch - it's after start_batch: {feature_batch_count}")
continue
if self.end_batch is not None and feature_batch_count > self.end_batch:
if (
self.end_batch is not None
and feature_batch_count > self.end_batch
):
# print(f"Skipping batch - it's after end_batch: {feature_batch_count}")
continue

Expand Down Expand Up @@ -294,13 +313,17 @@ def run(self):
)

features_outputs = []
for _, feat_index in enumerate(feature_data.feature_data_dict.keys()):
for _, feat_index in enumerate(
feature_data.feature_data_dict.keys()
):
feature = feature_data.feature_data_dict[feat_index]

feature_output = {}
feature_output["featureIndex"] = feat_index

top10_logits = self.round_list(feature.logits_table_data.top_logits)
top10_logits = self.round_list(
feature.logits_table_data.top_logits
)
bottom10_logits = self.round_list(
feature.logits_table_data.bottom_logits
)
Expand All @@ -309,29 +332,41 @@ def run(self):
feature_output["neuron_alignment_indices"] = (
feature.feature_tables_data.neuron_alignment_indices
)
feature_output["neuron_alignment_values"] = self.round_list(
feature.feature_tables_data.neuron_alignment_values
feature_output["neuron_alignment_values"] = (
self.round_list(
feature.feature_tables_data.neuron_alignment_values
)
)
feature_output["neuron_alignment_l1"] = self.round_list(
feature.feature_tables_data.neuron_alignment_l1
feature_output["neuron_alignment_l1"] = (
self.round_list(
feature.feature_tables_data.neuron_alignment_l1
)
)
feature_output["correlated_neurons_indices"] = (
feature.feature_tables_data.correlated_neurons_indices
)
feature_output["correlated_neurons_l1"] = self.round_list(
feature.feature_tables_data.correlated_neurons_cossim
feature_output["correlated_neurons_l1"] = (
self.round_list(
feature.feature_tables_data.correlated_neurons_cossim
)
)
feature_output["correlated_neurons_pearson"] = self.round_list(
feature.feature_tables_data.correlated_neurons_pearson
feature_output["correlated_neurons_pearson"] = (
self.round_list(
feature.feature_tables_data.correlated_neurons_pearson
)
)
feature_output["correlated_features_indices"] = (
feature.feature_tables_data.correlated_features_indices
)
feature_output["correlated_features_l1"] = self.round_list(
feature.feature_tables_data.correlated_features_cossim
feature_output["correlated_features_l1"] = (
self.round_list(
feature.feature_tables_data.correlated_features_cossim
)
)
feature_output["correlated_features_pearson"] = self.round_list(
feature.feature_tables_data.correlated_features_pearson
feature_output["correlated_features_pearson"] = (
self.round_list(
feature.feature_tables_data.correlated_features_pearson
)
)

feature_output["neg_str"] = self.to_str_tokens_safe(
Expand All @@ -345,28 +380,32 @@ def run(self):

feature_output["frac_nonzero"] = (
float(
feature.acts_histogram_data.title.split(" = ")[1].split(
"%"
)[0]
feature.acts_histogram_data.title.split(" = ")[
1
].split("%")[0]
)
/ 100
if feature.acts_histogram_data.title is not None
else 0
)

freq_hist_data = feature.acts_histogram_data
freq_bar_values = self.round_list(freq_hist_data.bar_values)
feature_output["freq_hist_data_bar_values"] = freq_bar_values
feature_output["freq_hist_data_bar_heights"] = self.round_list(
freq_hist_data.bar_heights
freq_bar_values = self.round_list(
freq_hist_data.bar_values
)
feature_output["freq_hist_data_bar_values"] = (
freq_bar_values
)
feature_output["freq_hist_data_bar_heights"] = (
self.round_list(freq_hist_data.bar_heights)
)

logits_hist_data = feature.logits_histogram_data
feature_output["logits_hist_data_bar_heights"] = self.round_list(
logits_hist_data.bar_heights
feature_output["logits_hist_data_bar_heights"] = (
self.round_list(logits_hist_data.bar_heights)
)
feature_output["logits_hist_data_bar_values"] = self.round_list(
logits_hist_data.bar_values
feature_output["logits_hist_data_bar_values"] = (
self.round_list(logits_hist_data.bar_values)
)

feature_output["num_tokens_for_dashboard"] = (
Expand Down Expand Up @@ -420,8 +459,12 @@ def run(self):
{"pos": posContribs, "neg": negContribs}
)
activation["tokens"] = strs
activation["values"] = self.round_list(sd.feat_acts)
activation["maxValue"] = max(activation["values"])
activation["values"] = self.round_list(
sd.feat_acts
)
activation["maxValue"] = max(
activation["values"]
)
activation["lossValues"] = self.round_list(
sd.loss_contribution
)
Expand Down
6 changes: 4 additions & 2 deletions tutorials/neuronpedia/upload_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def nanToNeg999(obj: Any) -> Any:
return {k: nanToNeg999(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [nanToNeg999(v) for v in obj]
elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan(obj):
elif (isinstance(obj, float) or isinstance(obj, Decimal)) and math.isnan(
obj
):
return -999
return obj

Expand All @@ -29,7 +31,7 @@ def encode(self, o: Any, *args: Any, **kwargs: Any):
host = "http://localhost:3000"

# Upload alive features
for file_name in os.listdir(FEATURE_OUTPUTS_FOLDER):
for file_name in sorted(os.listdir(FEATURE_OUTPUTS_FOLDER)):
if file_name.startswith("batch-") and file_name.endswith(".json"):
print("Uploading file: " + file_name)
file_path = os.path.join(FEATURE_OUTPUTS_FOLDER, file_name)
Expand Down

0 comments on commit 1e3d53e

Please sign in to comment.