Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add GPTQ and uniform interfaces #538

Merged
merged 26 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add docstring for gptq and sparse_gpt
  • Loading branch information
humu789 committed May 19, 2023
commit e4203eecb456191a436aeb3a305bc81beade72e1
14 changes: 12 additions & 2 deletions mmrazor/implementations/pruning/sparse_gpt/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


def to_static_model(model: nn.Module):
"""Replace dynamicops with torch modules."""
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
fix_subnet = export_fix_subnet(model)[0]
Expand All @@ -16,8 +17,7 @@ def to_static_model(model: nn.Module):


class SparseGptCompressor():

# init
"""The compressor with SparseGPT."""

def __init__(self) -> None:
self.model: nn.Module = None
Expand All @@ -26,6 +26,7 @@ def prepare(self,
model: nn.Module,
prune_conv=True,
prune_linear=True) -> None:
"""Prepare for compressing model."""
self.model = model
prune_modules: dict = {}
if prune_conv:
Expand All @@ -36,19 +37,23 @@ def prepare(self,

@classmethod
def to_static_model(cls, model):
"""Convert replaced op with the original torch model."""
return to_static_model(model)

# hessian

def register_hessian_hooks(self):
"""Register updating hessian hooks for specified ops."""
for module in self.sparse_ops:
module.register_hessian_hook()

def remove_hessian_hooks(self):
"""Remove updating hessian hooks for specified ops."""
for module in self.sparse_ops:
module.remove_hessian_hook()

def init_hessian(self, device=None):
"""Init hessian."""
for op in self.sparse_ops:
op.init_hessian(device=device)

Expand All @@ -60,6 +65,7 @@ def prune(self,
blocksize=128,
percdamp=.01,
device=torch.device('cuda')):
"""Apply the compression algorithm to the model."""
for name, module in self.named_sparse_ops:
try:
original_device = next(module.parameters()).device
Expand All @@ -78,19 +84,23 @@ def prune(self,
print_log(f'prune {name} failed as {e}')

def prune_24(self, device=torch.device('cuda:0')):
"""Apply the compression algorithm to the model with the specified
setting."""
self.prune(0.5, prunen=2, prunem=4, device=device)

# ops

@property
def sparse_ops(self):
"""The ops to be applied the algorithm."""
assert self.model is not None
for module in self.model.modules():
if isinstance(module, SparseGptMixIn):
yield module

@property
def named_sparse_ops(self):
"""The named ops to be applied the algorithm."""
for name, module in self.model.named_modules():
if isinstance(module, SparseGptMixIn):
yield name, module
16 changes: 14 additions & 2 deletions mmrazor/implementations/pruning/sparse_gpt/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@


class SparseGptMixIn(ModuleProtocol):

# init
"""The core algorithm implementation for SparseGpt."""

def _sparse_gpt_mix_in_init(self):
"""Init mixin."""
self.sparse_gpt_handles = []
self.rows = self.weight_matrix.shape[0]
self.columns = self.weight_matrix.shape[1]
Expand All @@ -37,6 +37,7 @@ def weight_matrix(self):

@weight_matrix.setter
def weight_matrix(self, value: torch.Tensor):
"""Set weight."""
with torch.no_grad():
value = value.reshape(self.weight.shape).to(self.weight.device).to(
self.weight.dtype)
Expand Down Expand Up @@ -69,6 +70,7 @@ def hessian(self):

@hessian.setter
def hessian(self, value: torch.Tensor):
"""Set hessian."""
with torch.no_grad():
if dist.is_initialized():
if dist.get_rank() == 0:
Expand All @@ -82,6 +84,7 @@ def hessian(self, value: torch.Tensor):

@torch.no_grad()
def update_hessian(self, input: torch.Tensor):
"""Update hessian."""
input = self.format_input(input).float()
H_save = self.hessian
H_save = H_save.to(input.device)
Expand All @@ -100,6 +103,7 @@ def update_hessian(self, input: torch.Tensor):
self.hessian_batch = self.hessian_batch + B

def register_hessian_hook(self):
"""Register updating hessian hook."""

@torch.no_grad()
def forward_pre_hook(module: Protocol, input: tuple):
Expand All @@ -110,10 +114,12 @@ def forward_pre_hook(module: Protocol, input: tuple):
self.sparse_gpt_handles.append(handle)

def remove_hessian_hook(self):
"""Remove updating hessian hook."""
for h in self.sparse_gpt_handles:
h.remove()

def init_hessian(self, device=None):
"""Init hessian."""
if dist.is_initialized():
if dist.get_rank() == 0:
self._hessian = torch.zeros([self.columns, self.columns],
Expand All @@ -130,6 +136,7 @@ def init_hessian(self, device=None):

@torch.no_grad()
def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):
"""The implementation for SparseGPT."""
with torch_setting(dtype=torch.float):
# Converted from /~https://github.com/ist-daslab/sparsegpt

Expand Down Expand Up @@ -224,13 +231,15 @@ def prune(self, sparsity, prunen=0, prunem=0, blocksize=128, percdamp=.01):


class SparseGptLinear(DynamicLinear, SparseGptMixIn):
"""Custom Linear for SparseGpt."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._sparse_gpt_mix_in_init()

@classmethod
def convert_from(cls, module: nn.Linear) -> 'DynamicConv2d':
"""Convert to cls from torch's module."""
if module.out_features < module.in_features:
return module
new_module = super().convert_from(module)
Expand All @@ -243,13 +252,15 @@ def convert_from(cls, module: nn.Linear) -> 'DynamicConv2d':


class SparseGptConv2d(DynamicConv2d, SparseGptMixIn):
"""Custom Conv2d for SparseGpt."""

def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self._sparse_gpt_mix_in_init()

@classmethod
def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
"""Convert to cls from torch's module."""
new_module = super().convert_from(module)
new_module.load_state_dict(module.state_dict(), strict=False)

Expand All @@ -259,6 +270,7 @@ def convert_from(cls, module: nn.Conv2d) -> 'DynamicConv2d':
return new_module

def format_input(self, input: torch.Tensor):
"""Format input shape."""
# input B C H W
input = F.unfold(
input, self.kernel_size, padding=self.padding,
Expand Down
14 changes: 14 additions & 0 deletions mmrazor/implementations/pruning/sparse_gpt/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,27 @@


class ModuleProtocol(Protocol):
"""Custom module protocol for algorithm mixin."""
weight: torch.Tensor

def forward(self, x):
"""The abstract method."""
pass

def register_forward_hook(self, hook):
"""The abstract method."""
pass

def register_backward_hook(self, hook):
"""The abstract method."""
pass

def register_forward_pre_hook(self, hook):
"""The abstract method."""
pass

def register_buffer(self, name, tensor):
"""The abstract method."""
pass


Expand All @@ -53,6 +59,7 @@ def replace_op(model: nn.Module, name: str, module: nn.Module):

def register_efficient_forward_hook(module: nn.Module,
device=torch.device('cuda:0')):
"""Register efficient forward hook."""

def forward_pre_hook(module: nn.Module, input):
module.to(device)
Expand All @@ -69,6 +76,7 @@ def forward_hook(module: nn.Module, input, output):
def enable_efficient_forward(model: nn.Module,
device=torch.device('cuda:0'),
wrap_modules=[]):
"""Enable efficient forward."""
handles = []
blocks = []
for name, module in model.named_children():
Expand All @@ -85,6 +93,7 @@ def enable_efficient_forward(model: nn.Module,


class memory_efficient_forward:
"""The class for Memory efficient forward."""

def __init__(self,
model: nn.Module,
Expand All @@ -101,26 +110,31 @@ def __init__(self,
model.to(device)

def __enter__(self, ):
"""Enter."""
if self.enabled:
handles, blocks = enable_efficient_forward(self.model, self.device,
self.wrap_modules)
print_log(f'enable memory efficient forward for {blocks}')
self.handlers = handles

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit."""
for h in self.handlers:
h.remove()


class torch_setting():
"""Set the default torch dtype setting."""

def __init__(self, dtype=None) -> None:
self.original_dtype = torch.get_default_dtype()
self.dtype = dtype

def __enter__(self):
"""Enter."""
if self.dtype is not None:
torch.set_default_dtype(self.dtype)

def __exit__(self, exc_type, exc_value, exc_traceback):
"""Exit."""
torch.set_default_dtype(self.original_dtype)
14 changes: 12 additions & 2 deletions mmrazor/implementations/quantization/gptq/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def replace_op(model: nn.Module, name: str, module: nn.Module):


def to_static_model(model: nn.Module):
"""Replace dynamicops with torch modules."""
from mmrazor.structures.subnet.fix_subnet import (export_fix_subnet,
load_fix_subnet)
fix_subnet = export_fix_subnet(model)[0]
Expand All @@ -46,8 +47,7 @@ def to_static_model(model: nn.Module):


class GPTQCompressor():

# init
"""The compressor with GPTQ."""

def __init__(self) -> None:
self.model: nn.Module = None
Expand All @@ -60,6 +60,7 @@ def prepare(self,
skipped_layers=[],
a_qconfig=None,
**kwargs) -> None:
"""Prepare for compressing model."""
self.model = model
quant_modules: dict = {}
if quant_conv:
Expand All @@ -72,19 +73,23 @@ def prepare(self,

@classmethod
def to_static_model(cls, model):
"""Convert replaced op with the original torch model."""
return to_static_model(model)

# hessian

def register_hessian_hooks(self):
"""Register updating hessian hooks for specified ops."""
for module in self.quant_ops:
module.register_hessian_hook()

def remove_hessian_hooks(self):
"""Remove updating hessian hooks for specified ops."""
for module in self.quant_ops:
module.remove_hessian_hook()

def init_hessian(self, device=None):
"""Init hessian."""
for op in self.quant_ops:
op.init_hessian(device=device)

Expand All @@ -96,6 +101,7 @@ def quant(self,
actorder=False,
device=torch.device('cuda:0'),
**qconfig):
"""Apply the compression algorithm to the model."""
for name, module in self.named_quant_ops:
try:
original_device = next(module.parameters()).device
Expand All @@ -115,6 +121,8 @@ def quant(self,
print_log(f'quant {name} failed as {e}')

def quant_with_default_qconfig(self, groupsize=128, device='cpu'):
"""Apply the compression algorithm to the model with the specified
setting."""
qconfig = dict(bits=4, perchannel=True, sym=False)
self.quant(
groupsize=groupsize, actorder=True, device=device, **qconfig)
Expand All @@ -123,13 +131,15 @@ def quant_with_default_qconfig(self, groupsize=128, device='cpu'):

@property
def quant_ops(self):
"""The ops to be applied the algorithm."""
assert self.model is not None
for module in self.model.modules():
if isinstance(module, GPTQMixIn):
yield module

@property
def named_quant_ops(self):
"""The named ops to be applied the algorithm."""
for name, module in self.model.named_modules():
if isinstance(module, GPTQMixIn):
yield name, module
8 changes: 6 additions & 2 deletions mmrazor/implementations/quantization/gptq/custom_autotune.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@


class Autotuner(triton.KernelInterface):
"""Autotuner."""

def __init__(self,
fn,
Expand Down Expand Up @@ -63,8 +64,8 @@ def _hook(args):
self.fn = fn

def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
"""Check for conflicts, i.e. meta-parameters both provided as kwargs
and by the autotuner."""
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
Expand Down Expand Up @@ -94,6 +95,7 @@ def kernel_call():
return (float('inf'), float('inf'), float('inf'))

def run(self, *args, **kwargs):
"""Run."""
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)
Expand Down Expand Up @@ -132,6 +134,7 @@ def run(self, *args, **kwargs):
**config.kwargs)

def prune_configs(self, kwargs):
"""Prune configs."""
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
Expand All @@ -154,6 +157,7 @@ def prune_configs(self, kwargs):
return pruned_configs

def warmup(self, *args, **kwargs):
"""Warm up."""
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
Expand Down
Loading