Skip to content

Commit

Permalink
Make FakeQuantizer expose useful config details
Browse files Browse the repository at this point in the history
**Summary:** Expose useful config details when printing
FakeQuantizer, which appears when printing QAT prepared models
containing linear layers.

Before:
```
>>> print(prepared_model.layers[0].attn.qproj)
FakeQuantizedLinear(
  in_features=4096, out_features=4096, bias=False
  (activation_fake_quantizer): FakeQuantizer()
  (weight_fake_quantizer): FakeQuantizer()
)
```

After:

```
>>> print(prepared_model.layers[0].attn.qproj)
FakeQuantizedLinear(
  in_features=4096, out_features=4096, bias=False
  (activation_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int8, granularity=PerToken(), mapping_type=<MappingType.ASYMMETRIC: 3>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
  (weight_fake_quantizer): FakeQuantizer(FakeQuantizeConfig(dtype=torch.int4, granularity=PerGroup(group_size=32), mapping_type=<MappingType.SYMMETRIC: 1>, scale_precision=torch.float32, zero_point_precision=torch.int32, zero_point_domain=<ZeroPointDomain.INT: 1>, is_dynamic=True, range_learning=False))
)
```

**Test Plan:**
python test/quantization/test_qat.py -k test_fake_quantizer_repr
  • Loading branch information
andrewor14 committed Feb 14, 2025
1 parent 12e830b commit b304c63
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
18 changes: 18 additions & 0 deletions test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from torchao.quantization.qat.embedding import (
FakeQuantizedEmbedding,
)
from torchao.quantization.qat.fake_quantizer import (
FakeQuantizer,
)
from torchao.quantization.qat.linear import (
FakeQuantizedLinear,
Int4WeightOnlyQATLinear,
Expand Down Expand Up @@ -1348,6 +1351,21 @@ def test_fake_quantize_config_torch_intx(self):
out2 = linear2(*x2)
torch.testing.assert_close(out1, out2, atol=0, rtol=0)

@unittest.skipIf(
not TORCH_VERSION_AT_LEAST_2_6, "skipping when torch version is 2.6 or lower"
)
def test_fake_quantizer_repr(self):
"""
Test that `repr(FakeQuantizer(config))` exposes useful config details.
"""
config = FakeQuantizeConfig(torch.int4, group_size=128)
fake_quantizer = FakeQuantizer(config)
fake_quantizer_repr = repr(fake_quantizer)
self.assertTrue("dtype=torch.int4" in fake_quantizer_repr)
self.assertTrue("group_size=128" in fake_quantizer_repr)
self.assertTrue("PerGroup" in fake_quantizer_repr)
self.assertTrue("MappingType.SYMMETRIC" in fake_quantizer_repr)


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions torchao/quantization/qat/fake_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,9 @@ def _should_compute_qparams(self) -> bool:
Return whether we need to compute new scales and zero points.
"""
return self.config.is_dynamic or self.scale is None or self.zero_point is None

def __repr__(self) -> str:
"""
Return a human readable representation of this `FakeQuantizer` with config details.
"""
return "FakeQuantizer(%s)" % self.config

0 comments on commit b304c63

Please sign in to comment.