Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add GPTQ and uniform interfaces #538

Merged
merged 26 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
align acc & add save load ckpt & add ut
  • Loading branch information
humu789 committed May 23, 2023
commit f71040e13dcb703b3d50255d2ab7769df6367051
1 change: 1 addition & 0 deletions mmrazor/implementations/quantization/gptq/compressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def quant(self,
module: GPTQMixIn = module.to(device)
quantizer = Quantizer()
quantizer.configure(**qconfig)
# print_log(f'quant {name}...')
error = module.quant(
quantizer=quantizer,
blocksize=blocksize,
Expand Down
13 changes: 6 additions & 7 deletions mmrazor/implementations/quantization/gptq/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ def pack(self, scales, zeros, g_idx=None):
self.scales[self.g_idx[idx]]).to(torch.int)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
intweight = intweight.cpu().numpy().astype(np.uint32)
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]),
dtype=np.uint32)
Expand All @@ -189,10 +189,10 @@ def pack(self, scales, zeros, g_idx=None):
raise NotImplementedError('Only 2,4,8 bits are supported.')

qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
self.qweight = torch.from_numpy(qweight).to(self.weight.device)

zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
zeros = zeros.cpu().numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits),
dtype=np.uint32)
i = 0
Expand All @@ -207,7 +207,7 @@ def pack(self, scales, zeros, g_idx=None):
raise NotImplementedError('Only 2,4,8 bits are supported.')

qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
self.qzeros = torch.from_numpy(qzeros).to(self.weight.device)

