Skip to content

Commit

Permalink
address #135 again
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 16, 2024
1 parent 78b36f4 commit ccf4068
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "vector-quantize-pytorch"
version = "1.14.25"
version = "1.14.26"
description = "Vector Quantization - Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand Down
10 changes: 8 additions & 2 deletions vector_quantize_pytorch/residual_vq.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import random
from math import ceil
from functools import partial
from itertools import zip_longest
from typing import List

import torch
from torch import nn
from torch import nn, Tensor
import torch.nn.functional as F
from vector_quantize_pytorch.vector_quantize_pytorch import VectorQuantize

Expand Down Expand Up @@ -122,7 +125,7 @@ def forward(
self,
x,
mask = None,
indices = None,
indices: Tensor | List[Tensor] | None = None,
return_all_codes = False,
sample_codebook_temp = None,
freeze_codebook = False,
Expand All @@ -140,6 +143,9 @@ def forward(
all_losses = []
all_indices = []

if isinstance(indices, list):
indices = torch.stack(indices)

if return_loss:
assert not torch.any(indices == -1), 'some of the residual vq indices were dropped out. please use indices derived when the module is in eval mode to derive cross entropy loss'
ce_losses = []
Expand Down

0 comments on commit ccf4068

Please sign in to comment.