diff --git a/tests/test_lfq.py b/tests/test_lfq.py new file mode 100644 index 0000000..0eb5903 --- /dev/null +++ b/tests/test_lfq.py @@ -0,0 +1,77 @@ +import torch +import pytest +from vector_quantize_pytorch import LFQ +import math +""" +testing_strategy: +subdivisions: using masks, using frac_per_sample_entropy < 1 +""" + +torch.manual_seed(0) + +@pytest.mark.parametrize('frac_per_sample_entropy', (1., 0.5)) +@pytest.mark.parametrize('mask', (torch.tensor([False, False]), + torch.tensor([True, False]), + torch.tensor([True, True]))) +def test_masked_lfq( + frac_per_sample_entropy, + mask +): + # you can specify either dim or codebook_size + # if both specified, will be validated against each other + + quantizer = LFQ( + codebook_size = 65536, # codebook size, must be a power of 2 + dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined + entropy_loss_weight = 0.1, # how much weight to place on entropy loss + diversity_gamma = 1., # within entropy loss, how much weight to give to diversity + frac_per_sample_entropy = frac_per_sample_entropy + ) + + image_feats = torch.randn(2, 16, 32, 32) + + ret, loss_breakdown = quantizer(image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature + + quantized, indices, _ = ret + assert (quantized == quantizer.indices_to_codes(indices)).all() + +@pytest.mark.parametrize('frac_per_sample_entropy', (0.1,)) +@pytest.mark.parametrize('iters', (10,)) +@pytest.mark.parametrize('mask', (None, torch.tensor([True, False]))) +def test_lfq_bruteforce_frac_per_sample_entropy(frac_per_sample_entropy, iters, mask): + image_feats = torch.randn(2, 16, 32, 32) + + full_per_sample_entropy_quantizer = LFQ( + codebook_size = 65536, # codebook size, must be a power of 2 + dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined + entropy_loss_weight = 0.1, # how much weight to place on entropy loss + diversity_gamma = 1., # within entropy loss, how much weight to give to diversity + frac_per_sample_entropy = 1 + ) + + partial_per_sample_entropy_quantizer = LFQ( + codebook_size = 65536, # codebook size, must be a power of 2 + dim = 16, # this is the input feature dimension, defaults to log2(codebook_size) if not defined + entropy_loss_weight = 0.1, # how much weight to place on entropy loss + diversity_gamma = 1., # within entropy loss, how much weight to give to diversity + frac_per_sample_entropy = frac_per_sample_entropy + ) + + ret, loss_breakdown = full_per_sample_entropy_quantizer( + image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) + true_per_sample_entropy = loss_breakdown.per_sample_entropy + + per_sample_losses = torch.zeros(iters) + for iter in range(iters): + ret, loss_breakdown = partial_per_sample_entropy_quantizer( + image_feats, inv_temperature=100., return_loss_breakdown=True, mask=mask) # you may want to experiment with temperature + + quantized, indices, _ = ret + assert (quantized == partial_per_sample_entropy_quantizer.indices_to_codes(indices)).all() + per_sample_losses[iter] = loss_breakdown.per_sample_entropy + # 95% confidence interval + assert abs(per_sample_losses.mean() - true_per_sample_entropy) \ + < (1.96*(per_sample_losses.std() / math.sqrt(iters))) + + print("difference: ", abs(per_sample_losses.mean() - true_per_sample_entropy)) + print("std error:", (1.96*(per_sample_losses.std() / math.sqrt(iters)))) \ No newline at end of file diff --git a/vector_quantize_pytorch/lookup_free_quantization.py b/vector_quantize_pytorch/lookup_free_quantization.py index 0df700a..7c2d694 100644 --- a/vector_quantize_pytorch/lookup_free_quantization.py +++ b/vector_quantize_pytorch/lookup_free_quantization.py @@ -335,26 +335,34 @@ def forward( codebook = self.maybe_l2norm(codebook) - # the same as euclidean distance up to a constant - distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook) - - prob = (-distance * inv_temperature).softmax(dim = -1) - - # account for mask - - if exists(mask): - prob = prob[mask] - else: - prob = rearrange(prob, 'b n ... -> (b n) ...') - # whether to only use a fraction of probs, for reducing memory if self.frac_per_sample_entropy < 1.: - num_tokens = prob.shape[0] + # account for mask + if exists(mask): + original_input = original_input[mask] + original_input = rearrange(original_input, 'b n ... -> (b n) ...') + + num_tokens = original_input.size(0) num_sampled_tokens = int(num_tokens * self.frac_per_sample_entropy) rand_mask = torch.randn(num_tokens).argsort(dim = -1) < num_sampled_tokens - per_sample_probs = prob[rand_mask] + + sampled_input = original_input[rand_mask] + + sampled_distance = -2 * einsum('... i d, j d -> ... i j', sampled_input, codebook) + + sampled_prob = (-sampled_distance * inv_temperature).softmax(dim = -1) + + per_sample_probs = sampled_prob else: + if exists(mask): + original_input = original_input[mask] + original_input = rearrange(original_input, 'b n ... -> (b n) ...') + # the same as euclidean distance up to a constant + distance = -2 * einsum('... i d, j d -> ... i j', original_input, codebook) + + prob = (-distance * inv_temperature).softmax(dim = -1) + per_sample_probs = prob # calculate per sample entropy