@torch.no_grad()
def quant(self,
Expand Down Expand Up @@ -298,7 +298,6 @@ def quant(self,
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
W = W[:, invperm]
Q = Q[:, invperm]
g_idx = g_idx[invperm]

Expand All @@ -307,10 +306,10 @@ def quant(self,
zero.append(quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
self.weight_matrix = Q.data
self.weight_matrix = Q.data.to(self.weight_matrix.dtype)
if self.is_custom_kernel:
self.pack(scale, zero, g_idx)

del self.weight
return error

def free(self):
Expand Down
18 changes: 12 additions & 6 deletions mmrazor/implementations/quantization/gptq/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,12 +479,18 @@ def convert_from(cls, module: nn.Linear, bits, groupsize):

def forward(self, x):
"""Custom forward."""
out_shape = x.shape[:-1] + (self.out_features, )
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros,
self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
if torch.all(self.qweight == 0):
out = F.linear(x, self.weight, self.bias)
else:
# import pdb;pdb.set_trace()
out_shape = x.shape[:-1] + (self.out_features, )
out = QuantLinearFunction.apply(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales,
self.qzeros, self.g_idx, self.bits, self.maxq)
out = out + self.bias if self.bias is not None else out
out = out.reshape(out_shape)
# import pdb;pdb.set_trace()
return out


class GPTQLinear(DynamicLinear, GPTQMixIn):
Expand Down
86 changes: 54 additions & 32 deletions projects/mmrazor_large/examples/language_models/LLaMA/llama_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward
from mmrazor.implementations.quantization.gptq import GPTQLinear
from mmrazor.implementations.quantization.gptq import (GPTQLinear,
TritonGPTQLinear)
from mmrazor.utils import print_log


Expand All @@ -25,6 +26,13 @@ def disable_observer_linear(model):
module.fix_qparams = True


def del_redundant_attr(model):
print_log('Del redundant weight for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, TritonGPTQLinear):
del module.weight


def get_model(model):

def skip(*args, **kwargs):
Expand All @@ -48,7 +56,7 @@ def skip(*args, **kwargs):

parser.add_argument('model', type=str, help='Llama model to load')
parser.add_argument(
'dataset',
'--dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
Expand All @@ -69,6 +77,10 @@ def skip(*args, **kwargs):
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
parser.add_argument(
'--dev', type=str, default='cuda:0', help='Use which device.')
parser.add_argument(
'-m',
type=bool,
Expand All @@ -77,28 +89,25 @@ def skip(*args, **kwargs):

args = parser.parse_args()

DEV = torch.device('cuda:0')
DEV = args.dev

model = get_model(args.model)
model.to(DEV)
model.eval()
print_log('load model over')

dataloader, testloader = get_loaders(
args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log('load data for infer over')

from mmrazor.implementations.quantization import gptq
compressor = gptq.GPTQCompressor()
# use_triton_ops is True
compressor.prepare(model.model.layers,
quant_conv=True,
use_triton_ops=True,
quant_linear=True,
bits=4,
groupsize=128)

# # # quantize activation for linear
compressor.prepare(
model.model.layers,
quant_conv=True,
use_triton_ops=True,
quant_linear=True,
bits=4,
groupsize=128)

# # quantize activation for linear
# # a_qconfig = dict(bits=4, perchannel=False, sym=False)
# compressor.prepare(
# model.model.layers,
Expand All @@ -108,23 +117,38 @@ def skip(*args, **kwargs):
# # a_qconfig=a_qconfig
# )

compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)
if args.quant_ckpt:
del_redundant_attr(model)
model.load_state_dict(torch.load(args.quant_ckpt))
else:
dataloader, testloader = get_loaders(
args.dataset,
seed=args.seed,
model=args.model,
seqlen=model.seqlen)
print_log('load data for infer over')

compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model,
wrap_modules=[LlamaDecoderLayer],
enabled=args.m,
device=DEV):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)

disable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m):
model, wrap_modules=[LlamaDecoderLayer], enabled=args.m,
device=DEV):

# for dataset in ['wikitext2', 'ptb', 'c4']:
for dataset in ['wikitext2']:
Expand All @@ -133,8 +157,6 @@ def skip(*args, **kwargs):
print_log(dataset)
opt_eval(model, testloader, DEV, batch_size=args.batch_size)

if args.save:
# model = compressor.to_static_model(model)
if args.save and not args.quant_ckpt:
print_log(f'save model in {args.save}')
# model.save_pretrained(args.save)
torch.save(model.state_dict(), args.save)
104 changes: 62 additions & 42 deletions projects/mmrazor_large/examples/language_models/OPT/opt_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from mmrazor.implementations.pruning.sparse_gpt.utils import \
memory_efficient_forward
from mmrazor.implementations.quantization.gptq import GPTQLinear
from mmrazor.implementations.quantization.gptq import (GPTQLinear,
TritonGPTQLinear)
from mmrazor.utils import print_log


Expand All @@ -26,6 +27,13 @@ def disable_observer_linear(model):
module.fix_qparams = True


def del_redundant_attr(model):
print_log('Del redundant weight for GPTQLinear!')
for _, module in model.named_modules():
if isinstance(module, TritonGPTQLinear):
del module.weight


def get_model(model):

def skip(*args, **kwargs):
Expand All @@ -44,10 +52,9 @@ def skip(*args, **kwargs):
import argparse
parser = argparse.ArgumentParser()

parser.add_argument('model', type=str, help='Llama model to load')
parser.add_argument(
'model', type=str, help='OPT model to load; pass `facebook/opt-X`.')
parser.add_argument(
'dataset',
'--dataset',
type=str,
choices=['wikitext2', 'ptb', 'c4'],
help='Where to extract calibration data from.')
Expand All @@ -64,10 +71,14 @@ def skip(*args, **kwargs):
parser.add_argument(
'--batch_size',
type=int,
default=64,
default=16,
help='Batchsize for calibration and evaluation.')
parser.add_argument(
'--save', type=str, default='', help='Path to saved model.')
parser.add_argument(
'--quant_ckpt', type=str, default='', help='Quantized ckpt to load.')
parser.add_argument(
'--dev', type=str, default='cuda:0', help='Use which device.')
parser.add_argument(
'-m',
type=bool,
Expand All @@ -76,53 +87,63 @@ def skip(*args, **kwargs):

args = parser.parse_args()

DEV = torch.device('cuda:0')
DEV = args.dev

model = get_model(args.model)
model.to(DEV)
model.eval()
print_log('load model over')

dataloader, testloader = get_loaders(
args.dataset, seed=args.seed, model=args.model, seqlen=model.seqlen)
print_log('load data for infer over')

from mmrazor.implementations.quantization import gptq
compressor = gptq.GPTQCompressor()
# # use_triton_ops is True
# compressor.prepare(model.model.layers,
# quant_conv=True,
# use_triton_ops=True,
# quant_linear=True,
# bits=4,
# groupsize=128)

# # quantize activation for linear
# a_qconfig = dict(bits=4, perchannel=False, sym=False)
# use_triton_ops is True
compressor.prepare(
model.model.decoder,
model.model.layers,
quant_conv=True,
use_triton_ops=True,
quant_linear=True,
use_triton_ops=False,
# a_qconfig=a_qconfig
)

compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)
bits=4,
groupsize=128)

# # # quantize activation for linear
# # a_qconfig = dict(bits=4, perchannel=False, sym=False)
# compressor.prepare(
# model.model.decoder,
# quant_conv=True,
# quant_linear=True,
# use_triton_ops=False,
# # a_qconfig=a_qconfig
# )

if args.quant_ckpt:
del_redundant_attr(model)
model.load_state_dict(torch.load(args.quant_ckpt))
else:
dataloader, testloader = get_loaders(
args.dataset,
seed=args.seed,
model=args.model,
seqlen=model.seqlen)
print_log('load data for infer over')

compressor.init_hessian()
enable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m,
device=DEV):
compressor.register_hessian_hooks()
opt_infer(
model,
testloader,
DEV,
batch_size=args.batch_size,
num_samples=args.nsamples)
compressor.remove_hessian_hooks()
compressor.quant_with_default_qconfig(device=DEV)

disable_observer_linear(model)
with memory_efficient_forward(
model, wrap_modules=[OPTDecoderLayer], enabled=args.m):
model, wrap_modules=[OPTDecoderLayer], enabled=args.m, device=DEV):

# for dataset in ['wikitext2', 'ptb', 'c4']:
for dataset in ['wikitext2']:
Expand All @@ -131,7 +152,6 @@ def skip(*args, **kwargs):
print_log(dataset)
opt_eval(model, testloader, DEV, batch_size=args.batch_size)

if args.save:
model = compressor.to_static_model(model)
if args.save and not args.quant_ckpt:
print_log(f'save model in {args.save}')
model.save_pretrained(args.save)
torch.save(model.state_dict(), args.save)
Loading