Skip to content

Commit

Permalink
flake8 and black
Browse files Browse the repository at this point in the history
  • Loading branch information
jbloom-md committed Feb 28, 2024
1 parent 8e41e59 commit ed8345a
Show file tree
Hide file tree
Showing 16 changed files with 282 additions and 319 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ max-complexity = 25
extend-select = E9, F63, F7, F82
show-source = true
statistics = true
exclude = ./sae_training/geom_median/
exclude = ./sae_training/geom_median/, ./wandb/*, ./research/wandb/*
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,19 @@ repos:
- id: check-added-large-files
args: [--maxkb=250000]
- repo: /~https://github.com/psf/black
rev: 23.3.0
rev: 24.2.0
hooks:
- id: black
- repo: /~https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
args: ['--config=.flake8']
additional_dependencies: [
'flake8-blind-except',
'flake8-docstrings',
# 'flake8-docstrings',
'flake8-bugbear',
'flake8-comprehensions',
'flake8-docstrings',
'flake8-implicit-str-concat',
'pydocstyle>=5.0.0',
]
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ ipykernel = "^6.29.2"
matplotlib = "^3.8.3"
matplotlib-inline = "^0.1.6"
eindex = {git = "/~https://github.com/callummcdougall/eindex.git"}
datasets = "^2.17.1"


[tool.poetry.group.dev.dependencies]
Expand All @@ -33,4 +34,4 @@ profile = "black"

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
build-backend = "poetry.core.masonry.api"
2 changes: 1 addition & 1 deletion sae_analysis/dashboard_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
import plotly
import plotly.express as px
import torch
import wandb
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_analysis.visualizer.data_fns import get_feature_data
from sae_training.utils import LMSparseAutoencoderSessionloader

Expand Down
11 changes: 7 additions & 4 deletions sae_training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional

import torch

import wandb


Expand All @@ -21,9 +22,9 @@ class RunnerConfig(ABC):
is_dataset_tokenized: bool = True
context_size: int = 128
use_cached_activations: bool = False
cached_activations_path: Optional[
str
] = None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"
cached_activations_path: Optional[str] = (
None # Defaults to "activations/{dataset}/{model}/{full_hook_name}_{hook_point_head_index}"
)

# SAE Parameters
d_in: int = 512
Expand Down Expand Up @@ -61,7 +62,9 @@ class LanguageModelSAERunnerConfig(RunnerConfig):
l1_coefficient: float = 1e-3
lp_norm: float = 1
lr: float = 3e-4
lr_scheduler_name: str = "constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
lr_scheduler_name: str = (
"constantwithwarmup" # constant, constantwithwarmup, linearwarmupdecay, cosineannealing, cosineannealingwarmup
)
lr_warm_up_steps: int = 500
train_batch_size: int = 4096

Expand Down
2 changes: 1 addition & 1 deletion sae_training/evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import pandas as pd
import torch
import wandb
from tqdm import tqdm
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_act_name

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.sparse_autoencoder import SparseAutoencoder

Expand Down
6 changes: 4 additions & 2 deletions sae_training/sae_group.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import dataclasses
import gzip
import os
import pickle
import dataclasses
from sae_training.sparse_autoencoder import SparseAutoencoder

import torch

from sae_training.sparse_autoencoder import SparseAutoencoder


class SAEGroup:
def __init__(self, cfg):
Expand Down
2 changes: 1 addition & 1 deletion sae_training/toy_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

import einops
import torch
import wandb

import wandb
from sae_training.sparse_autoencoder import SparseAutoencoder
from sae_training.toy_models import Config as ToyConfig
from sae_training.toy_models import Model as ToyModel
Expand Down
4 changes: 2 additions & 2 deletions sae_training/train_sae_on_language_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import torch
import wandb
from torch.optim import Adam
from tqdm import tqdm
from transformer_lens import HookedTransformer

import wandb
from sae_training.activations_store import ActivationsStore
from sae_training.evals import run_evals
from sae_training.geom_median.src.geom_median.torch import compute_geometric_median
from sae_training.optim import get_scheduler
from sae_training.sae_group import SAEGroup
from sae_training.geom_median.src.geom_median.torch import compute_geometric_median


def train_sae_on_language_model(
Expand Down
2 changes: 1 addition & 1 deletion sae_training/train_sae_on_toy_model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
import wandb
from torch.utils.data import DataLoader
from tqdm import tqdm

import wandb
from sae_training.sparse_autoencoder import SparseAutoencoder


Expand Down
2 changes: 1 addition & 1 deletion scripts/generate_dashboards.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
import plotly
import plotly.express as px
import torch
import wandb
from torch.nn.functional import cosine_similarity
from tqdm import tqdm

import wandb
from sae_analysis.visualizer.data_fns import get_feature_data
from sae_training.utils import LMSparseAutoencoderSessionloader

Expand Down
Loading

0 comments on commit ed8345a

Please sign in to comment.