Skip to content

Commit

Permalink
fix tsea typing
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Mar 16, 2024
1 parent ed0b0ea commit 449d90f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions sae_analysis/tsea.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import re
import string
from typing import List, Optional
from typing import Optional

import nltk
import numpy as np
Expand All @@ -14,12 +14,12 @@

def get_enrichment_df(
projections: torch.Tensor,
features: List[int],
features: list[int],
gene_sets_selected: dict[str, set[int]],
):

gene_sets_token_ids_padded = pad_gene_sets(gene_sets_selected)
gene_sets_token_ids_tensor = torch.tensor(gene_sets_token_ids_padded.values())
gene_sets_token_ids_tensor = torch.tensor(list(gene_sets_token_ids_padded.values()))
enrichment_scores = calculate_batch_enrichment_scores(
projections[features], gene_sets_token_ids_tensor
)
Expand Down Expand Up @@ -163,7 +163,7 @@ def plot_top_k_feature_projections_by_token_and_category(
dec_projection_onto_W_U: torch.Tensor,
k: int = 5,
projection_onto: str = "W_U",
features: Optional[List[int]] = None,
features: Optional[list[int]] = None,
log_y: bool = True,
histnorm: Optional[str] = None,
):
Expand Down

0 comments on commit 449d90f

Please sign in to comment.