Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

migrate static quant tutorials to direct configuration #1710

Merged
merged 50 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
24114ce
Update
vkuzo Jan 22, 2025
5b9d876
Update
vkuzo Jan 22, 2025
1cea42f
Update
vkuzo Jan 22, 2025
138883b
Update
vkuzo Jan 22, 2025
ba045ea
Update
vkuzo Jan 22, 2025
94d9426
Update
vkuzo Jan 22, 2025
b589ce7
Update
vkuzo Jan 23, 2025
aaba2d8
Update
vkuzo Feb 5, 2025
26850da
Update
vkuzo Feb 5, 2025
7caecb1
Update
vkuzo Feb 10, 2025
d42a590
Update
vkuzo Feb 10, 2025
5702ea0
Update
vkuzo Feb 11, 2025
0542402
Update
vkuzo Feb 11, 2025
146ac3b
Update
vkuzo Feb 11, 2025
5f75897
Update
vkuzo Feb 11, 2025
1c9c39f
Update
vkuzo Feb 11, 2025
1ff1f6e
Update
vkuzo Feb 11, 2025
bb253ef
Update
vkuzo Feb 11, 2025
c2ed2da
Update
vkuzo Feb 11, 2025
698989b
Update
vkuzo Feb 11, 2025
6184530
Update
vkuzo Feb 11, 2025
397002e
Update
vkuzo Feb 11, 2025
5514a99
Update
vkuzo Feb 11, 2025
fac3263
Update
vkuzo Feb 11, 2025
1e15950
Update
vkuzo Feb 11, 2025
e9c03e0
Update
vkuzo Feb 11, 2025
f5b7d87
Update
vkuzo Feb 11, 2025
6684b39
Update
vkuzo Feb 11, 2025
4dcb349
Update
vkuzo Feb 12, 2025
d63e657
Update
vkuzo Feb 13, 2025
36c2096
Update
vkuzo Feb 13, 2025
ca7531d
Update
vkuzo Feb 13, 2025
b55b1bb
Update
vkuzo Feb 13, 2025
3aaf5a0
Update
vkuzo Feb 13, 2025
3fd4cfc
Update
vkuzo Feb 13, 2025
ac7e5da
Update
vkuzo Feb 14, 2025
1e152e3
Update
vkuzo Feb 14, 2025
0be10ae
Update
vkuzo Feb 14, 2025
2f0d4e3
Update
vkuzo Feb 14, 2025
e397c47
Update
vkuzo Feb 14, 2025
9eebc4f
Update
vkuzo Feb 14, 2025
81dcff8
Update
vkuzo Feb 14, 2025
f44befc
Update
vkuzo Feb 14, 2025
e534d64
Update
vkuzo Feb 14, 2025
54d3c31
Update
vkuzo Feb 14, 2025
7688b35
Update
vkuzo Feb 14, 2025
e776f11
Update
vkuzo Feb 14, 2025
03fb862
Update
vkuzo Feb 14, 2025
0c09446
Update
vkuzo Feb 14, 2025
1979394
Update
vkuzo Feb 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions torchao/dtypes/floatx/float8_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def _linear_fp8_act_fp8_weight_impl(
):
"""Implements matmul between FP8 input and FP8 weight with compute using _scaled_mm"""
scaled_mm_config = weight_tensor._layout.mm_config
assert scaled_mm_config is not None
out_shape = get_out_shape(input_tensor.shape, weight_tensor.shape)

