diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 82324394a8..9aeaa53664 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -31,6 +31,9 @@ from torchao.quantization.qat.embedding import ( FakeQuantizedEmbedding, ) +from torchao.quantization.qat.fake_quantizer import ( + FakeQuantizer, +) from torchao.quantization.qat.linear import ( FakeQuantizedLinear, Int4WeightOnlyQATLinear, @@ -1348,6 +1351,21 @@ def test_fake_quantize_config_torch_intx(self): out2 = linear2(*x2) torch.testing.assert_close(out1, out2, atol=0, rtol=0) + @unittest.skipIf( + not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower" + ) + def test_fake_quantizer_repr(self): + """ + Test that `repr(FakeQuantizer(config))` exposes useful config details. + """ + config = FakeQuantizeConfig(torch.int4, group_size=128) + fake_quantizer = FakeQuantizer(config) + fake_quantizer_repr = repr(fake_quantizer) + self.assertTrue("dtype=torch.int4" in fake_quantizer_repr) + self.assertTrue("group_size=128" in fake_quantizer_repr) + self.assertTrue("PerGroup" in fake_quantizer_repr) + self.assertTrue("MappingType.SYMMETRIC" in fake_quantizer_repr) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/qat/fake_quantizer.py b/torchao/quantization/qat/fake_quantizer.py index 15cd3aaca4..de747366a6 100644 --- a/torchao/quantization/qat/fake_quantizer.py +++ b/torchao/quantization/qat/fake_quantizer.py @@ -134,3 +134,9 @@ def _should_compute_qparams(self) -> bool: Return whether we need to compute new scales and zero points. """ return self.config.is_dynamic or self.scale is None or self.zero_point is None + + def __repr__(self) -> str: + """ + Return a human readable representation of this `FakeQuantizer` with config details. + """ + return "FakeQuantizer(%s)" % self.config