Skip to content

Commit

Permalink
Fix deepspeed for single GPU (#187)
Browse files Browse the repository at this point in the history
  • Loading branch information
alan-cooney authored Jan 17, 2024
1 parent 6f67efb commit f55652e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 21 deletions.
39 changes: 21 additions & 18 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions sparse_autoencoder/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def train_autoencoder(
] = torch.zeros(
(self.n_components, self.n_learned_features),
dtype=torch.int64,
device=autoencoder_device,
device=torch.device("cpu"),
)

for store_batch in activations_dataloader:
Expand Down Expand Up @@ -274,7 +274,7 @@ def train_autoencoder(
# Store count of how many neurons have fired
with torch.no_grad():
fired = learned_activations > 0
learned_activations_fired_count.add_(fired.sum(dim=0))
learned_activations_fired_count.add_(fired.sum(dim=0).cpu())

# Backwards pass
total_loss.backward()
Expand Down
3 changes: 3 additions & 0 deletions sparse_autoencoder/train/sweep.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def setup_autoencoder_optimizer_scheduler(
model=model,
optimizer=optim,
lr_scheduler=lr_scheduler, # type: ignore
config={
"train_batch_size": hyperparameters["pipeline"]["train_batch_size"],
},
)

return (model_engine, optimizer_engine, scheduler) # type: ignore
Expand Down
8 changes: 7 additions & 1 deletion sparse_autoencoder/train/utils/get_model_device.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
"""Get the device that the model is on."""
from deepspeed import DeepSpeedEngine
import torch
from torch.nn import Module
from torch.nn.parallel import DataParallel


def get_model_device(model: Module) -> torch.device:
def get_model_device(model: Module | DataParallel | DeepSpeedEngine) -> torch.device:
"""Get the device on which a PyTorch model is on.
Args:
Expand All @@ -15,6 +17,10 @@ def get_model_device(model: Module) -> torch.device:
Raises:
ValueError: If the model has no parameters.
"""
# Deepspeed models already have a device property, so just return that
if hasattr(model, "device"):
return model.device

# Check if the model has parameters
if len(list(model.parameters())) == 0:
exception_message = "The model has no parameters."
Expand Down

0 comments on commit f55652e

Please sign in to comment.