Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
hiyouga authored and Yu Wang committed Sep 5, 2024
1 parent bc36e0f commit b3e5780
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 40 deletions.
15 changes: 8 additions & 7 deletions src/llamafactory/model/model_utils/liger_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,20 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
if not is_trainable or not model_args.enable_liger_kernel:
return

if getattr(config, "model_type", None) == "gemma":
model_type = getattr(config, "model_type", None)
if model_type == "gemma":
from liger_kernel.transformers import apply_liger_kernel_to_gemma as apply_liger_kernel
elif getattr(config, "model_type", None) == "gemma2":
elif model_type == "gemma2":
from liger_kernel.transformers import apply_liger_kernel_to_gemma2 as apply_liger_kernel
elif getattr(config, "model_type", None) == "llama":
elif model_type == "llama":
from liger_kernel.transformers import apply_liger_kernel_to_llama as apply_liger_kernel
elif getattr(config, "model_type", None) == "mistral":
elif model_type == "mistral":
from liger_kernel.transformers import apply_liger_kernel_to_mistral as apply_liger_kernel
elif getattr(config, "model_type", None) == "mixtral":
elif model_type == "mixtral":
from liger_kernel.transformers import apply_liger_kernel_to_mixtral as apply_liger_kernel
elif getattr(config, "model_type", None) == "phi3":
elif model_type == "phi3":
from liger_kernel.transformers import apply_liger_kernel_to_phi3 as apply_liger_kernel
elif getattr(config, "model_type", None) == "qwen2":
elif model_type == "qwen2":
from liger_kernel.transformers import apply_liger_kernel_to_qwen2 as apply_liger_kernel
else:
logger.warning("Current model does not support liger kernel.")
Expand Down
12 changes: 6 additions & 6 deletions src/llamafactory/model/model_utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
r"""
Finds all available modules to apply lora or galore.
"""
model_type = getattr(model.config, "model_type", None)
forbidden_modules = {"lm_head"}

if model.config.model_type == "chatglm":
if model_type == "chatglm":
forbidden_modules.add("output_layer")
elif model.config.model_type == "internlm2":
elif model_type == "internlm2":
forbidden_modules.add("output")
elif model.config.model_type in ["llava", "paligemma"]:
elif model_type in ["llava", "paligemma"]:
forbidden_modules.add("multi_modal_projector")
elif model.config.model_type == "qwen2_vl":
elif model_type == "qwen2_vl":
forbidden_modules.add("merger")

if freeze_vision_tower:
if model.config.model_type == "qwen2_vl":
if model_type == "qwen2_vl":
forbidden_modules.add("visual")
else:
forbidden_modules.add("vision_tower")
Expand Down
20 changes: 11 additions & 9 deletions src/llamafactory/model/model_utils/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,42 +39,44 @@ def add_z3_leaf_module(model: "PreTrainedModel") -> None:
if not is_deepspeed_zero3_enabled():
return

if getattr(model.config, "model_type", None) == "dbrx":
model_type = getattr(model.config, "model_type", None)
if model_type == "dbrx":
from transformers.models.dbrx.modeling_dbrx import DbrxFFN

_set_z3_leaf_modules(model, [DbrxFFN])

if getattr(model.config, "model_type", None) == "jamba":
if model_type == "jamba":
from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock

_set_z3_leaf_modules(model, [JambaSparseMoeBlock])

if getattr(model.config, "model_type", None) == "jetmoe":
if model_type == "jetmoe":
from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE

_set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE])

if getattr(model.config, "model_type", None) == "mixtral":
if model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

_set_z3_leaf_modules(model, [MixtralSparseMoeBlock])

if getattr(model.config, "model_type", None) == "qwen2moe":
if model_type == "qwen2moe":
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock

_set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock])


def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
model_type = getattr(config, "model_type", None)
if model_args.moe_aux_loss_coef is not None:
if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]:
if model_type in ["jamba", "mixtral", "qwen2_moe"]:
setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef)

elif getattr(config, "model_type", None) == "deepseek":
elif model_type == "deepseek":
setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef)

elif getattr(config, "model_type", None) == "jetmoe":
elif model_type == "jetmoe":
setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef)

