Skip to content

Commit

Permalink
Fixed warmstarting and fine-tuning. Fixes #2
Browse files Browse the repository at this point in the history
  • Loading branch information
bshall committed Jun 28, 2023
1 parent 1896269 commit 4f87749
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
6 changes: 3 additions & 3 deletions hubert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present

URLS = {
"hubert-discrete": "/~https://github.com/bshall/hubert/releases/download/v0.1/hubert-discrete-e9416457.pt",
"hubert-soft": "/~https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt",
"kmeans100": "/~https://github.com/bshall/hubert/releases/download/v0.1/kmeans100-50f36a95.pt",
"hubert-discrete": "/~https://github.com/bshall/hubert/releases/download/v0.2/hubert-discrete-96b248c5.pt",
"hubert-soft": "/~https://github.com/bshall/hubert/releases/download/v0.2/hubert-soft-35d9f29f.pt",
"kmeans100": "/~https://github.com/bshall/hubert/releases/download/v0.2/kmeans100-50f36a95.pt",
}


Expand Down
9 changes: 6 additions & 3 deletions hubert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def load_checkpoint(
logger.info(f"Loading checkpoint from {load_path}")
checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"})
hubert.load_state_dict(checkpoint["hubert"])
scaler.load_state_dict(checkpoint["scaler"])
optimizer.load_state_dict(checkpoint["optimizer"])
return checkpoint["step"], checkpoint["loss"]
if "scaler" in checkpoint:
scaler.load_state_dict(checkpoint["scaler"])
if "optimizer" in checkpoint:
optimizer.load_state_dict(checkpoint["optimizer"])
step, loss = checkpoint.get("step", 0), checkpoint.get("loss", float("inf"))
return step, loss
14 changes: 10 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
# Define hyperparameters for training:
########################################################################################

BATCH_SIZE = 64
BATCH_SIZE = 32
LEARNING_RATE = 2e-5
BETAS = (0.9, 0.98)
EPS = 1e-06
Expand All @@ -35,7 +35,7 @@
STEPS = 25000
LOG_INTERVAL = 5
VALIDATION_INTERVAL = 1000
CHECKPOINT_INTERVAL = 1000
CHECKPOINT_INTERVAL = 5000
BACKEND = "nccl"
INIT_METHOD = "tcp://localhost:54321"

Expand Down Expand Up @@ -79,8 +79,14 @@ def train(rank, world_size, args):
checkpoint = torch.hub.load_state_dict_from_url(
URLS["hubert-discrete"], map_location={"cuda:0": f"cuda:{rank}"}
)
consume_prefix_in_state_dict_if_present(checkpoint, "module.")
hubert.load_state_dict(checkpoint, strict=False)
consume_prefix_in_state_dict_if_present(checkpoint["hubert"], "module.")

# don't use warmstart weights for label embeddings and proj layer
del checkpoint["hubert"]["label_embedding.weight"]
del checkpoint["hubert"]["proj.weight"]
del checkpoint["hubert"]["proj.bias"]

hubert.load_state_dict(checkpoint["hubert"], strict=False)

hubert = DDP(hubert, device_ids=[rank])

Expand Down

0 comments on commit 4f87749

Please sign in to comment.