diff --git a/hubert/model.py b/hubert/model.py index f275b1c..97d9b66 100644 --- a/hubert/model.py +++ b/hubert/model.py @@ -10,9 +10,9 @@ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present URLS = { - "hubert-discrete": "/~https://github.com/bshall/hubert/releases/download/v0.1/hubert-discrete-e9416457.pt", - "hubert-soft": "/~https://github.com/bshall/hubert/releases/download/v0.1/hubert-soft-0d54a1f4.pt", - "kmeans100": "/~https://github.com/bshall/hubert/releases/download/v0.1/kmeans100-50f36a95.pt", + "hubert-discrete": "/~https://github.com/bshall/hubert/releases/download/v0.2/hubert-discrete-96b248c5.pt", + "hubert-soft": "/~https://github.com/bshall/hubert/releases/download/v0.2/hubert-soft-35d9f29f.pt", + "kmeans100": "/~https://github.com/bshall/hubert/releases/download/v0.2/kmeans100-50f36a95.pt", } diff --git a/hubert/utils.py b/hubert/utils.py index 8519e7a..d42ba3a 100644 --- a/hubert/utils.py +++ b/hubert/utils.py @@ -53,6 +53,9 @@ def load_checkpoint( logger.info(f"Loading checkpoint from {load_path}") checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"}) hubert.load_state_dict(checkpoint["hubert"]) - scaler.load_state_dict(checkpoint["scaler"]) - optimizer.load_state_dict(checkpoint["optimizer"]) - return checkpoint["step"], checkpoint["loss"] + if "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + if "optimizer" in checkpoint: + optimizer.load_state_dict(checkpoint["optimizer"]) + step, loss = checkpoint.get("step", 0), checkpoint.get("loss", float("inf")) + return step, loss diff --git a/train.py b/train.py index fc01575..ff5ca9d 100644 --- a/train.py +++ b/train.py @@ -26,7 +26,7 @@ # Define hyperparameters for training: ######################################################################################## -BATCH_SIZE = 64 +BATCH_SIZE = 32 LEARNING_RATE = 2e-5 BETAS = (0.9, 0.98) EPS = 1e-06 @@ -35,7 +35,7 @@ STEPS = 25000 LOG_INTERVAL = 5 VALIDATION_INTERVAL = 1000 -CHECKPOINT_INTERVAL = 1000 +CHECKPOINT_INTERVAL = 5000 BACKEND = "nccl" INIT_METHOD = "tcp://localhost:54321" @@ -79,8 +79,14 @@ def train(rank, world_size, args): checkpoint = torch.hub.load_state_dict_from_url( URLS["hubert-discrete"], map_location={"cuda:0": f"cuda:{rank}"} ) - consume_prefix_in_state_dict_if_present(checkpoint, "module.") - hubert.load_state_dict(checkpoint, strict=False) + consume_prefix_in_state_dict_if_present(checkpoint["hubert"], "module.") + + # don't use warmstart weights for label embeddings and proj layer + del checkpoint["hubert"]["label_embedding.weight"] + del checkpoint["hubert"]["proj.weight"] + del checkpoint["hubert"]["proj.bias"] + + hubert.load_state_dict(checkpoint["hubert"], strict=False) hubert = DDP(hubert, device_ids=[rank])