if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
if model_type in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]:
setattr(config, "output_router_logits", is_trainable)
40 changes: 24 additions & 16 deletions src/llamafactory/model/model_utils/visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,10 @@ def _mm_projector_forward_post_hook(
return output.to(model_args.compute_dtype)

if getattr(model, "quantization_method", None):
if getattr(model.config, "model_type", None) in ["llava", "paligemma"]:
model_type = getattr(model.config, "model_type", None)
if model_type in ["llava", "paligemma"]:
mm_projector: "torch.nn.Module" = getattr(model, "multi_modal_projector")
elif getattr(model.config, "model_type", None) == "qwen2_vl":
elif model_type == "qwen2_vl":
mm_projector: "torch.nn.Module" = getattr(getattr(model, "visual"), "merger")
else:
return
Expand All @@ -104,7 +105,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
r"""
Patches VLMs before loading them.
"""
if getattr(config, "model_type", None) == "llava": # required for ds zero3 and valuehead models
model_type = getattr(config, "model_type", None)
if model_type == "llava": # required for ds zero3 and valuehead models
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))

if getattr(config, "is_yi_vl_derived_model", None):
Expand All @@ -116,15 +118,16 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
r"""
Freezes vision tower and language model for VLM full/freeze tuning.
"""
model_type = getattr(config, "model_type", None)
forbidden_modules = set()
if getattr(config, "model_type", None) in ["llava", "paligemma"]:
if model_type in ["llava", "paligemma"]:
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("vision_tower")

if finetuning_args.train_mm_proj_only:
forbidden_modules.add("language_model")

elif getattr(config, "model_type", None) == "qwen2_vl":
elif model_type == "qwen2_vl":
if finetuning_args.freeze_vision_tower:
forbidden_modules.add("visual")

Expand All @@ -138,13 +141,14 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
r"""
Computes the number of special tokens per image.
"""
if getattr(config, "model_type", None) == "llava":
model_type = getattr(config, "model_type", None)
if model_type == "llava":
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
if getattr(config, "vision_feature_select_strategy", "default") == "full": # add [CLS] token
image_seqlen += 1
elif getattr(config, "model_type", None) == "paligemma":
elif model_type == "paligemma":
image_seqlen = config.vision_config.num_image_tokens
elif getattr(config, "model_type", None) == "qwen2_vl": # variable length
elif model_type == "qwen2_vl": # variable length
image_seqlen = -1

return image_seqlen
Expand All @@ -156,12 +160,16 @@ def patch_target_modules(
r"""
Freezes vision tower for VLM LoRA tuning.
"""
if not finetuning_args.freeze_vision_tower:
return target_modules

if getattr(config, "model_type", None) in ["llava", "paligemma"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif getattr(config, "model_type", None) == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
model_type = getattr(config, "model_type", None)
if finetuning_args.freeze_vision_tower:
if model_type in ["llava", "paligemma"]:
return "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules))
elif model_type == "qwen2_vl":
return "^(?!.*visual).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules
else:
return target_modules
if model_type == "qwen2_vl":
return "^(?!.*patch_embed).*(?:{}).*".format("|".join(target_modules))
else:
return target_modules
36 changes: 34 additions & 2 deletions src/llamafactory/train/sft/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@


def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor":
r"""
Computes the token with the largest likelihood to reduce memory footprint.
"""
if isinstance(logits, (list, tuple)):
if logits[0].dim() == 3: # (batch_size, seq_len, vocab_size)
logits = logits[0]
Expand All @@ -56,9 +59,38 @@ def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "tor


@dataclass
class ComputeMetrics:
class ComputeAccuracy:
r"""
Computes accuracy and supports `batch_eval_metrics`.
"""
def _dump(self) -> Optional[Dict[str, float]]:
result = None
if hasattr(self, "score_dict"):
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}

self.score_dict = {"accuracy": []}
return result

def __post_init__(self):
self._dump()

def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[Dict[str, float]]:
preds, labels = numpify(eval_preds.predictions), numpify(eval_preds.label_ids)
for i in range(len(preds)):
pred, label = preds[i, :-1], labels[i, 1:]
label_mask = label != IGNORE_INDEX
self.score_dict["accuracy"].append(np.mean(pred[label_mask] == label[label_mask]))

if compute_result:
return self._dump()


@dataclass
class ComputeSimilarity:
r"""
Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer.
Computes text similarity scores and supports `batch_eval_metrics`.
Wraps the tokenizer into metric functions, used in CustomSeq2SeqTrainer.
"""

tokenizer: "PreTrainedTokenizer"
Expand Down

0 comments on commit b3e5780

Please sign in to comment.