Highlights
We are excited to announce the 0.9.0 release of torchao! This release moves a number of sparsity techniques out of prototype, a significant overhaul of the quantize_ api, a new cutlass kernel for 4 bit dynamic quantization and more!
Block Sparsity promoted out of prototype
We’ve promoted block sparsity out of torchao.prototype and made several performance improvements.
You can accelerate your models with block sparsity as follows:
from torchao.sparsity import sparsify, block_sparse_weight
sparsify_(model, block_sparse_weight(blocksize=64))
Blocksparse Benchmarks
Technique | Decode (tok/s) | Model Size (GB) |
---|---|---|
baseline | 134.40 | 15.01 |
2:4 sparse | 163.13 | 10.08 |
bsr-0.8-32 | 210.91 | 6.01 |
bsr-0.8-64 | 222.43 | 6.00 |
bsr-0.9-32 | 255.19 | 4.88 |
bsr-0.9-64 | 262.94 | 4.88 |
2:4 sparse + int4wo (marlin) | 255.21 | 3.89 |
Block Sparsity technique names (bsr) indicate sparsity fraction and blocksize.
These numbers were generated on H100 using torchao/_models/llama/generate.py on the Meta-Llama-3.1-8B model. You can reproduce these numbers using this script
BC Breaking
TorchAO M1 Binaries currently not working
W've identified that the binaries are broken on M1 and have been since v0.8.0 though they were working in v0.7.0. We're working on a fix for this, details and discussion can be found here.
quantize_ configuration callables -> configs (#1595, #1694, #1696, #1697)
We are migrating the way quantize_
workflows are configured from callables (tensor subclass inserters) to direct configuration (config objects). Motivation: align with the rest of the ecosystem, enable inspection of configs after instantiation, remove a common source of confusion.
What is changing:
Specifically, here is how the signature of quantize_
's second argument will change:
#
# torchao v0.8.0 and before
#
def quantize(
model: torch.nn.Module,
apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
...,
): ...
#
# torchao v0.9.0
#
def quantize(
model: torch.nn.Module,
config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]],
...,
): ...
#
# torchao v0.10.0 or later (exact version TBD)
#
def quantize(
model: torch.nn.Module,
config: AOBaseConfig,
...,
): ...
- the name of the second argument to
quantize_
changed fromapply_tensor_subclass
toconfig
. Since the vast majority of callsites today are passing in configuration with a positional argument, this change should not affect most people. - the type of the second argument to
quantize_
will change fromCallable[[torch.nn.Module], torch.nn.Module]
toconfig: AOBaseConfig
, following a deprecation process detailed below. - for individual workflows, the user facing API name changed from snake case (
int8_weight_only
) to camel case (Int8WeightOnlyConfig
). All argument names for each config are kept as-is. We will keep the old snake case names (int8_weight_only
) around and alias them to the new names (int8_weight_only = Int8WeightOnlyConfig
), to avoid breaking callsites. We plan to keep the old names forever. Here are all the workflow config name changes:
old name (will keep working) | new name (recommended) |
---|---|
int4_weight_only |
Int4WeightOnlyConfig |
float8_dynamic_activation_float8_weight |
Float8DynamicActivationFloat8WeightConfig |
float8_static_activation_float8_weight |
Float8StaticActivationFloat8WeightConfig |
float8_weight_only |
Float8WeightOnlyConfig |
fpx_weight_only |
FPXWeightOnlyConfig |
gemlite_uintx_weight_only |
GemliteUIntXWeightOnlyConfig |
int4_dynamic_activation_int4_weight |
Int4DynamicActivationInt4WeightConfig |
int8_dynamic_activation_int4_weight |
Int8DynamicActivationInt4WeightConfig |
int8_dynamic_activation_int8_semi_sparse_weight |
n/a (deprecated) |
int8_dynamic_activation_int8_weight |
Int8DynamicActivationInt8WeightConfig |
int8_weight_only |
Int8WeightOnlyConfig |
uintx_weight_only |
UIntXWeightOnlyConfig |
Configuration for prototype workflows using quantize_
will be migrated at a later time.
How these changes can affect you:
- If you are a user of existing
quantize_
API workflows and are passing in config by a positional argument (quantize_(model, int8_weight_only(group_size=128))
), you are not affected. This positional syntax will keep working going forward. You are encouraged to migrate your callsite to the new config name (quantize_(model, Int8WeightOnlyConfig(group_size=128))
though the old names will continue to work indefinitely. - If you are a user of existing
quantize_
API workflows and are passing in config by a keyword argument (quantize_(model, tensor_subclass_inserter=int8_weight_only(group_size=128))
), your callsite will break. You will need to change your callsite toquantize_(model, config=int8_weight_only(group_size=128))
. We don't expect many people to be in this bucket. - If you are a developer writing new workflows for the
quantize_
API, you will need to use the new configuration system. Please see #1690 for details. - If you are a user of
sparsify_
, you are not affected for now and a similar change will happen in a future version of torchao.
This migration will be a two step process:
- in torchao v0.9.0, we will enable the new syntax while starting the deprecation process for the old syntax.
- in torchao v.0.10.0 or later, we will remove the old syntax
Please see #1690 for more details.
Block Sparsity imports after moved out of prototype (#1734)
Before:
from torchao.prototype.sparsity.superblock.blocksparse import block_sparse_weight
After:
from torchao.sparsity import block_sparse_weight
Deprecations
deprecation of the set_inductor_config
argument of quantize_
(#1716)
We are migrating the set_inductor_config
argument of quantize_
to individual workflows. Motivation:
- this functionality was intended for inference, and we don't want to expose it to future training workflows that we plan to add to
quantize_
. - higher level, this flag couples torchao workflows with torch.compile, which is not ideal. We would rather keep these systems decoupled at the
quantize_
API level, with individual workflows opting in as needed.
Impact on users
- for torchao v0.9.0:: if you are passing in
set_inductor_config
toquantize_
, your callsite will keep working with a deprecation warning. We recommend that you migrate this option to your individual workflow. - for a future version of torchao: the
set_inductor_config
argument will be removed fromquantize_
.
API changes
# torchao v0.8.x
def quantize_(
...,
set_inductor_config: bool = True,
...,
): ...
# torchao v.0.9.0
def quantize_(
...,
set_inductor_config: Optional[bool] = None,
...,
):
# if set_inductor_config != None, throw a deprecation warning
# if set_inductor_config == None, set it to True to stay consistent with old behavior
# torchao v TBD (a future release)
def quantize_(
...,
):
# set_inductor_config is removed from quantize_ and moved to relevant individual workflows
Please see #1715 for more details.
Deprecation warning for float8 training delayed and static scaling (#1681, #1680)
We plan to deprecate delayed and static scaling from torchao.float8 training codebase due to lack of real world use cases for delayed/static scaling (dynamic scaling is required for higher accuracy) and
complexity tax for supporting these features.
- for torchao v0.9.0: add deprecation warning for delayed and static scaling
- for torchao v0.10.0: deprecate delayed and static scaling
New Features
Supermask for improving accuracy for sparse models (#1729)
Supermask (https://pytorch.org/blog/speeding-up-vits/) is a technique for improving the accuracy of block sparsified models by learning a block-sparse mask during a training phase.
from torchao.sparsity import SupermaskLinear, block_sparse_weight
sparsify_(model, lambda x: SupermaskLinear.from_linear(x, block_size=64, sparsity_level=0.9)
# training here
# collapse supermask into a normal linear layer (with many weights set to 0) and then convert to block sparse format for inference speedup
sparsify_(model, lambda x: SupermaskLinear.to_linear(x, sparsity_level=0.9)
sparsify_(model, block_sparse_weight(blocksize=64))
Dynamic quantization W4A4 CUTLASS-based kernel (#1515)
This kernel which adds support for 4 bit dynamic activation + 4 bit weight quantization can be used as follows:
from torchao.quantization import int4_dynamic_activation_int4_weight
quantize_(model, int4_dynamic_activation_int4_weight)
Improvements
Early prototype MXFP8 and MXFP4 training and inference support for NVIDIA Blackwell GPUs
In torchao v0.9.0, we include very early support for training and inference on the NVIDIA Blackwell GPUs following the microscaling recipes from https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf, and backed by real MX gemms.
Here is how to use the current prototype APIs.
MX training
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_linear
from torchao.prototype.mx_formats.config import MXLinearConfig, MXGemmKernelChoice
from torchao.utils import is_sm_at_least_100
# early prototype: on MX-enabled hardware, you can use the real MX gemm backed by
# torchao's CUTLASS kernels. In the future, we will also add cuBLAS kernel support.
gemm_kernel_choice = MXGemmKernelChoice.EMULATED
if is_sm_at_least_100():
gemm_kernel_choice = MXGemmKernelChoice.CUTLASS
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
config = MXLinearConfig(
elem_dtype=torch.float8_e4m3fn,
block_size=32,
gemm_kernel_choice=gemm_kernel_choice,
)
swap_linear_with_mx_linear(m, config=config)
# training loop (not shown)
MX inference, weights are in MX and matmul is in high precision.
from torchao.prototype.mx_formats.mx_linear import swap_linear_with_mx_inference_linear
from torchao.prototype.mx_formats.config import MXLinearConfig
m = torch.nn.Sequential(torch.nn.Linear(32, 32)).cuda()
config = MXLinearConfig(elem_dtype=torch.float8_e4m3fn, block_size=32)
swap_linear_with_mx_inference_linear(m, config=config)
# do inference (not shown)
The additional features for MX support in v0.9.0 were enabled by:
- Add mx_fp8_bf16 kernel (#1637)
- Support mixed MX element dtype in
mx_mm
function andMXLinear
. (#1667) - move block_size and elem_dtype into MXLinearConfig (#1689)
- hook up mxfp8 and mxfp4 CUTLASS kernels to MXLinear (#1713)
- add ceil and RNE rounding modes to the cast from fp32 to e8m0 (#1643)
Experimental
SAM2
- Add modal script extensions (#1500)
- Increase export usage, small perf improvements (#1673)
- Model experiments QoL improvements (#1683)
- Collect p90 latency statistics (#1703)
Training
- Support power of 2 scaling factors in float8 training with rowwise scaling and use e4m3 in fwd and bwd pass (#1670)
- clean up recipe names in Float8 training (#1730)
- make the "config from recipe" API polished in Float8 training (#1731)
- dd workaround to reduce FSDP memory usage for float8 rowwise training (#1629)
- Make FakeQuantizer expose useful config details when printed (#1717)
Sparsity
- Promote blocksparse from prototype, make it faster (#1734)
Other
- Relax dtype requirements for int4 and float8 quants in autoquant (#1571)
- Update init.py to load experimental ops even if other C++ ops are not found (#1565)
Bug Fixes
- Fix torch.intx support in FakeQuantizeConfig (#1544)
- Fix float related autoquant options (#1562)
- Fix #1559, sparsity instead of sparstiy (#1560)
- Fix
.item()
issue in running parallel evaluation for BO mixed precision (#1630) - Add more stringent test for CPUOffloadOptimizer (#1650)
- Fix LR scheduler issue with CPU offload optimizer (#1649)
- Add int8 dynamic activation + int8 weight only test to TensorParallel (#1657)
- Fix compile issue for Marlin qqq on sm<8.0 (#1651)
- Fix use_hqq for int4_weight_only quantize (#1707)
- Unbreak float8 static quant tutorial (/~https://github.com/pytorch/ao
/pull/1709) - Fix
DDP
withnf4
(#1684) - Fix tensor parallelism for float8 training with rowwise scaling (#1718)
Documentation
- Update supported dtypes for fp8 (#1573)
- Sparsity docs update (#1590)
- Sparsity getting started docs (#1592)
- Fix broken link on doc page (#1582)
- Add quick start guide for first time users (#1611)
- Update api_ref_dtypes docs (#1610)
- Add module swap -> tensor subclass migration tutorial (#1596)
- Update docs to refer to version.html (#1631)
- Split contributor guide into quantization overview (#1618)
- Update api_ref_quantization docs (#1619)
- Migrate static quant tutorials to direct configuration (#1710)
- Update torchao READMEs with new configuration APIs (#1711)
- Update SAM2 README.md (#1735)
- Add rowwise scaling README.md entry for float8 training(#1733)
Developers
- Consolidate
ZeroPointDomain.NONE
&None
zero point domains (#1556) - Only run docs build in CI if docs have changed (#1589)
- Add separate quantization primitives for float8 (#1597)
- Add boiler plate code to Tensor subclass (#1663)
- Change TORCH_LIBRARY to TORCH_LIBRARY_FRAGMENT (#1645)
- Reformat C++ kernels (#1723)
- Add torchao/experimental CI test (#1586)
- Clean up linear_int8_dynamic_activation_intx_weight_subclass (#1553)
New Contributors
- @jaewoosong made their first contribution in #1560
- @haodongucsb made their first contribution in #1630
- @nikhil-arm made their first contribution in #1447
- @ngc92 made their first contribution in #1650
- @balancap made their first contribution in #1667
Full Changelog: v0.8.0...v0.9.0-rc1