# Weight tensor preprocessing
Expand Down
114 changes: 65 additions & 49 deletions tutorials/calibration_flow/awq_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
"""

import copy
from dataclasses import dataclass

import torch
import torch.nn.functional as F
from torch import Tensor

from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
Float8Layout,
to_affine_quantized_floatx_static,
Expand All @@ -33,6 +35,9 @@
from torchao.quantization.quant_primitives import (
MappingType,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.utils import compute_error


Expand Down Expand Up @@ -83,61 +88,72 @@ def replacement_fn(m):
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)


@dataclass
class ApplyAWQConfig(AOBaseConfig):
target_dtype: torch.dtype


# converting observed linear module to linear module with quantzied weights (and quantized activations)
# with tensor subclasses
def apply_awq(target_dtype: torch.dtype):
# target_dtype = torch.uint8
def _apply_awq_to_linear(observed_linear):
# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()

def weight_quant_func(weight):
block_size = (1, weight.shape[1])
if target_dtype == torch.uint8:
return to_affine_quantized_intx_static(
weight, weight_scale, weight_zero_point, block_size, target_dtype
)
elif target_dtype == torch.float8_e4m3fn:
return to_affine_quantized_floatx_static(
weight,
weight_scale,
block_size,
target_dtype,
Float8Layout(mm_config=None),
)
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")

linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
False,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

# activation quantization
# pretend this to be the equalization scale, in reality the `act_obs` should
# be an observer that can caluclate equalization scale
equalization_scale, _ = observed_linear.act_obs.calculate_qparams()
equalization_scale = torch.ones_like(equalization_scale)

linear.weight = torch.nn.Parameter(
weight_quant_func(linear.weight * equalization_scale), requires_grad=False
)
@register_quantize_module_handler(ApplyAWQConfig)
def _apply_awq_transform(
module: torch.nn.Module,
config: ApplyAWQConfig,
):
target_dtype = config.target_dtype
observed_linear = module

linear.weight = torch.nn.Parameter(
to_weight_tensor_with_linear_activation_scale_metadata(
linear.weight, equalization_scale
),
requires_grad=False,
)
# target_dtype = torch.uint8
# weight quantization
weight_scale, weight_zero_point = observed_linear.weight_obs.calculate_qparams()

def weight_quant_func(weight):
block_size = (1, weight.shape[1])
if target_dtype == torch.uint8:
return to_affine_quantized_intx_static(
weight, weight_scale, weight_zero_point, block_size, target_dtype
)
elif target_dtype == torch.float8_e4m3fn:
return to_affine_quantized_floatx_static(
weight,
weight_scale,
block_size,
target_dtype,
Float8Layout(mm_config=None),
)
else:
raise ValueError(f"Unsupported target dtype {target_dtype}")

linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
False,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = observed_linear.weight
linear.bias = observed_linear.bias

# activation quantization
# pretend this to be the equalization scale, in reality the `act_obs` should
# be an observer that can caluclate equalization scale
equalization_scale, _ = observed_linear.act_obs.calculate_qparams()
equalization_scale = torch.ones_like(equalization_scale)

return linear
linear.weight = torch.nn.Parameter(
weight_quant_func(linear.weight * equalization_scale), requires_grad=False
)

linear.weight = torch.nn.Parameter(
to_weight_tensor_with_linear_activation_scale_metadata(
linear.weight, equalization_scale
),
requires_grad=False,
)

return _apply_awq_to_linear
return linear


######## Test ##########
Expand Down Expand Up @@ -201,7 +217,7 @@ def test_awq(target_dtype: torch.dtype, mapping_type: MappingType):

# quantized linear represented as an nn.Linear with modified tensor subclass weights
# for both activation and weight quantization
quantize_(m, apply_awq(target_dtype), is_observed_linear)
quantize_(m, ApplyAWQConfig(target_dtype), is_observed_linear)
print("quantized model (applying tensor subclass to weight):", m)
after_quant = m(*example_inputs)
assert compute_error(before_quant, after_quant) > 25
Expand Down
66 changes: 38 additions & 28 deletions tutorials/calibration_flow/gptq_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import torch
from torch.utils._pytree import tree_flatten, tree_unflatten

from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
to_affine_quantized_intx,
to_affine_quantized_intx_static,
Expand All @@ -47,6 +48,9 @@
to_linear_activation_quantized,
)
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)
from torchao.quantization.utils import compute_error

torch.manual_seed(0)
Expand Down Expand Up @@ -252,36 +256,42 @@ def _register_forward_pre_hook(module: torch.nn.Module):
)


# using a function to align with the API in quant_api
def apply_activation_static_weight_quant():
def _apply_activation_static_weight_quant(observed_linear):
target_dtype = torch.uint8

# we can quantize the weight here as well
class ApplyActivationStaticWeightQuantConfig(AOBaseConfig):
pass

# activation quantization
act_scale, act_zero_point = (
observed_linear.input_scale,
observed_linear.input_zp,
)
input_quant_func = lambda x: to_affine_quantized_intx_static(
x, act_scale, act_zero_point, x.shape, target_dtype
)
# for demo purpose only, we quantize the weight here
weight = observed_linear.weight
weight = to_affine_quantized_intx(
weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8
)
observed_linear.weight = torch.nn.Parameter(
to_linear_activation_quantized(weight, input_quant_func),
requires_grad=False,
)

del observed_linear.input_scale
del observed_linear.input_zp
return observed_linear
# using a function to align with the API in quant_api
@register_quantize_module_handler(ApplyActivationStaticWeightQuantConfig)
def _apply_activation_static_weight_quant_transform(
module: torch.nn.Module,
config: ApplyActivationStaticWeightQuantConfig,
):
observed_linear = module
target_dtype = torch.uint8

# we can quantize the weight here as well

# activation quantization
act_scale, act_zero_point = (
observed_linear.input_scale,
observed_linear.input_zp,
)
input_quant_func = lambda x: to_affine_quantized_intx_static(
x, act_scale, act_zero_point, x.shape, target_dtype
)
# for demo purpose only, we quantize the weight here
weight = observed_linear.weight
weight = to_affine_quantized_intx(
weight, MappingType.SYMMETRIC, (1, weight.shape[-1]), torch.int8
)
observed_linear.weight = torch.nn.Parameter(
to_linear_activation_quantized(weight, input_quant_func),
requires_grad=False,
)

return _apply_activation_static_weight_quant
del observed_linear.input_scale
del observed_linear.input_zp
return observed_linear


example_inputs = (torch.randn(32, 64),)
Expand All @@ -298,7 +308,7 @@ def _apply_activation_static_weight_quant(observed_linear):

# just quantizing activation since we only observed quantization, this could be extended to support
# quantizing weight as well
quantize_(m, apply_activation_static_weight_quant(), _is_linear)
quantize_(m, ApplyActivationStaticWeightQuantConfig(), _is_linear)
for l in m.modules():
if isinstance(l, torch.nn.Linear):
assert isinstance(l.weight, LinearActivationQuantizedTensor)
Expand Down
Loading
Loading