From 25278e805f24d5d48eaa0638abb48de1b783a3fb Mon Sep 17 00:00:00 2001 From: Marc Sun Date: Fri, 12 Jul 2024 18:34:13 +0200 Subject: [PATCH] add qgalore --- src/transformers/trainer.py | 127 +++++++++++++++++++++++++ src/transformers/training_args.py | 2 + src/transformers/utils/__init__.py | 1 + src/transformers/utils/import_utils.py | 5 + 4 files changed, 135 insertions(+) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6c6964f2a46a..ae0aab6e0990 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -155,6 +155,7 @@ is_ipex_available, is_lomo_available, is_peft_available, + is_q_galore_torch_available, is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, @@ -1288,6 +1289,132 @@ def get_optimizer_cls_and_kwargs( optimizer_cls = torch.optim.Adagrad elif args.optim == OptimizerNames.RMSPROP: optimizer_cls = torch.optim.RMSprop + elif args.optim in [OptimizerNames.QGALORE_ADAMW_8BIT, OptimizerNames.QGALORE_ADAMW_8BIT_LAYERWISE]: + if not is_q_galore_torch_available(): + raise ImportError( + "You need to install `q-galore-torch` in order to use GaLore optimizers" + " install it with `pip install qgalore" + ) + from q_galore_torch import QGaLoreAdamW8bit + + is_layerwise = args.optim.lower().endswith("layerwise") + if is_layerwise and args.parallel_mode == ParallelMode.DISTRIBUTED: + # TODO: check if this is True + raise NotImplementedError("Layer-wise QGaLore does not support DDP at this time") + + optimizer_cls = QGaLoreAdamW8bit + + if args.optim_target_modules is None: + raise ValueError( + "You need to define a `optim_target_modules` in order to properly use QGaLore optimizers" + ) + if args.optim_target_modules is None: + raise ValueError( + "You need to define a `optim_target_modules` in order to properly use GaLore optimizers" + ) + + if not isinstance(args.optim_target_modules, (list, str)): + raise ValueError( + f"`optim_target_modules` has to be a list of strings, a string corresponding to a regex, or a specific module or 'all-linear', you passed {args.optim_target_modules}" + ) + + if model is None: + raise ValueError("You need to pass a model in order to correctly initialize a GaLore optimizer.") + + logger.warning( + "Activated GaLoRE fine-tuning, depending on your model size and hardware, the training might take a while before starting. Please be patient !" + ) + + all_linear = ( + isinstance(args.optim_target_modules, str) + and args.optim_target_modules.replace("_", "-") == "all-linear" + ) + + galore_params = [] + galore_params_names = [] + for module_name, module in model.named_modules(): + target_module_exists, is_regex = check_target_module_exists( + args.optim_target_modules, module_name, return_is_regex=True + ) + + if not isinstance(module, nn.Linear): + # Warn in case we match but it's not a linear layer + if target_module_exists and not is_regex: + logger.warning( + f"{module_name} has been matched but ignored as GaLore only supports linear layers. Please double check your `optim_target_modules`!" + ) + + continue + + if not target_module_exists and not all_linear: + continue + + galore_params.append(module.weight) + galore_params_names.append(module_name + ".weight") + + if len(galore_params) == 0: + raise ValueError( + f"None of the target modules were found! ({args.optim_target_modules}). Please make sure to pass a valid `target_modules`." + ) + + non_galore_params = [p for n, p in model.named_parameters() if n not in galore_params_names] + + # The default args are from the official repository: /~https://github.com/VITA-Group/Q-GaLore + galore_optim_kwargs = { + "rank": int(optim_args.pop("rank", 256)), + "update_proj_gap": int(optim_args.pop("update_proj_gap", 200)), + "scale": float(optim_args.pop("scale", 0.25)), + "proj_type": optim_args.pop("proj_type", "std"), + "quant": optim_args.pop("quant", True), + "quant_n_bit": optim_args.pop("quant_n_bit", 4), + "quant_group_size": optim_args.pop("quant_group_size", 256), + "cos_threshold": optim_args.pop("cos_threshold", 0.4), + "gamma_proj": optim_args.pop("gamma_proj", 2), + "queue_size": optim_args.pop("queue_size", 5), + } + + param_groups = [ + {"params": non_galore_params}, + {"params": galore_params, **galore_optim_kwargs}, + ] + + if is_layerwise: + # For layer-wise optimizers, the optimization step is done through post accumulation + # gradient hooks. The trick is to first attach these hooks to the model parameters then + # create a dummy optimizer that will perform no-ops in the Trainer. + # See the original implementation or the nice implementation from @hiyouga + # here: /~https://github.com/hiyouga/LLaMA-Factory/commit/8664262cde3919e10eaecbd66e8c5d356856362e#diff-ebe08ab14496dfb9e06075f0fdd36799ef6d1535cc4dd4715b74c4e3e06fe3ba + if args.gradient_accumulation_steps != 1: + raise ValueError("Layerwise QGaLoRE optimizer do not support gradient accumulation !") + + optimizer_dict = {} + for param in non_galore_params: + if param.requires_grad: + param_groups = [{"params": [param]}] + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) + # TODO: in the original repo, they multiply update_proj_gap param by 2, to check + for param in galore_params: + param_groups = [{"params": [param], **galore_optim_kwargs}] + optimizer_dict[param] = optimizer_cls(param_groups, **optimizer_kwargs) + + def optimizer_hook(param): + if (not hasattr(param, "float_grad")) and param.grad is None: + return + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() + + id_galore_params = [id(p) for p in galore_params] + + # TODO: strange, we are not applying on every param here compared to galore + for param in model.parameters(): + if id(param) in id_galore_params or param.requires_grad: + setattr(param, "backward_hook", optimizer_hook) + + optimizer_cls = LayerWiseDummyOptimizer + optimizer_kwargs.update({"optimizer_dict": optimizer_dict}) + + optimizer_kwargs.update({"params": param_groups}) + elif args.optim in [ OptimizerNames.GALORE_ADAMW, OptimizerNames.GALORE_ADAMW_8BIT, diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 66e2ea923e8e..0e2efa3c55fd 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -174,6 +174,8 @@ class OptimizerNames(ExplicitEnum): GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise" GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise" GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" + QGALORE_ADAMW_8BIT = "qgalore_adamw_8bit" + QGALORE_ADAMW_8BIT_LAYERWISE = "qgalore_adamw_8bit_layerwise" LOMO = "lomo" ADALOMO = "adalomo" diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index f506a523d1a3..ba72c0118c75 100755 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -164,6 +164,7 @@ is_pytesseract_available, is_pytest_available, is_pytorch_quantization_available, + is_q_galore_torch_available, is_quanto_available, is_rjieba_available, is_sacremoses_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 7b7b981b8735..ae5ee5b37906 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -99,6 +99,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[ _bitsandbytes_available = _is_package_available("bitsandbytes") _eetq_available = _is_package_available("eetq") _galore_torch_available = _is_package_available("galore_torch") +_q_galore_torch_available = _is_package_available("q_galore_torch") _lomo_available = _is_package_available("lomo_optim") _torchao_available = _is_package_available("torchao") # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. @@ -346,6 +347,10 @@ def is_galore_torch_available(): return _galore_torch_available +def is_q_galore_torch_available(): + return _q_galore_torch_available + + def is_lomo_available(): return _lomo_available