Skip to content

Commit

Permalink
Add new observer for KVCache and FP8 quantization (#1902)
Browse files Browse the repository at this point in the history
  • Loading branch information
lixcli authored Nov 20, 2024
1 parent fe6a87f commit b770c8b
Show file tree
Hide file tree
Showing 6 changed files with 304 additions and 26 deletions.
72 changes: 72 additions & 0 deletions paddleslim/quant/layers/custom_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Custome Attention Layer for quantization.
"""
import paddle.tensor as tensor
from paddle.nn import Layer
from paddle.nn.quant.format import ConvertibleQuantedLayer


class QuantizedCustomAttentionLayer(ConvertibleQuantedLayer):
"""
Quantized Custom Attention Layer.
"""

def __init__(self, layer: Layer, q_config=None):
"""
Initialize the QuantizeWrapper class.
Args:
layer (Layer): The layer to be quantized.
q_config (QuantConfig, optional): The quantization configuration. Defaults to None.
"""
super().__init__()
# hard code: get activation quanter from weight
self.activation_quanter_k = q_config.weight._instance(layer)
self.activation_quanter_v = q_config.activation._instance(layer)
self.layer = layer
self.quant_info = None
layer_name = self.layer.full_name()
self.layer_id = int(layer_name.split("_")[-1])
self.kv_losses = {}

def forward(self, q, config, k, v, attention_mask, output_attentions, **kwargs):
"""forward"""
perm = [0, 2, 1, 3] # [1, 2, 0, 3] if self.sequence_parallel else [0, 2, 1, 3]
tmp_k = tensor.transpose(x=k, perm=perm)
tmp_v = tensor.transpose(x=v, perm=perm)
if self.activation_quanter_k is not None:
tmp_k = self.activation_quanter_k(tmp_k)
if self.activation_quanter_v is not None:
tmp_v = self.activation_quanter_v(tmp_v)
k = tensor.transpose(x=tmp_k, perm=perm)
v = tensor.transpose(x=tmp_v, perm=perm)
return self.layer(
q,
config,
k,
v,
attention_mask,
output_attentions,
**kwargs,
)

def weights_to_quanters(self):
"""weights to quanters"""
return []

def activation_quanters(self):
"""activation to quanters"""
return ["activation_quanter_k", "activation_quanter_v"]
92 changes: 92 additions & 0 deletions paddleslim/quant/observers/abs_max_headwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from .channel_wise import ChannelWiseObserver
from paddle.quantization.factory import ObserverFactory


class AbsMaxHeadwiseObserver(ObserverFactory):
r"""
It collects channel-wise maximum absolute values of target weights.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import AbsMaxHeadwiseObserver
quanter = AbsMaxHeadwiseObserver()
q_config = QuantConfig(activation=None, weight=quanter)
"""

def __init__(self, quant_bits=8, quant_axis=None):
super(AbsMaxHeadwiseObserver, self).__init__(quant_bits=quant_bits, quant_axis=quant_axis)

def _get_class(self):
return AbsMaxHeadwiseObserverLayer


class AbsMaxHeadwiseObserverLayer(ChannelWiseObserver):
def __init__(self, layer, quant_bits=8, quant_axis=None):
super(AbsMaxHeadwiseObserverLayer, self).__init__(
layer, quant_bits=quant_bits, sign=True, symmetric=True, quant_axis=quant_axis
)
self.quant_bits = quant_bits
self.calibration_loss = float("inf")
self.qmin, self.qmax = self.qmin_qmax
self._layer = layer
self._max = None
self._scale = None
self._zero_point = None

def forward(self, inputs):
self._max = self._cal_abs_max(inputs)
return inputs

def _cal_abs_max(self, inputs):
reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != self.quant_axis()])
abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis).cast("float32")
abs_max_values = paddle.where(abs_max_values == np.float32(0.0), np.float32(1e-8), abs_max_values)

if self._max is not None:
abs_max_values = paddle.maximum(abs_max_values, self._max)

return abs_max_values

def min_value(self) -> float:
return 0.0

def max_value(self) -> float:
return self._max

def cal_thresholds(self):
"""Compute thresholds for MAX function."""
self._scale = self._max
self._zero_point = paddle.zeros_like(self._scale)

def scales(self):
"""Return output scales."""
if self._scale is None:
self.cal_thresholds()
return self._scale

def zero_points(self):
"""Return output zero points."""
if self._zero_point is None:
self.cal_thresholds()
return self._zero_point
2 changes: 1 addition & 1 deletion paddleslim/quant/observers/avg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
103 changes: 103 additions & 0 deletions paddleslim/quant/observers/avg_headwise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import paddle
from paddle.quantization.factory import ObserverFactory

from .abs_max_headwise import AbsMaxHeadwiseObserverLayer


class AvgHeadwiseObserver(ObserverFactory):
r"""
It collects channel-wise maximum absolute values of target weights.
Args:
bit_length(int, optional): Number of bits to represent an quantized integer in binary.
dtype(str, optional): The data type of input tensor.
name (str, optional): This parameter is used by developers to print debugging information. \
For details, please refer to :ref:`api_guide_Name`. Default is None.
Examples:
.. code-block:: python
from paddle.quantization import QuantConfig
from paddle.quantization.quanters import AbsMaxHeadwiseObserver
quanter = AbsMaxHeadwiseObserver()
q_config = QuantConfig(activation=None, weight=quanter)
"""

def __init__(self, quant_bits=8, quant_axis=None, moving_avg=False):
super(AvgHeadwiseObserver, self).__init__(quant_bits=quant_bits, quant_axis=quant_axis, moving_avg=moving_avg)

def _get_class(self):
return AvgHeadwiseObserverLayer


class AvgHeadwiseObserverLayer(AbsMaxHeadwiseObserverLayer):
def __init__(self, layer, quant_bits=8, quant_axis=None, moving_avg=True):
super(AvgHeadwiseObserverLayer, self).__init__(layer, quant_bits=quant_bits, quant_axis=quant_axis)
self.quant_bits = quant_bits
self._qmin, self._qmax = self.qmin_qmax
self._max = None
self._scale = None
self._zero_point = None
if quant_axis is not None:
self._channel_axis = quant_axis
self._current_iters = 0
self._range_update_factor_min = 0.001
self._moving_avg = moving_avg

def forward(self, inputs, quant_axis=None):
if quant_axis is not None:
self._channel_axis = quant_axis
self._max = self._cal_abs_max(inputs)
return inputs

def _cal_abs_max(self, inputs):
self._current_iters += 1
reduce_axis = tuple([i for i in range(len(inputs.shape)) if i != self.quant_axis()])
abs_max_values = paddle.max(paddle.abs(inputs), axis=reduce_axis).cast("float32")
abs_max_values = paddle.where(abs_max_values == np.float32(0.0), np.float32(1e-8), abs_max_values)
if self._max is not None:
if self._moving_avg:
# exponential moving average update
update_factor = 1.0 / self._current_iters
update_factor = max(update_factor, self._range_update_factor_min)
abs_max_values = self._max * (1 - update_factor) + abs_max_values * update_factor
else:
# normal average
abs_max_values = (self._max * (self._current_iters - 1) + abs_max_values) / self._current_iters
return abs_max_values

def min_value(self) -> float:
return 0.0

def max_value(self) -> float:
return self._max

def cal_thresholds(self):
"""Compute thresholds for MAX function."""
if self._scale is not None:
self._zero_point = paddle.zeros_like(self._scale)
return
self._scale = self._max
self._zero_point = paddle.zeros_like(self._scale)

def scales(self):
"""Return output scales."""
self.cal_thresholds()
return self._scale

def zero_points(self):
"""Return output zero points."""
self.cal_thresholds()
return self._zero_point
22 changes: 10 additions & 12 deletions paddleslim/quant/observers/channel_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,25 +28,23 @@


class ChannelWiseObserver(UniformObserver):
def __init__(
self,
layer,
quant_bits=8,
sign=True,
symmetric=True, ):
def __init__(self, layer, quant_bits=8, sign=True, symmetric=True, quant_axis=None):
super(ChannelWiseObserver, self).__init__(
quant_bits=quant_bits,
sign=sign,
symmetric=symmetric, )
self._channel_axis = CHANNEL_AXIS[type(layer)]
symmetric=symmetric,
)
if quant_axis is not None:
self._channel_axis = quant_axis
else:
assert type(layer) in CHANNEL_AXIS, "Unsupported layer type: {}".format(type(layer))
self._channel_axis = CHANNEL_AXIS[type(layer)]
self._quant_bits = quant_bits

def quant_axis(self):
""" Return quantization axis.
"""
"""Return quantization axis."""
return self._channel_axis

def bit_length(self):
""" Return the bit length of quantized data.
"""
"""Return the bit length of quantized data."""
return self._quant_bits
39 changes: 26 additions & 13 deletions paddleslim/quant/observers/uniform.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@


class UniformObserver(BaseObserver):
""" This is the base class for a uniform quantization observer, which provides
"""This is the base class for a uniform quantization observer, which provides
common functions for calculating the scale and zero-point used in uniform quantization.
Uniform quantization maps floating point values to integers, where the scale determines
the step size of the quantizer and the floating point zero is mapped to the zero-point,
Expand All @@ -31,14 +31,15 @@ class UniformObserver(BaseObserver):
symmetric (bool): Whether it is symmetric quantization. the quantization is symmetric.
In symmetric quantization, the range of floating point values is relaxed to be symmetric
around zero and the zero-point is always 0.
"""

def __init__(
self,
quant_bits=8,
sign=True,
symmetric=True, ):
self,
quant_bits=8,
sign=True,
symmetric=True,
):
super(UniformObserver, self).__init__()
self._quant_bits = quant_bits
self._sign = sign
Expand All @@ -54,14 +55,26 @@ def __init__(

@property
def qmin_qmax(self):
""" Calculate the range of the quantized integer based on the specified
"""Calculate the range of the quantized integer based on the specified
quant_bits, sign, and symmetric properties."""
if self._sign:
self._qmin = -2**(self.bit_length() - 1)
self._qmax = 2**(self.bit_length() - 1) - 1
if isinstance(self._quant_bits, tuple):
if self._quant_bits[0] == 4 and self._quant_bits[1] == 3 and len(self._quant_bits) == 2:
self._qmin = -448.0
self._qmax = 448.0
elif self._quant_bits[0] == 5 and self._quant_bits[1] == 2 and len(self._quant_bits) == 2:
self._qmin = -57344.0
self._qmax = 57344.0
else:
raise NotImplementedError(
"Currently, only float8_e4m3 and float8_e5m2 formats are supported. Please set quant_bits to (4,3) or (5,2) for the corresponding format."
)
else:
self._qmin = 0
self._qmax = 2**self.bit_length()
if self._sign:
self._qmin = -(2 ** (self.bit_length() - 1))
self._qmax = 2 ** (self.bit_length() - 1) - 1
else:
self._qmin = 0
self._qmax = 2 ** self.bit_length()
return self._qmin, self._qmax

@abc.abstractmethod
Expand Down

0 comments on commit b770c8b

Please sign in to comment.