Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support lsq new #501

Merged
merged 13 commits into from
Apr 11, 2023
Prev Previous commit
Next Next commit
add _load_from_state_dict to lsq fakequant
  • Loading branch information
HIT-cwh committed Apr 11, 2023
commit bb221e9beffd35cefdaeba852f76b00a2ff4af72
40 changes: 40 additions & 0 deletions mmrazor/models/fake_quants/lsq.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,46 @@ def forward(self, X):

return X

def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Removing this function throws an error that the the size of the
loaded tensor does not match the original size i.e., These buffers
start out with numel 0 and become numel 1 once they have their first
forward pass.

Modified from /~https://github.com/pytorch/pytorch/blob/master/torch/ao/quantization/fake_quantize.py # noqa:E501
"""
local_state = ['scale', 'zero_point']
for name in local_state:
key = prefix + name
if key in state_dict:
val = state_dict[key]
# Custom handling to allow loading scale and zero_point
# of size N into uninitialized buffers of size 0. The
# buffers are resized here, and the values are copied in
# the default state_dict loading code of the parent.
if name == 'scale':
self.scale.data = self.scale.data.resize_(val.shape)
else:
assert name == 'zero_point'
self.zero_point.data = self.zero_point.data.resize_(
val.shape)
# For torchscript module we need to update the attributes here
# since we do not call the `_load_from_state_dict` function
# defined module.py
if torch.jit.is_scripting():
if name == 'scale':
self.scale.copy_(val)
else:
assert name == 'zero_point'
self.zero_point.copy_(val)
elif strict:
missing_keys.append(key)
super(LearnableFakeQuantize,
self)._load_from_state_dict(state_dict, prefix, local_metadata,
strict, missing_keys,
unexpected_keys, error_msgs)

@torch.jit.export
def extra_repr(self):
"""The printable representational string."""
Expand Down