diff --git a/comps/llms/deployment/docker_compose/compose_text-generation.yaml b/comps/llms/deployment/docker_compose/compose_text-generation.yaml index fbf503ed6..ad353bb58 100644 --- a/comps/llms/deployment/docker_compose/compose_text-generation.yaml +++ b/comps/llms/deployment/docker_compose/compose_text-generation.yaml @@ -45,6 +45,29 @@ services: - SYS_NICE restart: unless-stopped + textgen-gaudi-enhance: + image: ${REGISTRY:-opea}/llm-textgen-gaudi-enhance:${TAG:-latest} + container_name: llm-textgen-gaudi-enhance-server + ports: + - ${TEXTGEN_PORT:-9000}:9000 + volumes: + - "${DATA_PATH:-./data}:/data" + ipc: host + environment: + no_proxy: ${no_proxy} + http_proxy: ${http_proxy} + https_proxy: ${https_proxy} + LLM_MODEL_ID: ${LLM_MODEL_ID} + HF_TOKEN: ${HF_TOKEN} + HABANA_VISIBLE_DEVICES: all + OMPI_MCA_btl_vader_single_copy_mechanism: none + TOKENIZERS_PARALLELISM: False + LOGFLAG: ${LOGFLAG:-False} + runtime: habana + cap_add: + - SYS_NICE + restart: unless-stopped + textgen-service-tgi: extends: textgen container_name: textgen-service-tgi @@ -100,6 +123,18 @@ services: environment: LLM_COMPONENT_NAME: ${LLM_COMPONENT_NAME:-OpeaTextGenNative} + textgen-native-gaudi-enhance: + extends: textgen-gaudi-enhance + container_name: textgen-native-gaudi-enhance + environment: + LLM_COMPONENT_NAME: ${LLM_COMPONENT_NAME:-OpeaTextGenNativeEnhance} + + textgen-native-gaudi-enhance-multimodal: + extends: textgen-gaudi-enhance + container_name: textgen-native-gaudi-enhance-multimodal + environment: + LLM_COMPONENT_NAME: ${LLM_COMPONENT_NAME:-OpeaTextGenNativeEnhanceMultimodal} + networks: default: driver: bridge diff --git a/comps/llms/src/text-generation/Dockerfile.intel_hpu_enhance b/comps/llms/src/text-generation/Dockerfile.intel_hpu_enhance new file mode 100644 index 000000000..b9eb2558f --- /dev/null +++ b/comps/llms/src/text-generation/Dockerfile.intel_hpu_enhance @@ -0,0 +1,32 @@ +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +# HABANA environment +FROM vault.habana.ai/gaudi-docker/1.19.0/ubuntu22.04/habanalabs/pytorch-installer-2.5.1 AS hpu + +ENV LANG=en_US.UTF-8 + +RUN apt-get update && apt-get install -y --no-install-recommends --fix-missing \ + git-lfs \ + libgl1-mesa-glx \ + libjemalloc-dev + +RUN mkdir -p /home/user + +RUN git lfs install + +COPY comps /home/user/comps + +RUN pip install --no-cache-dir --upgrade pip setuptools && \ + pip install --no-cache-dir --upgrade-strategy eager optimum[habana] && \ + pip install --no-cache-dir git+/~https://github.com/HabanaAI/DeepSpeed.git@1.19.0 + +RUN pip install git+/~https://github.com/huggingface/optimum-habana.git@transformers_future && \ + cd /home/user/comps/llms/src/text-generation/ && pip install --no-cache-dir -r requirements.txt && \ + pip install soundfile peft backoff + +ENV PYTHONPATH=/root:/home/user + +WORKDIR /home/user/comps/llms/src/text-generation/ + +ENTRYPOINT ["bash", "entrypoint_enhance.sh"] diff --git a/comps/llms/src/text-generation/entrypoint_enhance.sh b/comps/llms/src/text-generation/entrypoint_enhance.sh new file mode 100644 index 000000000..6a801ba78 --- /dev/null +++ b/comps/llms/src/text-generation/entrypoint_enhance.sh @@ -0,0 +1,30 @@ +#!/usr/bin/env bash + +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 +#!/bin/bash + +#LLM_MODEL_ID mush be a model path +llm_name=$LLM_MODEL_ID +WORKPATH="/home/user/comps/llms/src/text-generation/" + +if [[ $llm_name == *"phi-4-multimodel"* ]]; then + cd $WORKPATH + echo -e "Patching into the multimodal models" + cp patch/enhance-multimodal-patch/*.py $llm_name/ + export PT_HPU_LAZY_MODE=1 +elif [[ $llm_name == *"phi-4"* ]]; then + cd $WORKPATH + git clone -b transformers_future /~https://github.com/huggingface/optimum-habana + cd optimum-habana + cp ../patch/optimum-habana-enhance.patch . + git apply optimum-habana-enhance.patch + pip install -e . + cd examples/text-generation/ + pip install -r requirements.txt + cd phi-4-mini-instruct/ + bash ./01-patch-transformer.sh +fi + +cd $WORKPATH +python opea_llm_microservice.py diff --git a/comps/llms/src/text-generation/integrations/native_enhance.py b/comps/llms/src/text-generation/integrations/native_enhance.py new file mode 100644 index 000000000..eee867ec6 --- /dev/null +++ b/comps/llms/src/text-generation/integrations/native_enhance.py @@ -0,0 +1,269 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +sys.path.append("/test/GenAIComps/") + +import os +import threading +import time + +import torch +from langchain_core.prompts import PromptTemplate + +from comps import CustomLogger, GeneratedDoc, OpeaComponent, OpeaComponentRegistry, ServiceType +from comps.cores.proto.api_protocol import ChatCompletionRequest + +from .template import ChatTemplate +from .utils import initialize_model + +logger = CustomLogger("opea_textgen_native") +logflag = os.getenv("LOGFLAG", False) + +MODEL_NAME = os.getenv("LLM_MODEL_ID", "phi/phi-4") + +input_sentences = [ + "DeepSpeed is a machine learning framework", + "He is working on", + "He has a", + "He got all", + "Everyone is happy and I can", + "The new movie that got Oscar this year", + "In the far far distance from our galaxy,", + "Peace is the only way", +] + +args_dict = { + "device": "hpu", + "model_name_or_path": MODEL_NAME, + "bf16": True, + "max_new_tokens": 128, + "max_input_tokens": 128, + "batch_size": 1, + "warmup": 3, + "n_iterations": 1, + "local_rank": 0, + "use_kv_cache": False, + "use_hpu_graphs": True, + "dataset_name": None, + "column_name": None, + "do_sample": False, + "num_beams": 1, + "trim_logits": False, + "seed": 27, + "profiling_warmup_steps": 0, + "profiling_steps": 0, + "profiling_record_shapes": False, + "prompt": None, + "bad_words": None, + "force_words": None, + "assistant_model": None, + "peft_model": None, + "num_return_sequences": 1, + "token": None, + "model_revision": "main", + "attn_softmax_bf16": True, + "output_dir": None, + "bucket_size": -1, + "bucket_internal": False, + "dataset_max_samples": -1, + "limit_hpu_graphs": True, + "reuse_cache": False, + "verbose_workers": False, + "simulate_dyn_prompt": None, + "reduce_recompile": False, + "use_flash_attention": True, + "flash_attention_recompute": True, + "flash_attention_causal_mask": True, + "flash_attention_fast_softmax": False, + "book_source": False, + "torch_compile": False, + "ignore_eos": True, + "temperature": 1.0, + "top_p": 1.0, + "top_k": None, + "const_serialization_path": None, + "disk_offload": False, + "trust_remote_code": False, + "quant_config": "", + "world_size": 0, + "show_graphs_count": False, + "load_quantized_model_with_inc": False, + "local_quantized_inc_model_path": None, + "load_quantized_model_with_autogptq": False, + "penalty_alpha": None, +} + + +class Args: + def __init__(self, **entries): + self.__dict__.update(entries) + + +model = None +assistant_model = None +tokenizer = None +generation_config = None +args = Args(**args_dict) +initialization_lock = threading.Lock() +initialized = False + + +def generate( + input_query: list, + device="hpu", + use_lazy_mode=True, + use_hpu_graphs=True, + profiling_steps=0, + profiling_warmup_steps=0, + ignore_eos=True, + profiling_record_shapes=False, +): + """Generates sequences from the input sentences and returns them.""" + logger.info(f"[llm - generate] starting to inference with prompt {input_query}") + encode_t0 = time.perf_counter() + + # Tokenization + input_tokens = tokenizer.batch_encode_plus( + input_query, + return_tensors="pt", + padding=True, + return_token_type_ids=False, # token_type_ids is not needed for falcon-three model + ) + encode_duration = time.perf_counter() - encode_t0 + logger.info(f"[llm - generate] input tokenized: {input_tokens}") + + # Move inputs to target device(s) + for t in input_tokens: + logger.info(f"[llm - generate] t: {t}") + if torch.is_tensor(input_tokens[t]): + logger.info("[llm - generate] input[t] is tensor") + logger.info(f"[llm - generate] device: {model.device}") + input_tokens[t] = input_tokens[t].to(model.device) + + logger.info("[llm - generate] inputs transferred.") + + iteration_times = [] + outputs = model.generate( + **input_tokens, + generation_config=generation_config, + assistant_model=assistant_model, + lazy_mode=use_lazy_mode, + hpu_graphs=use_hpu_graphs, + profiling_steps=profiling_steps, + profiling_warmup_steps=profiling_warmup_steps, + ignore_eos=ignore_eos, + iteration_times=iteration_times, + profiling_record_shapes=profiling_record_shapes, + ).cpu() + logger.info("[llm - generate] result generated") + first_token_time = iteration_times[0] + encode_duration + result = tokenizer.batch_decode(outputs, skip_special_tokens=True) + logger.info(f"[llm - generate] result: {result}") + logger.info(f"[llm - generate] Time to first token = {first_token_time*1000}ms") + return result + + +def initialize(): + global model, assistant_model, tokenizer, generation_config, initialized + with initialization_lock: + if not initialized: + # initialize model and tokenizer + import habana_frameworks.torch.hpu as torch_hpu + from optimum.habana.utils import HabanaProfile + + model, assistant_model, tokenizer, generation_config = initialize_model(args, logger) + logger.info("[llm] model and tokenizer initialized.") + + # compilation and model warmup + HabanaProfile.disable() + logger.info("[llm - native] Graph compilation...") + for _ in range(args.warmup): + generate(input_sentences) + logger.info("[llm - native] model warm up finished.") + torch_hpu.synchronize() + HabanaProfile.enable() + logger.info("[llm - native] Ready to inference") + res = generate(["What is Deep Learning?"]) + logger.info(f"[llm - native] test result: {res}") + initialized = True + + +@OpeaComponentRegistry.register("OpeaTextGenNativeEnhance") +class OpeaTextGenNativeEnhance(OpeaComponent): + """A specialized OPEA TextGen component derived from OpeaComponent for interacting with LLM services based on native optimum habana.""" + + def __init__(self, name: str, description: str, config: dict = None): + super().__init__(name, ServiceType.LLM.name.lower(), description, config) + initialize() + health_status = self.check_health() + if not health_status: + logger.error("OpeaTextGenNativeEnhance health check failed.") + else: + logger.info("OpeaTextGenNativeEnhance health check success.") + + def check_health(self) -> bool: + """Checks the health of the LLM service. + + Returns: + bool: True if the service is reachable and healthy, False otherwise. + """ + + try: + return initialized + except Exception as e: + logger.error(e) + logger.error("Health check failed") + return False + + async def invoke(self, input: ChatCompletionRequest): + """Invokes the LLM service to generate output for the provided input. + + Args: + input (ChatCompletionRequest): The input text(s). + """ + + message = None + if isinstance(input.messages, str): + message = input.messages + else: # List[Dict] + for input_data in input.messages: + if "role" in input_data and input_data["role"] == "user" and "content" in input_data: + message = input_data["content"] + if logflag: + logger.info(f"Get input text:\n {message}") + if message is None: + logger.error("Don't receive any input text, exit!") + return GeneratedDoc(text=None, prompt=None) + + prompt = message + prompt_template = None + if input.chat_template: + prompt_template = PromptTemplate.from_template(input.chat_template) + input_variables = prompt_template.input_variables + if prompt_template: + if sorted(input_variables) == ["context", "question"]: + prompt = prompt_template.format(question=message, context="\n".join(input.documents)) + elif input_variables == ["question"]: + prompt = prompt_template.format(question=message) + else: + logger.info(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']") + else: + if input.documents: + prompt = ChatTemplate.generate_rag_prompt(message, input.documents) + res = generate([prompt]) + + if logflag: + logger.info(f"[llm - native] inference result: {res}") + return GeneratedDoc(text=res[0], prompt=message) diff --git a/comps/llms/src/text-generation/integrations/native_enhance_multimodal.py b/comps/llms/src/text-generation/integrations/native_enhance_multimodal.py new file mode 100644 index 000000000..aea127c9f --- /dev/null +++ b/comps/llms/src/text-generation/integrations/native_enhance_multimodal.py @@ -0,0 +1,205 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import sys + +sys.path.append("/test/GenAIComps/") + +import os +import threading +import time + +import habana_frameworks.torch.core as htcore +import soundfile +import torch +from langchain_core.prompts import PromptTemplate +from PIL import Image +from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig + +from comps import CustomLogger, GeneratedDoc, OpeaComponent, OpeaComponentRegistry, ServiceType +from comps.cores.proto.api_protocol import ChatCompletionRequest + +from .template import ChatTemplate + +logger = CustomLogger("opea_textgen_native") +logflag = os.getenv("LOGFLAG", False) + +MODEL_NAME = os.getenv("LLM_MODEL_ID", "phi-4-multimodel") + +model = None +processor = None +generation_config = None +initialization_lock = threading.Lock() +initialized = False + +kwargs = {} +kwargs["torch_dtype"] = torch.bfloat16 + +user_prompt = "<|user|>" +assistant_prompt = "<|assistant|>" +prompt_suffix = "<|end|>" +IMAGE_SPECIAL = "<|endoftext10|>" +AUDIO_SPECIAL = "<|endoftext11|>" +sample_prompt = f"{user_prompt}what is the answer for 1+1? Explain it.{prompt_suffix}{assistant_prompt}" +if logflag: + logger.info(f">>> Prompt\n{sample_prompt}") + +generation_config = GenerationConfig.from_pretrained(MODEL_NAME, "generation_config.json") + +# generation_config.max_new_tokens = args.max_new_tokens +# generation_config.use_cache = args.use_kv_cache +generation_config.static_shapes = False # There's a list of models optimized with static shapes +generation_config.bucket_size = -1 +generation_config.bucket_internal = False +# generation_config.do_sample = args.do_sample +# generation_config.num_beams = args.num_beams +# generation_config.top_k = args.top_k +# generation_config.penalty_alpha = args.penalty_alpha +# generation_config.bad_words_ids = bad_words_ids +# generation_config.force_words_ids = force_words_ids +# generation_config.num_return_sequences = args.num_return_sequences +generation_config.trim_logits = True +generation_config.attn_softmax_bf16 = False +generation_config.limit_hpu_graphs = False +generation_config.clear_hpu_graphs_cache = False +generation_config.reuse_cache = False +generation_config.reduce_recompile = False +# if generation_config.reduce_recompile: +# assert generation_config.bucket_size > 0 +generation_config.use_flash_attention = False +generation_config.flash_attention_recompute = False +generation_config.flash_attention_causal_mask = False +generation_config.flash_attention_fast_softmax = False +# generation_config.trust_remote_code = args.trust_remote_code +generation_config.valid_sequence_lengths = None # OkS +generation_config.attn_batch_split = False +generation_config.ignore_eos = None + + +def generate( + query, +): + """Generates sequences from the input sentences and returns them.""" + logger.info(f"[llm - generate] starting to inference with prompt {query}") + inputs = processor(query, images=None, return_tensors="pt").to("hpu:0") + + generate_ids = model.generate( + **inputs, + max_new_tokens=100, + generation_config=generation_config, + ) + generate_ids = generate_ids[:, inputs["input_ids"].shape[1] :] + response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + + if logflag: + logger.info(response) + print(f">>> Response\n{response}") + + return response + + +def initialize(): + global model, processor, generation_config, initialized + with initialization_lock: + if not initialized: + processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + MODEL_NAME, + trust_remote_code=True, + torch_dtype="auto", + _attn_implementation="sdpa", + ) + model = model.to("hpu") + if logflag: + logger.info(processor.tokenizer) + logger.info("model.config._attn_implementation:", model.config._attn_implementation) + logger.info("[llm] model and processor initialized.") + + # Must put after the models are downloaded because this has custom remote code that needs to be loaded first for the OH to load the override functions + from optimum.habana.transformers.modeling_utils import adapt_transformers_to_gaudi + + adapt_transformers_to_gaudi() + + logger.info("[llm - native] Ready to inference") + res = generate(sample_prompt) + logger.info(f"[llm - native] test result: {res}") + initialized = True + + +@OpeaComponentRegistry.register("OpeaTextGenNativeEnhanceMultimodal") +class OpeaTextGenNativeEnhanceMultimodal(OpeaComponent): + """A specialized OPEA TextGen component derived from OpeaComponent for interacting with LLM services based on native optimum habana.""" + + def __init__(self, name: str, description: str, config: dict = None): + super().__init__(name, ServiceType.LLM.name.lower(), description, config) + initialize() + health_status = self.check_health() + if not health_status: + logger.error("OpeaTextGenNativeEnhanceMultimodal health check failed.") + else: + logger.info("OpeaTextGenNativeEnhanceMultimodal health check success.") + + def check_health(self) -> bool: + """Checks the health of the LLM service. + + Returns: + bool: True if the service is reachable and healthy, False otherwise. + """ + + try: + return initialized + except Exception as e: + logger.error(e) + logger.error("Health check failed") + return False + + async def invoke(self, input: ChatCompletionRequest): + """Invokes the LLM service to generate output for the provided input. + + Args: + input (ChatCompletionRequest): The input text(s). + """ + + message = None + if isinstance(input.messages, str): + message = input.messages + else: # List[Dict] + for input_data in input.messages: + if "role" in input_data and input_data["role"] == "user" and "content" in input_data: + message = input_data["content"] + if logflag: + logger.info(f"Get input text:\n {message}") + if message is None: + logger.error("Don't receive any input text, exit!") + return GeneratedDoc(text=None, prompt=None) + + prompt = message + prompt_template = None + if input.chat_template: + prompt_template = PromptTemplate.from_template(input.chat_template) + input_variables = prompt_template.input_variables + if prompt_template: + if sorted(input_variables) == ["context", "question"]: + prompt = prompt_template.format(question=message, context="\n".join(input.documents)) + elif input_variables == ["question"]: + prompt = prompt_template.format(question=message) + else: + logger.info(f"{prompt_template} not used, we only support 2 input variables ['question', 'context']") + else: + if input.documents: + prompt = ChatTemplate.generate_rag_prompt(message, input.documents) + res = generate(prompt) + + if logflag: + logger.info(f"[llm - native] inference result: {res}") + return GeneratedDoc(text=res[0], prompt=message) diff --git a/comps/llms/src/text-generation/opea_llm_microservice.py b/comps/llms/src/text-generation/opea_llm_microservice.py index c59db8e47..e6c0a247c 100644 --- a/comps/llms/src/text-generation/opea_llm_microservice.py +++ b/comps/llms/src/text-generation/opea_llm_microservice.py @@ -28,6 +28,10 @@ if llm_component_name == "OpeaTextGenNative": from integrations.native import OpeaTextGenNative +elif llm_component_name == "OpeaTextGenNativeEnhance": + from integrations.native_enhance import OpeaTextGenNativeEnhance +elif llm_component_name == "OpeaTextGenNativeEnhanceMultimodal": + from integrations.native_enhance_multimodal import OpeaTextGenNativeEnhanceMultimodal elif llm_component_name == "OpeaTextGenBedrock": from integrations.bedrock import OpeaTextGenBedrock else: diff --git a/comps/llms/src/text-generation/patch/enhance-multimodal-patch/configuration_phio.py b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/configuration_phio.py new file mode 100644 index 000000000..943e2f8a9 --- /dev/null +++ b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/configuration_phio.py @@ -0,0 +1,229 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Phi-O model configuration.""" + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class PhiOConfig(PretrainedConfig): + r"""This is the configuration class to store the configuration of a [`PhiOModel`]. It is used to instantiate a Phi-O + model according to the specified arguments, defining the model architecture. + + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + + Args: + vocab_size (`int`, *optional*, defaults to 200064): + Vocabulary size of the Phi-O model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`PhiOModel`]. + hidden_size (`int`, *optional*, defaults to 3072): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 8192): + Dimension of the MLP representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + resid_pdrop (`float`, *optional*, defaults to 0.0): + Dropout probability for mlp outputs. + embd_pdrop (`int`, *optional*, defaults to 0.0): + The dropout ratio for the embeddings. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio after computing the attention scores. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model might ever be used with. + original_max_position_embeddings (`int`, *optional*, defaults to 4096): + The maximum sequence length that this model was trained with. This is used to determine the size of the + original RoPE embeddings when using long scaling. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon value used for the RMSNorm. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`dict`, *optional*): + The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must + contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and + the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size + divided by the number of attention heads divided by 2. + partial_rotary_factor (`float`, *optional*, defaults to 0.5): + Percentage of the query and keys which will have rotary embedding. + bos_token_id (`int`, *optional*, defaults to 199999): + The id of the "beginning-of-sequence" token. + eos_token_id (`int`, *optional*, defaults to 199999): + The id of the "end-of-sequence" token. + pad_token_id (`int`, *optional*, defaults to 199999): + The id of the padding token. + sliding_window (`int`, *optional*): + Sliding window attention window size. If `None`, no sliding window is applied. + + Example: + + ```python + >>> from transformers import PhiOModel, PhiOConfig + + >>> # Initializing a Phi-O style configuration + >>> configuration = PhiOConfig.from_pretrained("TBA") + + >>> # Initializing a model from the configuration + >>> model = PhiOModel(configuration) + + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "phio" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=200064, + hidden_size=3072, + intermediate_size=8192, + num_hidden_layers=32, + num_attention_heads=32, + num_key_value_heads=None, + resid_pdrop=0.0, + embd_pdrop=0.0, + attention_dropout=0.0, + hidden_act="silu", + max_position_embeddings=4096, + original_max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-5, + use_cache=True, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + partial_rotary_factor=1, + bos_token_id=199999, + eos_token_id=199999, + pad_token_id=199999, + sliding_window=None, + embd_layer: str = "default", + img_processor=None, + audio_processor=None, + vision_lora=None, + speech_lora=None, + **kwargs, + ): + self.embd_layer = embd_layer + self.img_processor = img_processor + self.audio_processor = audio_processor + self.vision_lora = vision_lora + self.speech_lora = speech_lora + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attention_dropout = attention_dropout + self.hidden_act = hidden_act + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.partial_rotary_factor = partial_rotary_factor + self._rope_scaling_adjustment() + self._rope_scaling_validation() + self.sliding_window = sliding_window + + super().__init__( + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) + + def _rope_scaling_adjustment(self): + """Adjust the `type` of the `rope_scaling` configuration for backward compatibility.""" + if self.rope_scaling is None: + return + + rope_scaling_type = self.rope_scaling.get("type", None) + + # For backward compatibility if previous version used "su" or "yarn" + if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: + self.rope_scaling["type"] = "longrope" + + def _rope_scaling_validation(self): + """Validate the `rope_scaling` configuration.""" + if self.rope_scaling is None: + return + + if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: + raise ValueError( + "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " + f"got {self.rope_scaling}" + ) + rope_scaling_type = self.rope_scaling.get("type", None) + rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) + rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) + if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: + raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") + if not ( + isinstance(rope_scaling_short_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) + ): + raise ValueError( + f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" + ) + rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor) + if not len(rope_scaling_short_factor) == rotary_ndims // 2: + raise ValueError( + f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}" + ) + if not ( + isinstance(rope_scaling_long_factor, list) + and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) + ): + raise ValueError( + f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" + ) + if not len(rope_scaling_long_factor) == rotary_ndims // 2: + raise ValueError( + f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}" + ) diff --git a/comps/llms/src/text-generation/patch/enhance-multimodal-patch/modeling_phio.py b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/modeling_phio.py new file mode 100644 index 000000000..a519266f4 --- /dev/null +++ b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/modeling_phio.py @@ -0,0 +1,2524 @@ +# coding=utf-8 +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch Phi-O model.""" +import math +import warnings +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import PretrainedConfig +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache +from transformers.generation import GenerationMixin +from transformers.modeling_attn_mask_utils import AttentionMaskConverter +from transformers.modeling_flash_attention_utils import _flash_attention_forward +from transformers.modeling_outputs import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + SequenceClassifierOutputWithPast, + TokenClassifierOutput, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + add_code_sample_docstrings, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) + +from .configuration_phio import PhiOConfig +from .processing_phio import InputMode +from .speech_conformer_encoder import ConformerEncoder +from .vision_siglip_navit import get_siglip_vision_model + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "TBA" +_CONFIG_FOR_DOC = "PhiOConfig" + +# Special token ids +_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>', or we can better name it (in `tokenizer_config.json`) +_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>' +_COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE = [-9999, -1] # For backward compatibility +_COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE = [float("-inf"), -10000] # For backward compatibility + + +class PhiOImageEmbedding(nn.Module): + """Image embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + + # n_embed or hidden_size + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size + if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): + embd_drop = config.embd_pdrop if hasattr(config, "embd_pdrop") else config.embed_pdrop + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + logger.info(f"create image tower {config.img_processor}") + enable_gradient_checkpointing = kwargs.get("enable_gradient_checkpointing", False) + + # Load SigLIP model + self.img_processor = get_siglip_vision_model(_flash_attn_2_enabled=False) + + pe_weight = self.img_processor.embeddings.position_embedding.weight + L, D = pe_weight.size() + H = int(math.sqrt(L)) + assert H**2 == L + if H % 2 != 0: # and kwargs.get('image_token_compression_cls', None) is None: + self.img_processor_padding = nn.ReflectionPad2d((0, 1, 0, 1)) + H += 1 + image_dim_out = D + # ((448/14)//2)**2 + self.num_img_tokens = (H // 2) ** 2 + self.base_feat_height_target = H + + if enable_gradient_checkpointing: + self.img_processor.encoder.gradient_checkpointing = True + + self.image_dim_out = image_dim_out + self.img_sizes = None + self.image_attention_mask = None + + # global_gn and sub_gn for hd transform, serves as line separator + self.use_hd_transform = kwargs.get("use_hd_transform", False) + self.with_learnable_separator = kwargs.get("with_learnable_separator", False) + self.hd_transform_order = kwargs.get("hd_transform_order", "glb_sub") + self.freeze_img_processor = kwargs.get("freeze_img_processor", False) + self.crop_size = kwargs.get("crop_size", 336) + logger.info(f"freeze_img_processor = {self.freeze_img_processor}") + + # image token compression + self.image_token_compression_cls = kwargs.get("image_token_compression_cls", None) + if self.image_token_compression_cls == "avg_pool_2d": + self.image_token_compression = nn.AvgPool2d(kernel_size=2, stride=2) + self.base_feat_height_reduction = 1 + self.base_feat_height_target = self.base_feat_height_target // 2 + elif self.image_token_compression_cls is None: + self.image_token_compression = None + self.base_feat_height_reduction = 2 + else: + raise NotImplementedError( + f"image_token_compression_cls = {self.image_token_compression_cls}, not implemented" + ) + + # with_hd_transform and with_learnable_separator should have same value + assert ( + self.use_hd_transform == self.with_learnable_separator + ), "use_hd_transform and with_learnable_separator should have same value" + if self.with_learnable_separator: + assert self.use_hd_transform, "learnable separator is only for hd transform" + # 1024 * 4, merge spatial to channel dimension + self.glb_GN = nn.Parameter(torch.zeros([1, 1, self.image_dim_out * self.base_feat_height_reduction**2])) + self.sub_GN = nn.Parameter(torch.zeros([1, 1, 1, self.image_dim_out * self.base_feat_height_reduction**2])) + logger.info(f"learnable separator enabled for hd transform, hd_transform_order = {self.hd_transform_order}") + + projection_cls = kwargs.get("projection_cls", "linear") + if projection_cls == "linear": + self.img_projection = nn.Linear(image_dim_out, hidden_size) + elif projection_cls == "mlp" and self.use_hd_transform: + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out * self.base_feat_height_reduction**2, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + elif projection_cls == "mlp": + # follow llava-v1.5's implementation + # (do not use image_projection and image_proj_norm) + dim_projection = hidden_size + depth = 2 + layers = [nn.Linear(image_dim_out, dim_projection)] + for _ in range(1, depth): + layers.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + self.img_projection = nn.Sequential(*layers) + else: + raise NotImplementedError(f"projection_cls = {projection_cls}, not implemented") + + self.vocab_size = config.vocab_size + self.img_features = None + + if isinstance(config.img_processor, dict): + self.layer_idx = config.img_processor.get("layer_idx", -2) + self.type_feature = config.img_processor.get("type_feature", "patch") + else: + self.layer_idx = -2 + self.type_feature = "patch" + + def set_img_features(self, img_features: torch.FloatTensor) -> None: + self.img_features = img_features + + def set_img_sizes(self, img_sizes: torch.LongTensor) -> None: + self.img_sizes = img_sizes + + def set_img_attn_mask(self, image_attention_mask: torch.FloatTensor) -> None: + self.image_attention_mask = image_attention_mask + + def get_img_features(self, img_embeds: torch.FloatTensor, attention_mask=None) -> torch.FloatTensor: + LAYER_IDX = self.layer_idx + TYPE_FEATURE = self.type_feature + + if self.freeze_img_processor: + with torch.no_grad(): + if attention_mask is not None: + img_processor_output = self.img_processor( + img_embeds, output_hidden_states=True, patch_attention_mask=attention_mask + ) + else: + img_processor_output = self.img_processor(img_embeds, output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + else: + if attention_mask is not None: + img_processor_output = self.img_processor( + img_embeds, output_hidden_states=True, patch_attention_mask=attention_mask + ) + else: + img_processor_output = self.img_processor(img_embeds, output_hidden_states=True) + img_feature = img_processor_output.hidden_states[LAYER_IDX] + + if TYPE_FEATURE == "patch": + patch_feature = img_feature + if self.image_token_compression is not None: + # reshape to 2D tensor + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + if getattr(self, "img_processor_padding", None) is not None: + patch_feature = self.img_processor_padding(patch_feature) + patch_feature = self.image_token_compression(patch_feature) + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view( + -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1) + ) + elif getattr(self, "img_processor_padding", None) is not None: + width = int(math.sqrt(patch_feature.size(1))) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + # convert to NCHW + patch_feature = patch_feature.permute(0, 3, 1, 2) + patch_feature = self.img_processor_padding(patch_feature) + # convert to NHWC + patch_feature = patch_feature.permute(0, 2, 3, 1) + patch_feature = patch_feature.view( + -1, patch_feature.size(1) * patch_feature.size(2), patch_feature.size(-1) + ) + return patch_feature + + if TYPE_FEATURE == "cls_patch": + if self.image_token_compression is not None: + # reshape to 2D tensor + patch_feature = img_feature[:, 1:] + cls_feature = img_feature[:, 0] + width = math.sqrt(patch_feature.size(1)) + patch_feature = patch_feature.view(-1, width, width, patch_feature.size(-1)) + patch_feature = self.image_token_compression(patch_feature) + patch_feature = patch_feature.view(-1, patch_feature.size(-2) * patch_feature.size(-1)) + img_feature = torch.cat([cls_feature, patch_feature], dim=1) + return img_feature + + logger.info(f"processed img feature size = {img_feature.size()}") + raise NotImplementedError + + def spatiotemporal_pool(self, x, num_img_tokens, batch_size=1, T=1): + + if self.image_pos_embed is not None: + x = x.view(batch_size * T, -1, x.shape[-1]) + num_tokens = x.shape[-2] + h, w = int(num_tokens**0.5), int(num_tokens**0.5) + assert h * w == num_tokens, "only support square feature maps for now" + x = x.view(batch_size * T, h, w, x.shape[-1]) + pos_embed = self.image_pos_embed(x) + x = x + pos_embed + x = x.view(batch_size, T * h * w, x.shape[-1]) + + if self.visual_temporal_embed is not None: + visual_temporal_embed = self.visual_temporal_embed(x.view(batch_size, T, -1, x.shape[-1])[:, :, 0]) + x = x.view(batch_size, T, -1, x.shape[-1]) + visual_temporal_embed.view(1, T, 1, x.shape[-1]) + + new_x = [] + # [bsz, T * H' * W', C] -> [bsz, T, C] + spatial_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=2) + new_x.append(spatial_avg_pool_x) + + # [bsz, T * H' * W', C] -> [bsz, H'*W', C] + temporal_avg_pool_x = x.view(batch_size, T, -1, x.shape[-1]).mean(dim=1) + new_x.append(temporal_avg_pool_x) + + x = torch.cat(new_x, dim=1).view(-1, self.image_dim_out) + num_img_tokens += T + return x, num_img_tokens + + def forward( + self, input_ids: torch.LongTensor, input_embeds: torch.FloatTensor, image_sizes=None, **kwargs + ) -> torch.FloatTensor: + + if isinstance(input_ids, tuple): + # # pipeline parallel + input_ids, input_embeds = input_ids + + img_embeds = input_embeds + if image_sizes is None and "image_sizes" in kwargs: + image_sizes = kwargs["image_sizes"] + img_sizes = image_sizes + + if self.img_features is not None: + img_embeds = self.img_features.clone() + self.img_features = None + + if self.img_sizes is not None: + img_sizes = self.img_sizes + + if img_embeds is not None: + # convert to bf16 + img_embeds = img_embeds.to(torch.bfloat16) + + if self.image_attention_mask is not None: + image_attention_mask = self.image_attention_mask.clone() + self.image_attention_mask = None + elif "image_attention_mask" in kwargs: + image_attention_mask = kwargs["image_attention_mask"] + else: + image_attention_mask = None + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + with torch.no_grad(): + positions = torch.nonzero(input_ids == _IMAGE_SPECIAL_TOKEN_ID, as_tuple=False) + positions_tuple = torch.nonzero(input_ids == _IMAGE_SPECIAL_TOKEN_ID, as_tuple=True) + + # logger.info(f'position size: {positions.size()} ...') + fake_image_forward = False + select = False + hd_transform = False + + if isinstance(self.img_projection, nn.Sequential): + target_device = self.img_projection[0].bias.device + target_dtype = self.img_projection[0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.img_projection.bias.device + target_dtype = self.img_projection.bias.dtype + + num_img_tokens = self.num_img_tokens + if len(positions.tolist()) > 0: + if self.use_hd_transform and img_sizes is not None and len(img_sizes): + hd_transform = True + assert ( + img_embeds.ndim == 5 + ), f"(branch 1) img_embeds size: {img_embeds.size()}, expect 5D tensor for hd transform" + # img_embeds: (num_images, max_num_crops, 3, H, W) + # img_sizes: (num_images, 2).view(1, -1) + + bs = img_embeds.shape[0] + # Nx(HW)xC + if image_attention_mask is not None and len(image_attention_mask) > 0: + img_features = self.get_img_features( + img_embeds.flatten(0, 1), + attention_mask=image_attention_mask.type(torch.BoolTensor).flatten(0, 1).to(target_device), + ) + else: + img_features = self.get_img_features(img_embeds.flatten(0, 1)) + + base_feat_height_target = self.base_feat_height_target + base_resolution = self.crop_size + base_feat_height_reduction = self.base_feat_height_reduction + + base_feat_height = base_feat_width = int(np.sqrt(img_features.shape[1])) + + assert ( + base_feat_height == base_feat_height_target and base_feat_width == base_feat_height_target + ), f"base_feat_height: {base_feat_height}, base_feat_width: {base_feat_width}, expect {base_feat_height_target} features for hd transform" + + # bs x max_num_crops x (24x24) x C + img_features = img_features.view(bs, -1, base_feat_height * base_feat_width, self.image_dim_out) + C = self.image_dim_out + H = base_feat_height + + output_imgs = [] + output_len = [] + # training is tensor, inference is list + if isinstance(img_sizes, torch.Tensor): + img_sizes = img_sizes.view(-1, 2) + for _bs in range(bs): + h, w = img_sizes[_bs] + # h = h // base_resolution + # w = w // base_resolution + h = torch.div(h, base_resolution, rounding_mode="trunc") + w = torch.div(w, base_resolution, rounding_mode="trunc") + B_ = h * w + + # 1 x (24x24) x 1024 + global_img_feature = img_features[_bs, :1] + + # 1 x 12 x 12 x 4096 + glb_img = ( + global_img_feature.reshape(1, H, H, C) + .reshape( + 1, + H // base_feat_height_reduction, + base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction, + C, + ) + .contiguous() + .permute(0, 1, 3, 2, 4, 5) + .reshape( + 1, + H // base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * C, + ) + .contiguous() + ) + temp_glb_GN = self.sub_GN.repeat(1, H // base_feat_height_reduction, 1, 1) + + # 1 x 156 x 4096 + glb_img = torch.cat([glb_img, temp_glb_GN], dim=2).reshape( + 1, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) + + # (max_num_crops-1) x (12x12) x C + sub_img = img_features[_bs, 1:] + # 16x574x1024 + # get rid of padding sub_img + sub_img = sub_img[:B_] + + # (num_crops, 12, 2, 12, 2, 1024) -> (num_crops, 12, 12, 2, 2, 1024) -> (num_crops, 12*12, 4*1024) + sub_img = ( + sub_img.reshape(B_, H, H, C) + .reshape( + B_, + H // base_feat_height_reduction, + base_feat_height_reduction, + H // base_feat_height_reduction, + base_feat_height_reduction, + C, + ) + .contiguous() + .permute(0, 1, 3, 2, 4, 5) + .reshape(B_, -1, base_feat_height_reduction * base_feat_height_reduction * C) + .contiguous() + ) + sub_img = ( + sub_img.reshape( + 1, + h, + w, + base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction, + -1, + ) + .permute(0, 1, 3, 2, 4, 5) + .reshape( + 1, + h * base_feat_height // base_feat_height_reduction, + w * base_feat_width // base_feat_height_reduction, + base_feat_height_reduction * base_feat_height_reduction * C, + ) + ) + + if image_attention_mask is not None and len(image_attention_mask) > 0: + reshaped_image_attention_mask = ( + image_attention_mask[_bs, 1 : B_ + 1, 0::2, 0::2] + .reshape( + 1, + h, + w, + base_feat_height // base_feat_height_reduction, + base_feat_width // base_feat_height_reduction, + ) + .permute(0, 1, 3, 2, 4) + .reshape( + 1, + h * base_feat_height // base_feat_height_reduction, + w * base_feat_width // base_feat_height_reduction, + ) + ) + useful_height = int(reshaped_image_attention_mask[0, :, 0].sum().item()) + useful_width = int(reshaped_image_attention_mask[0, 0, :].sum().item()) + sub_img = sub_img[:, :useful_height, :useful_width] + temp_sub_GN = self.sub_GN.repeat(1, useful_height, 1, 1) + temp_len = ( + int(image_attention_mask[_bs, : B_ + 1, 0::2, 0::2].sum().item()) + + (useful_height + 1) + + base_feat_height // base_feat_height_reduction + ) + else: + temp_sub_GN = self.sub_GN.repeat(1, h * base_feat_height // base_feat_height_reduction, 1, 1) + temp_len = int( + (h * w + 1) * self.num_img_tokens + + 1 + + (h + 1) * base_feat_height // base_feat_height_reduction + ) + + sub_img = torch.cat([sub_img, temp_sub_GN], dim=2).reshape( + 1, -1, base_feat_height_reduction * base_feat_height_reduction * C + ) + # (1, num_img_tokens, 1024*4) + + # glb + sub + if self.hd_transform_order == "glb_sub": + output_imgs.append(torch.cat([glb_img, self.glb_GN, sub_img], dim=1)) + elif self.hd_transform_order == "sub_glb": + output_imgs.append(torch.cat([sub_img, self.glb_GN, glb_img], dim=1)) + else: + raise NotImplementedError(f"hd_transform_order = {self.hd_transform_order}, not implemented") + + # temp_len = int((h*w+1)*144 + 1 + (h+1)*12) + assert ( + temp_len == output_imgs[-1].shape[1] + ), f"temp_len: {temp_len}, output_imgs[-1].shape[1]: {output_imgs[-1].shape[1]}" + output_len.append(temp_len) + + num_img_tokens = output_len + img_set_tensor = [] + for _output_img in output_imgs: + img_feature_proj = self.img_projection(_output_img.to(target_device).to(target_dtype)) + img_set_tensor.append(img_feature_proj) + # logger.info(f'img_embeds size: {img_embeds.size()}, image sizes: {img_sizes} loading time {datetime.now() - start_time}') + # assert sum(num_img_tokens) == len(g_values), f'(branch 1) sum(num_img_tokens): {sum(num_img_tokens)}, g_values size: {len(g_values)}, g_values {g_values}' + + else: + raise NotImplementedError + select = True + else: + # # create a fake image tensor + # # TODO: need define image size for different vision model + if self.training: + img_embeds = torch.zeros( + 1, 3, self.crop_size, self.crop_size, dtype=torch.bfloat16, device=input_ids.device + ) + + tt = self.get_img_features(img_embeds).to(target_device).to(target_dtype).reshape(-1, 1024) + if self.use_hd_transform: + img_set_tensor = self.img_projection( + tt.reshape(-1, self.image_dim_out * self.base_feat_height_reduction**2) + * self.glb_GN[0] + * self.sub_GN[0, 0] + ) + else: + img_set_tensor = self.img_projection(tt) # adapted visual features. + fake_image_forward = True + + # we use the token embedding layer from the huggingface model, this is REQUIRED to make sure we are using the loaded weights. + hidden_states = kwargs["wte"](input_ids) + + if select: + if hd_transform: + # new implementation without in-place operation + # Ref: https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py#L233 + # Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.index_put.html + # Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html#torch.Tensor.index_put_ + # img_set_tensor: a list of tensors, each tensor has shape (1, N_tokens, C) + assert all( + [_img_set_tensor.shape[0] == 1 for _img_set_tensor in img_set_tensor] + ), "img_set_tensor should have shape (1, N_tokens, C)" + # Shape: (merged_N_tokens, C) + merged_img_set_tensor = torch.cat(img_set_tensor, dim=1).squeeze(0) + merged_img_set_tensor = merged_img_set_tensor.to(hidden_states.dtype).to(hidden_states.device) + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: /~https://github.com/pytorch/pytorch/issues/132715 + with torch.autocast(device_type=hidden_states.device.type, enabled=False): + new_hidden_states = hidden_states.index_put( + indices=positions_tuple, values=merged_img_set_tensor, accumulate=False + ) + hidden_states = new_hidden_states + else: + raise NotImplementedError + + if fake_image_forward and self.training: + hidden_states = ( + hidden_states + (0 * img_set_tensor[0].to(hidden_states.dtype).to(hidden_states.device)).sum() + ) + if self.drop is not None: + hidden_states = self.drop(hidden_states) + + return hidden_states + + +class PhiOAudioEmbedding(nn.Module): + """Audio embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + self.config = config + # n_embed or hidden_size for text LM + hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size + + if hasattr(config, "embd_pdrop") or hasattr(config, "embed_pdrop"): + embd_drop = config.embd_pdrop if hasattr(config, "embd_pdrop") else config.embed_pdrop + self.drop = nn.Dropout(embd_drop) + else: + self.drop = None + + audio_dim_out = None # Set this variable according to the actual audio processor + logger.info(f"create audio processor {config.audio_processor}") + self.layer_idx = -2 + + if isinstance(config.audio_processor, dict) and config.audio_processor.get("name", None) == "cascades": + encoder_config = config.audio_processor.get("config", None) + assert encoder_config is not None + self.encoder = ConformerEncoder(**encoder_config) + + # fake initialization, create encoder_embedding layer only so that + # in decoding, all parameters can be loaded in from_pretrained_function + # in training, we do post init after from_pretrained function to make sure the correct initialization + self.encoder.post_init({}) + + audio_dim_out = encoder_config["attention_dim"] + n_mels = encoder_config["input_size"] + else: + raise NotImplementedError + + assert audio_dim_out is not None, "Remember to set values for audio_dim_out" + self.audio_dim_out = audio_dim_out + self.audio_dim_in = n_mels + + self.freeze_audio_processor = kwargs.get("freeze_audio_processor", False) + logger.info(f"freeze_audio_processor = {self.freeze_audio_processor}") + + self.downsample_rate = kwargs.get("downsample_rate", 1) + + enable_gradient_checkpointing = kwargs.get("enable_gradient_checkpointing", False) + if enable_gradient_checkpointing: + self.encoder.gradient_checkpointing_enable() + logger.info("gradient checkpointing enabled for audio processor") + + projection_cls = kwargs.get("projection_cls", "linear") + if projection_cls == "linear": + self.audio_projection = nn.Linear(audio_dim_out, hidden_size) + elif projection_cls == "mlp": + # follow llava-v1.5's implementation + # (do not use image_projection and image_proj_norm) + dim_projection = hidden_size + depth = 2 + self.linear_downsample_rate = self.downsample_rate + + layers_for_speech = [nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)] + for _ in range(1, depth): + layers_for_speech.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + audio_projection_for_speech = nn.Sequential(*layers_for_speech) + + layers_for_vision = [nn.Linear(audio_dim_out * self.linear_downsample_rate, dim_projection)] + for _ in range(1, depth): + layers_for_vision.extend([nn.GELU(), nn.Linear(dim_projection, dim_projection)]) + audio_projection_for_vision = nn.Sequential(*layers_for_vision) + + self.audio_projection = nn.ModuleDict( + {"speech": audio_projection_for_speech, "vision": audio_projection_for_vision} + ) + else: + raise NotImplementedError(f"projection_cls = {projection_cls}, not implemented") + + self.vocab_size = config.vocab_size + self.input_embeds = None + self.audio_embed_sizes = None + + def post_init(self, audio_config): + # execute after the from_pretrained() initialization of the phio model + if audio_config.get("name", None) == "cascades": + init_model_config = audio_config.get("init_model", {}) + self.encoder.post_init(init_model_config) + # remove the init model in config so it is not saved in the config. + # This might affect the model loading in resuming training and decoding. + if "init_model" in audio_config: + audio_config.pop("init_model") + + def set_audio_embeds(self, input_embeds: torch.FloatTensor) -> None: + self.input_embeds = input_embeds + + def set_audio_embed_sizes(self, audio_embed_sizes: torch.LongTensor) -> None: + self.audio_embed_sizes = audio_embed_sizes + + def get_audio_features( + self, input_embeds: torch.FloatTensor, audio_attention_mask: torch.Tensor, audio_projection_mode: str = "speech" + ): + + if self.freeze_audio_processor: + with torch.no_grad(): + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) + else: + audio_features, masks = self.encoder(input_embeds, audio_attention_mask) + + if isinstance(self.audio_projection, nn.Sequential): + audio_set_tensor = self.audio_projection(audio_features) + elif isinstance(self.audio_projection, nn.ModuleDict): + audio_set_tensor = self.audio_projection[audio_projection_mode](audio_features) + else: + raise NotImplementedError + + return audio_set_tensor + + def forward( + self, + input_ids: torch.LongTensor, + input_embeds: torch.FloatTensor, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode="speech", + **kwargs, + ) -> torch.FloatTensor: + """ + arguments: + input_ids: input text ids (B, U) + input_embeds: audio features (B, T, D) B: num audios in a sequence + """ + if self.input_embeds is not None: + input_embeds = self.input_embeds.clone() + if self.audio_embed_sizes is not None: + audio_embed_sizes = self.audio_embed_sizes.clone() + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + MAX_INPUT_ID = int(1e9) + + with torch.no_grad(): + positions = torch.nonzero(input_ids == _AUDIO_SPECIAL_TOKEN_ID, as_tuple=False) + positions_tuple = torch.nonzero(input_ids == _AUDIO_SPECIAL_TOKEN_ID, as_tuple=True) + + if isinstance(self.audio_projection, nn.Sequential): + target_device = self.audio_projection[0].bias.device + target_dtype = self.audio_projection[0].bias.dtype + elif isinstance(self.audio_projection, nn.ModuleDict): + target_device = self.audio_projection[audio_projection_mode][0].bias.device + target_dtype = self.audio_projection[audio_projection_mode][0].bias.dtype + else: # It's a single nn.Linear layer + target_device = self.audio_projection.bias.device + target_dtype = self.audio_projection.bias.dtype + + if input_embeds is not None: + input_embeds = input_embeds.to(target_device).to(target_dtype) + + if len(positions.tolist()) > 0: + audio_set_tensor = self.get_audio_features(input_embeds, audio_attention_mask, audio_projection_mode) + else: + # # create an audio tensor + # To do: not sure if this is required for text only input + if self.training: + audio_embeds = torch.zeros(1, 500, self.audio_dim_in).to(target_device).to(target_dtype) + audio_attention_mask = audio_embeds.new_ones(audio_embeds.size()[:2]).long() + audio_set_tensor = self.get_audio_features(audio_embeds, audio_attention_mask, audio_projection_mode) + + hidden_states = kwargs["wte"](input_ids) + + if len(positions.tolist()) > 0: + + assert audio_embed_sizes.sum().item() == len( + positions + ), f"please ensure the encoder outputs have the same length as defined in input_ids! \n audio_embed_sizes.sum().item(): {audio_embed_sizes.sum().item()} \n len(positions): {len(positions)} \n audio_embed_sizes: {audio_embed_sizes} \n positions: {positions} \n input_ids.shape \n {input_ids.shape}" + + # new implementation without in-place operation + # Ref: https://huggingface.co/microsoft/Phi-3.5-vision-instruct/blob/4a0d683eba9f1d0cbfb6151705d1ee73c25a80ca/modeling_phi3_v.py#L233 + # Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.index_put.html + # Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.index_put_.html#torch.Tensor.index_put_ + # audio_set_tensor: shape (N_audios, N_padded_tokens, C) + # Shape: (merged_N_tokens, C) + merged_audio_set_tensor = torch.cat( + [audio_set_tensor[i, : audio_embed_sizes[i], :] for i in range(len(audio_embed_sizes))], dim=0 + ) + merged_audio_set_tensor = merged_audio_set_tensor.to(hidden_states.dtype).to(hidden_states.device) + # Temporarily disable autocast to avoid issue on bf16 tensors + # Ref: /~https://github.com/pytorch/pytorch/issues/132715 + with torch.autocast(device_type=hidden_states.device.type, enabled=False): + new_hidden_states = hidden_states.index_put( + indices=positions_tuple, values=merged_audio_set_tensor, accumulate=False + ) + hidden_states = new_hidden_states + else: + if self.training: + hidden_states = ( + hidden_states + (0 * audio_set_tensor[:, 0].to(hidden_states.dtype).to(hidden_states.device)).sum() + ) + + if self.drop is not None: + hidden_states = self.drop(hidden_states) + + return hidden_states + + +class PhiOImageAudioEmbedding(nn.Module): + """Image-audio embedding.""" + + def __init__(self, config: PretrainedConfig, **kwargs) -> None: + super().__init__() + + self.vocab_size = config.vocab_size + + self.image_input_id = kwargs.get("image_input_id", -1) + self.audio_input_id = kwargs.get("audio_input_id", -10000) + assert self.image_input_id != self.audio_input_id, "image_input_id and audio_input_id should be different" + + self.image_embd_layer_kwargs = kwargs["image_embd_layer"] + self.image_embed = PhiOImageEmbedding(config, **self.image_embd_layer_kwargs) + self.audio_embd_layer_kwargs = kwargs["audio_embd_layer"] + self.audio_embed = PhiOAudioEmbedding(config, **self.audio_embd_layer_kwargs) + + self.input_image_embeds = None + self.image_sizes = None + self.image_attention_mask = None + self.input_audio_embeds = None + self.audio_embed_sizes = None + + def post_init(self, audio_config): + # post init for audio embedding + # ref: model.model.embed_tokens_extend.post_init(audio_config) in phyagi/getters/model.py + self.audio_embed.post_init(audio_config) + + def set_input_image_embeds(self, input_image_embeds: torch.FloatTensor) -> None: + self.input_image_embeds = input_image_embeds + + def set_image_sizes(self, image_sizes: torch.LongTensor) -> None: + self.image_sizes = image_sizes + + def set_img_attn_mask(self, image_attention_mask: torch.FloatTensor) -> None: + self.image_attention_mask = image_attention_mask + + def set_input_audio_embeds(self, input_audio_embeds: torch.FloatTensor) -> None: + self.input_audio_embeds = input_audio_embeds + + def set_audio_embed_sizes(self, audio_embed_sizes: torch.LongTensor) -> None: + self.audio_embed_sizes = audio_embed_sizes + + def forward( + self, + input_ids: torch.LongTensor, + input_embeds, + input_image_embeds: torch.FloatTensor = None, + input_audio_embeds: torch.FloatTensor = None, + image_sizes=None, + image_attention_mask=None, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode="speech", + wte=None, + ) -> torch.FloatTensor: + MAX_INPUT_ID = int(1e9) + assert -MAX_INPUT_ID < self.audio_input_id < self.image_input_id + + # override image and audio embeddings and sizes from object itself + # this is for inference + # ref: phyagi/eval/utils/text_generation_vision_audio_pipeline.py + if self.input_image_embeds is not None: + assert input_image_embeds is None + input_image_embeds = self.input_image_embeds.clone() + # NOTE weijian: set input_image_embeds to None after first call in for eval stage + # during evaluation, it will call model's forward() multiple times + # the first time input_ids contains the prompt (including <|image_{}|>) and input_embeds exists + # from the second time, the input_ids will only contain the generated text + # thus, the input_image_embeds is no longer needed + self.input_image_embeds = None + + if self.image_sizes is not None: + assert image_sizes is None + image_sizes = self.image_sizes + + if self.input_audio_embeds is not None: + assert input_audio_embeds is None + input_audio_embeds = self.input_audio_embeds.clone() + self.input_audio_embeds = None + + if self.audio_embed_sizes is not None: + assert audio_embed_sizes is None + audio_embed_sizes = self.audio_embed_sizes.clone() + + if input_image_embeds is not None: + # convert to bf16 + input_image_embeds = input_image_embeds.to(torch.bfloat16) + + if self.image_attention_mask is not None: + assert image_attention_mask is None + image_attention_mask = self.image_attention_mask.clone() + self.image_attention_mask = None + + if input_audio_embeds is not None: + # convert to bf16 + input_audio_embeds = input_audio_embeds.to(torch.bfloat16) + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + # backward compatibility + with torch.no_grad(): + new_input_ids = input_ids.clone() + new_input_ids[ + (input_ids >= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[0]) + & (input_ids <= _COMPATIBLE_IMAGE_SPECIAL_TOKEN_ID_RANGE[1]) + ] = _IMAGE_SPECIAL_TOKEN_ID + new_input_ids[ + (input_ids >= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[0]) + & (input_ids <= _COMPATIBLE_AUDIO_SPECIAL_TOKEN_ID_RANGE[1]) + ] = _AUDIO_SPECIAL_TOKEN_ID + input_ids = new_input_ids + + with torch.no_grad(): + image_position_mask = input_ids == _IMAGE_SPECIAL_TOKEN_ID + non_image_position_mask = ~image_position_mask + + assert input_embeds is None + if self.training: + assert input_image_embeds is not None and input_audio_embeds is not None + + # copy the input ids since they will be modified in place in image_embed and audio_embed + image_hidden_states = self.image_embed( + input_ids=input_ids, + input_embeds=input_image_embeds, + image_sizes=image_sizes, + wte=wte, + image_attention_mask=image_attention_mask, + ) + audio_hidden_states = self.audio_embed( + input_ids=input_ids, + input_embeds=input_audio_embeds, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + wte=wte, + audio_projection_mode=audio_projection_mode, + ) + + # merge image and audio hidden states + # NOTE weijian: for non-image-audio tokens, here we use audio hidden states + # actually, in the debug code above, the non-image-audio tokens from image_hidden_states and audio_hidden_states should be the same + hidden_states = image_hidden_states * image_position_mask.to(torch.bfloat16).unsqueeze( + -1 + ) + audio_hidden_states * non_image_position_mask.to(torch.bfloat16).unsqueeze(-1) + + return hidden_states + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 +class PhiORMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """PhiORMSNorm is equivalent to T5LayerNorm.""" + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 +class PhiORotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) + self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + self.inv_freq.to(x.device) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See /~https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos() + sin = emb.sin() + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class PhiOSuScaledRotaryEmbedding(PhiORotaryEmbedding): + def __init__(self, dim, config, device=None): + warnings.warn( + "The class PhiOSuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please" + " use PhiOLongRoPEScaledRotaryEmbedding instead.", + FutureWarning, + ) + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + # Force float32 since bfloat16 loses precision on long contexts + # See /~https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class PhiOYarnScaledRotaryEmbedding(PhiORotaryEmbedding): + def __init__(self, dim, config, device=None): + warnings.warn( + "The class PhiOYarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers", + FutureWarning, + ) + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See /~https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = 0.1 * math.log(scale) + 1.0 + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +class PhiOLongRoPEScaledRotaryEmbedding(PhiORotaryEmbedding): + def __init__(self, dim, config, device=None): + super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) + + self.short_factor = config.rope_scaling["short_factor"] + self.long_factor = config.rope_scaling["long_factor"] + self.original_max_position_embeddings = config.original_max_position_embeddings + + @torch.no_grad() + def forward(self, x, position_ids, seq_len=None): + seq_len = seq_len or torch.max(position_ids) + 1 + if seq_len > self.original_max_position_embeddings: + ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) + else: + ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) + + inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim + self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) + + inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) + position_ids_expanded = position_ids[:, None, :].float() + + # Force float32 since bfloat16 loses precision on long contexts + # See /~https://github.com/huggingface/transformers/pull/29285 + device_type = x.device.type + device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" + with torch.autocast(device_type=device_type, enabled=False): + freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) + emb = torch.cat((freqs, freqs), dim=-1) + + scale = self.max_position_embeddings / self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) + + cos = emb.cos() * scaling_factor + sin = emb.sin() * scaling_factor + return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`, *optional*): + Deprecated and unused. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + + rotary_dim = cos.shape[-1] + q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] + k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] + + q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) + k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) + return q_embed, k_embed + + +class PhiOMLP(nn.Module): + def __init__(self, config): + super().__init__() + + self.config = config + self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) + self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) + + self.activation_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + up_states = self.gate_up_proj(hidden_states) + + gate, up_states = up_states.chunk(2, dim=-1) + up_states = up_states * self.activation_fn(gate) + + return self.down_proj(up_states) + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). + + The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class PhiOAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + def __init__(self, config: PhiOConfig, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.original_max_position_embeddings = config.original_max_position_embeddings + self.rope_theta = config.rope_theta + self.rope_scaling = config.rope_scaling + self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) + self._init_rope() + + def _init_rope(self): + if self.rope_scaling is None: + self.rotary_emb = PhiORotaryEmbedding( + self.rotary_ndims, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + if scaling_type == "longrope": + self.rotary_emb = PhiOLongRoPEScaledRotaryEmbedding(self.rotary_ndims, self.config) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights += causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class PhiOFlashAttention2(PhiOAttention): + """Phi-O flash attention module. + + This module inherits from `PhiOAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: /~https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # PhiOFlashAttention2 attention does not support output_attentions + + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = ( + max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len + ) + + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_dropout = self.attention_dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. + + if query_states.dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.qkv_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = _flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + position_ids=position_ids, + dropout=attn_dropout, + sliding_window=getattr(self.config, "sliding_window", None), + use_top_left_mask=self._flash_attn_uses_top_left_mask, + is_causal=self.is_causal, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi +# TODO @Arthur no longer copied from LLama after static cache + +import habana_frameworks.torch.hpu as ht +from habana_frameworks.torch.hpex.kernels import FusedSDPA + + +class PhiOSdpaAttention(PhiOAttention): + """PhiO attention module using torch.nn.functional.scaled_dot_product_attention. + + This module inherits from + `PhiOAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + + # Adapted from PhiOAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "PhiOModel is using PhiOSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + qkv = self.qkv_proj(hidden_states) + query_pos = self.num_heads * self.head_dim + query_states = qkv[..., :query_pos] + key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] + value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: /~https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "hpu" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + scale = 1.0 / math.sqrt(self.head_dim) + + # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment + # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. + # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. + + with ht.sdp_kernel(enable_recompute=False): + attn_output = FusedSDPA.apply( + query_states, + key_states, + value_states, + causal_mask, + self.attention_dropout if self.training else 0.0, + is_causal, + scale, + ) + # attn_output = torch.nn.functional.scaled_dot_product_attention( + # query_states, + # key_states, + # value_states, + # attn_mask=causal_mask, + # dropout_p=self.attention_dropout if self.training else 0.0, + # is_causal=is_causal, + # ) + # import pdb + + # pdb.set_trace() + torch.hpu.synchronize() + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +PHIO_ATTENTION_CLASSES = { + "eager": PhiOAttention, + "flash_attention_2": PhiOFlashAttention2, + "sdpa": PhiOSdpaAttention, +} + + +class PhiODecoderLayer(nn.Module): + def __init__(self, config: PhiOConfig, layer_idx: int): + super().__init__() + + self.config = config + self.self_attn = PHIO_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) + + self.mlp = PhiOMLP(config) + self.input_layernorm = PhiORMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) + self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) + self.post_attention_layernorm = PhiORMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): + input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + position_ids (`torch.LongTensor` of shape `({0})`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range + `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + attn_outputs, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = residual + self.resid_attn_dropout(attn_outputs) + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + self.resid_mlp_dropout(hidden_states) + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +PHIO_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`PhiOConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Phi-O model outputting raw hidden-states without any specific head on top.", + PHIO_START_DOCSTRING, +) +class PhiOPreTrainedModel(PreTrainedModel): + config_class = PhiOConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["PhiODecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + _version = "0.0.5" + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +PHIO_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): + Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` + returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. + + Two formats are allowed: + - a [`~cache_utils.Cache`] instance, see our + [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); + - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of + shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy + cache format. + + The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the + legacy cache format will be returned. + + If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't + have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` + of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, + this tensor is not affected by padding. It is used to update the cache in the correct position and to infer + the complete sequence length. +""" + + +@add_start_docstrings( + "The bare Phi-O model outputting raw hidden-states without any specific head on top.", + PHIO_START_DOCSTRING, +) +class PhiOModel(PhiOPreTrainedModel): + """Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiODecoderLayer`] + + Args: + config: PhiOConfig + """ + + def __init__(self, config: PhiOConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + + self.embed_tokens_extend = None + if isinstance(config.embd_layer, dict): + embedding_config = {"embedding_cls": config.embd_layer["embedding_cls"], **config.embd_layer} + self.embed_tokens_extend = PhiOImageAudioEmbedding(config, **embedding_config) + + self.layers = nn.ModuleList( + [PhiODecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = PhiORMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHIO_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + input_image_embeds: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + input_audio_embeds: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + audio_projection_mode=None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # kept for BC (non `Cache` `past_key_values` inputs) + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + if past_key_values is None: + past_key_values = DynamicCache() + else: + past_key_values = DynamicCache.from_legacy_cache(past_key_values) + logger.warning_once( + "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " + "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " + "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" + ) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens_extend( + input_ids=input_ids, + input_embeds=inputs_embeds, + input_image_embeds=input_image_embeds, + input_audio_embeds=input_audio_embeds, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + wte=self.embed_tokens, + ) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + causal_mask = self._update_causal_mask( + attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + decoder_layer.__call__, + hidden_states, + causal_mask, + position_ids, + past_key_values, + output_attentions, + use_cache, + cache_position, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache = layer_outputs[2 if output_attentions else 1] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if return_legacy_cache: + next_cache = next_cache.to_legacy_cache() + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if ( + self.config._attn_implementation == "sdpa" + and not (using_static_cache or using_sliding_window_cache) + and not output_attentions + ): + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + sliding_window=self.config.sliding_window, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + min_dtype = torch.finfo(dtype).min + sequence_length = input_tensor.shape[1] + # SlidingWindowCache or StaticCache + if using_sliding_window_cache or using_static_cache: + target_length = past_key_values.get_max_cache_shape() + # DynamicCache or no cache + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + config=self.config, + past_key_values=past_key_values, + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: /~https://github.com/pytorch/pytorch/issues/110213 + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phi3 + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + config: PhiOConfig, + past_key_values: Cache, + ): + """Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to place the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + config (`PhiOConfig`): + The model's configuration class + past_key_values (`Cache`): + The cache class that is being used currently to generate + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) + diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + if config.sliding_window is not None: + # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also + # the check is needed to verify is current checkpoint was trained with sliding window or not + if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: + sliding_attend_mask = torch.arange(target_length, device=device) <= ( + cache_position.reshape(-1, 1) - config.sliding_window + ) + diagonal_attend_mask.bitwise_or_(sliding_attend_mask) + causal_mask *= diagonal_attend_mask + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + if attention_mask.shape[-1] > target_length: + attention_mask = attention_mask[:, :target_length] + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + return causal_mask + + +class PhiOForCausalLM(PhiOPreTrainedModel, GenerationMixin): + _tied_weights_keys = ["lm_head.weight"] + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi + def __init__(self, config): + super().__init__(config) + self.model = PhiOModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + # LoRA related settings + assert getattr(config, "vision_lora", None) is not None + from peft import LoraConfig, get_peft_model + + vision_lora_config = LoraConfig( + r=config.vision_lora["r"], + lora_alpha=config.vision_lora["lora_alpha"], + target_modules=config.vision_lora["layer"], + lora_dropout=config.vision_lora["dp"], + task_type="CAUSAL_LM", + ) + peft_model = get_peft_model(self.model, vision_lora_config, adapter_name="vision") + self.config.vision_lora["r"] = config.vision_lora["r"] + self.config.vision_lora["lora_alpha"] = config.vision_lora["lora_alpha"] + self.config.vision_lora["layer"] = config.vision_lora["layer"] + self.config.vision_lora["dp"] = config.vision_lora["dp"] + + assert getattr(config, "speech_lora", None) is not None + speech_lora_config = LoraConfig( + r=config.speech_lora["r"], + lora_alpha=config.speech_lora["lora_alpha"], + target_modules=config.speech_lora["layer"], + lora_dropout=config.speech_lora["dp"], + task_type="CAUSAL_LM", + ) + peft_model.base_model.active_adapter.append("speech") + peft_model.add_adapter("speech", speech_lora_config) + self.config.speech_lora["r"] = config.speech_lora["r"] + self.config.speech_lora["lora_alpha"] = config.speech_lora["lora_alpha"] + self.config.speech_lora["layer"] = config.speech_lora["layer"] + self.config.speech_lora["dp"] = config.speech_lora["dp"] + + def set_lora_adapter(self, adapter_name) -> None: + from peft.tuners.lora.layer import LoraLayer + + for module in self.modules(): + if isinstance(module, LoraLayer): + if module.merged: + warnings.warn("Adapter cannot be set when the model is merged. Unmerging the model first.") + module.unmerge() + module.set_adapter(adapter_name) + module._disable_adapters = False + + def unset_lora_adapter(self) -> None: + # Ref: peft/tuners/tuners_utils.py - enable_adapters() + # Ref: peft/tuners/lora/layer.py + from peft.tuners.lora.layer import LoraLayer + + for module in self.modules(): + if isinstance(module, LoraLayer): + # disable grads on all adapter layers + # TODO weijian: may use enable_adapters() instead + for layer_name in module.adapter_layer_names: + layer = getattr(module, layer_name) + layer.requires_grad_(False) + module._disable_adapters = True + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings + def get_input_embeddings(self): + return self.model.embed_tokens + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings + def get_output_embeddings(self): + return self.lm_head + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder + def set_decoder(self, decoder): + self.model = decoder + + # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder + def get_decoder(self): + return self.model + + # Ignore copy + @add_start_docstrings_to_model_forward(PHIO_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + input_image_embeds: Optional[torch.FloatTensor] = None, + image_sizes: Optional[torch.LongTensor] = None, + image_attention_mask=None, + input_audio_embeds: Optional[torch.FloatTensor] = None, + audio_embed_sizes=None, + audio_attention_mask=None, + input_mode=None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + trim_logits: Optional[bool] = False, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, PhiOForCausalLM + + >>> model = PhiOForCausalLM.from_pretrained("TBA") + >>> tokenizer = AutoTokenizer.from_pretrained("TBA") + + >>> prompt = "This is an example script ." + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' + ```""" + if ( + use_cache + and self.config.rope_scaling + and cache_position is not None + and cache_position[0] == self.config.original_max_position_embeddings + ): + logger.warning( + f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed." + ) + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if isinstance(input_mode, torch.Tensor): + assert len(input_mode) == 1 + input_mode = input_mode[0].item() + input_mode = InputMode(input_mode) + + if input_mode in [InputMode.VISION_SPEECH, InputMode.VISION]: + self.set_lora_adapter("vision") + audio_projection_mode = "vision" + elif input_mode == InputMode.SPEECH: + self.set_lora_adapter("speech") + audio_projection_mode = "speech" + elif input_mode == InputMode.LANGUAGE: + self.unset_lora_adapter() + audio_projection_mode = "speech" + else: + raise ValueError(f"Invalid input_mode: {input_mode}") + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + input_image_embeds=input_image_embeds, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + input_audio_embeds=input_audio_embeds, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + audio_projection_mode=audio_projection_mode, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + num_logits_to_keep = 0 + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function(logits, labels, self.vocab_size) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + input_image_embeds=None, + image_sizes=None, + image_attention_mask=None, + input_audio_embeds=None, + audio_embed_sizes=None, + audio_attention_mask=None, + input_mode=None, + cache_position=None, + position_ids=None, + use_cache=True, + num_logits_to_keep=None, + **kwargs, + ): + # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the + # process + + # When the first time input length reached long and short factor switching point, enforce re-compute cache + # It will cause downside of slower at this single token position, however, better than current failure. + if ( + past_key_values + and self.config.rope_scaling + and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 + ): + past_length = cache_position[0] + if past_length <= self.config.original_max_position_embeddings: + past_key_values = None + + model_inputs = super().prepare_inputs_for_generation( + input_ids=input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + input_image_embeds=input_image_embeds, + image_sizes=image_sizes, + image_attention_mask=image_attention_mask, + input_audio_embeds=input_audio_embeds, + audio_embed_sizes=audio_embed_sizes, + audio_attention_mask=audio_attention_mask, + input_mode=input_mode, + cache_position=cache_position, + position_ids=position_ids, + use_cache=use_cache, + num_logits_to_keep=num_logits_to_keep, + **kwargs, + ) + return model_inputs + + +@add_start_docstrings( + """ + The [`PhiOModel`] with a sequence classification head on top (linear layer). + + [`PhiOForSequenceClassification`] uses the last token in order to do the classification, as other causal models + (e.g. GPT-2) do. + + Since it does classification on the last token, it requires to know the position of the last token. If a + `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If + no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the + padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in + each row of the batch). + """, + PHIO_START_DOCSTRING, +) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi, LLAMA->PHI, self.transformer->self.model, transformer_outputs->model_outputs +class PhiOForSequenceClassification(PhiOPreTrainedModel): + def __init__(self, config): + super().__init__(config) + self.num_labels = config.num_labels + self.model = PhiOModel(config) + self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + @add_start_docstrings_to_model_forward(PHIO_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + r"""Labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = model_outputs[0] + logits = self.score(hidden_states) + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths = -1 + else: + if input_ids is not None: + # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility + sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 + sequence_lengths = sequence_lengths % input_ids.shape[-1] + sequence_lengths = sequence_lengths.to(logits.device) + else: + sequence_lengths = -1 + + pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] + + loss = None + if labels is not None: + loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) + + if not return_dict: + output = (pooled_logits,) + model_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=model_outputs.past_key_values, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) + + +@add_start_docstrings( + """ + [`PhiOModel`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for + Named-Entity-Recognition (NER) tasks. + """, + PHIO_START_DOCSTRING, +) +# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi,MPT->PHI,self.transformer->self.model,transformer_outputs->model_outputs +class PhiOForTokenClassification(PhiOPreTrainedModel): + def __init__(self, config: PhiOConfig): + super().__init__(config) + self.num_labels = config.num_labels + + self.model = PhiOModel(config) + if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: + classifier_dropout = config.classifier_dropout + elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: + classifier_dropout = config.hidden_dropout + else: + classifier_dropout = 0.1 + self.dropout = nn.Dropout(classifier_dropout) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(PHIO_INPUTS_DOCSTRING) + @add_code_sample_docstrings( + checkpoint=_CHECKPOINT_FOR_DOC, + output_type=TokenClassifierOutput, + config_class=_CONFIG_FOR_DOC, + ) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + attention_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + labels: Optional[torch.Tensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **deprecated_arguments, + ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: + r"""Labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): + + Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., + config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If + `config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + model_outputs = self.model( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = model_outputs[0] + hidden_states = self.dropout(hidden_states) + logits = self.classifier(hidden_states) + + loss = None + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + batch_size, seq_length = labels.shape + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length)) + + if not return_dict: + output = (logits,) + model_outputs[2:] + return ((loss,) + output) if loss is not None else output + + return TokenClassifierOutput( + loss=loss, + logits=logits, + hidden_states=model_outputs.hidden_states, + attentions=model_outputs.attentions, + ) diff --git a/comps/llms/src/text-generation/patch/enhance-multimodal-patch/processing_phio.py b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/processing_phio.py new file mode 100644 index 000000000..8777c1669 --- /dev/null +++ b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/processing_phio.py @@ -0,0 +1,727 @@ +# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Processor class for PhiO.""" +import math +import re +from enum import Enum +from typing import List, Optional, Tuple, Union + +import numpy as np +import scipy +import torch +import torchvision +from torch.nn.utils.rnn import pad_sequence +from transformers import AutoFeatureExtractor, AutoImageProcessor +from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_utils import ImageInput, make_list_of_images, valid_images +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PaddingStrategy, TextInput, TruncationStrategy +from transformers.utils import TensorType, logging + +logger = logging.get_logger(__name__) + +# Special tokens +_COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN = r"<\|image_\d+\|>" # For backward compatibility +_COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN = r"<\|audio_\d+\|>" # For backward compatibility +_IMAGE_SPECIAL_TOKEN = "<|endoftext10|>" +_AUDIO_SPECIAL_TOKEN = "<|endoftext11|>" +_IMAGE_SPECIAL_TOKEN_ID = 200010 # '<|endoftext10|>', or we can better name it (in `tokenizer_config.json`) +_AUDIO_SPECIAL_TOKEN_ID = 200011 # '<|endoftext11|>' + + +class InputMode(Enum): + LANGUAGE = 0 + VISION = 1 + SPEECH = 2 + VISION_SPEECH = 3 + + +class PhiOImageProcessor(BaseImageProcessor): + r"""Constructs a PhiO image processor.""" + + model_input_names = ["input_image_embeds", "image_sizes", "image_attention_mask"] + + def __init__( + self, + dynamic_hd, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.dynamic_hd = dynamic_hd + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=384, mask_size=27, use_thumbnail=True): + orig_width, orig_height = image.size + + w_crop_num = math.ceil(orig_width / float(image_size)) + h_crop_num = math.ceil(orig_height / float(image_size)) + if w_crop_num * h_crop_num > max_num: + + aspect_ratio = orig_width / orig_height + + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + print(target_aspect_ratio) + else: + target_width = image_size * w_crop_num + target_height = image_size * h_crop_num + target_aspect_ratio = (w_crop_num, h_crop_num) + + # Calculate the ratio + ratio_width = target_width / orig_width + ratio_height = target_height / orig_height + if ratio_width < ratio_height: + new_size = (target_width, int(orig_height * ratio_width)) + padding_width = 0 + padding_height = target_height - int(orig_height * ratio_width) + else: + new_size = (int(orig_width * ratio_height), target_height) + padding_width = target_width - int(orig_width * ratio_height) + padding_height = 0 + + attention_mask = torch.ones((int(mask_size * target_aspect_ratio[1]), int(mask_size * target_aspect_ratio[0]))) + if padding_width >= 14: + attention_mask[:, -math.floor(padding_width / 14) :] = 0 + if padding_height >= 14: + attention_mask[-math.floor(padding_height / 14) :, :] = 0 + assert attention_mask.sum() > 0 + + if min(new_size[1], target_height) < 10 or min(new_size[0], target_width) < 10: + raise ValueError(f"the aspect ratio is very extreme {new_size}") + + image = torchvision.transforms.functional.resize( + image, + [new_size[1], new_size[0]], + ) + + resized_img = torchvision.transforms.functional.pad( + image, [0, 0, padding_width, padding_height], fill=[255, 255, 255] + ) + + return resized_img, attention_mask + + def pad_to_max_num_crops(self, images, max_crops=5): + """ + images: B x 3 x H x W, B<=max_crops + """ + B, _, H, W = images.shape + if B < max_crops: + pad = torch.zeros(max_crops - B, 3, H, W, dtype=images.dtype, device=images.device) + images = torch.cat([images, pad], dim=0) + return images + + def pad_mask_to_max_num_crops(self, masks, max_crops=5): + B, H, W = masks.shape + if B < max_crops: + pad = torch.ones(max_crops - B, H, W, dtype=masks.dtype, device=masks.device) + masks = torch.cat([masks, pad], dim=0) + return masks + + def preprocess( + self, + images: ImageInput, + return_tensors: Optional[Union[str, TensorType]] = None, + ): + """ + Args: + images (`ImageInput`): + Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If + passing in images with pixel values between 0 and 1, set `do_rescale=False`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be one of: + - Unset: Return a list of `np.ndarray`. + - `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`. + - `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`. + - `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`. + - `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`. + """ + images = make_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + # Basic settings. + img_processor = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + ] + ) + dyhd_base_resolution = 448 + + # Dynamic HD + base_resolution = dyhd_base_resolution + images = [image.convert("RGB") for image in images] + # cover 384 and 448 resolution + mask_resolution = base_resolution // 14 + elems, image_attention_masks = [], [] + for im in images: + elem, attention_mask = self.dynamic_preprocess( + im, max_num=self.dynamic_hd, image_size=base_resolution, mask_size=mask_resolution + ) + elems.append(elem) + image_attention_masks.append(attention_mask) + hd_images = [img_processor(im) for im in elems] + global_image = [ + torch.nn.functional.interpolate( + im.unsqueeze(0).float(), + size=(base_resolution, base_resolution), + mode="bicubic", + ).to(im.dtype) + for im in hd_images + ] + shapes = [[im.size(1), im.size(2)] for im in hd_images] + mask_shapes = [[mask.size(0), mask.size(1)] for mask in image_attention_masks] + global_attention_mask = [torch.ones((1, mask_resolution, mask_resolution)) for _ in hd_images] + hd_images_reshape = [ + im.reshape(1, 3, h // base_resolution, base_resolution, w // base_resolution, base_resolution) + .permute(0, 2, 4, 1, 3, 5) + .reshape(-1, 3, base_resolution, base_resolution) + .contiguous() + for im, (h, w) in zip(hd_images, shapes) + ] + attention_masks_reshape = [ + mask.reshape(1, h // mask_resolution, mask_resolution, w // mask_resolution, mask_resolution) + .permute(0, 1, 3, 2, 4) + .reshape(-1, mask_resolution, mask_resolution) + .contiguous() + for mask, (h, w) in zip(image_attention_masks, mask_shapes) + ] + downsample_attention_masks = [ + mask[:, 0::2, 0::2] + .reshape( + 1, + h // mask_resolution, + w // mask_resolution, + mask_resolution // 2 + mask_resolution % 2, + mask_resolution // 2 + mask_resolution % 2, + ) + .permute(0, 1, 3, 2, 4) + for mask, (h, w) in zip(attention_masks_reshape, mask_shapes) + ] + downsample_attention_masks = [ + mask.reshape(mask.size(1) * mask.size(2), mask.size(3) * mask.size(4)) + for mask in downsample_attention_masks + ] + num_img_tokens = [ + 256 + 1 + int(mask.sum().item()) + int(mask[:, 0].sum().item()) + 16 for mask in downsample_attention_masks + ] + + hd_images_reshape = [ + torch.cat([_global_image] + [_im], dim=0) for _global_image, _im in zip(global_image, hd_images_reshape) + ] + hd_masks_reshape = [ + torch.cat([_global_mask] + [_mask], dim=0) + for _global_mask, _mask in zip(global_attention_mask, attention_masks_reshape) + ] + max_crops = max([img.size(0) for img in hd_images_reshape]) + image_transformed = [self.pad_to_max_num_crops(im, max_crops) for im in hd_images_reshape] + image_transformed = torch.stack(image_transformed, dim=0) + mask_transformed = [self.pad_mask_to_max_num_crops(mask, max_crops) for mask in hd_masks_reshape] + mask_transformed = torch.stack(mask_transformed, dim=0) + + returned_input_image_embeds = image_transformed + returned_image_sizes = torch.tensor(shapes, dtype=torch.long) + returned_image_attention_mask = mask_transformed + returned_num_img_tokens = num_img_tokens + + data = { + "input_image_embeds": returned_input_image_embeds, + "image_sizes": returned_image_sizes, + "image_attention_mask": returned_image_attention_mask, + "num_img_tokens": returned_num_img_tokens, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + +AudioInput = Tuple[Union[np.ndarray, torch.Tensor], int] +AudioInputs = List[AudioInput] + + +def speechlib_mel(sample_rate, n_fft, n_mels, fmin=None, fmax=None): + """Create a Mel filter-bank the same as SpeechLib FbankFC. + + Args: + sample_rate (int): Sample rate in Hz. number > 0 [scalar] + n_fft (int): FFT size. int > 0 [scalar] + n_mel (int): Mel filter size. int > 0 [scalar] + fmin (float): lowest frequency (in Hz). If None use 0.0. + float >= 0 [scalar] + fmax: highest frequency (in Hz). If None use sample_rate / 2. + float >= 0 [scalar] + + Returns + out (numpy.ndarray): Mel transform matrix + [shape=(n_mels, 1 + n_fft/2)] + """ + + bank_width = int(n_fft // 2 + 1) + if fmax is None: + fmax = sample_rate / 2 + if fmin is None: + fmin = 0 + assert fmin >= 0, "fmin cannot be negative" + assert fmin < fmax <= sample_rate / 2, "fmax must be between (fmin, samplerate / 2]" + + def mel(f): + return 1127.0 * np.log(1.0 + f / 700.0) + + def bin2mel(fft_bin): + return 1127.0 * np.log(1.0 + fft_bin * sample_rate / (n_fft * 700.0)) + + def f2bin(f): + return int((f * n_fft / sample_rate) + 0.5) + + # Spec 1: FFT bin range [f2bin(fmin) + 1, f2bin(fmax) - 1] + klo = f2bin(fmin) + 1 + khi = f2bin(fmax) + + khi = max(khi, klo) + + # Spec 2: SpeechLib uses trianges in Mel space + mlo = mel(fmin) + mhi = mel(fmax) + m_centers = np.linspace(mlo, mhi, n_mels + 2) + ms = (mhi - mlo) / (n_mels + 1) + + matrix = np.zeros((n_mels, bank_width), dtype=np.float32) + for m in range(0, n_mels): + left = m_centers[m] + center = m_centers[m + 1] + right = m_centers[m + 2] + for fft_bin in range(klo, khi): + mbin = bin2mel(fft_bin) + if left < mbin < right: + matrix[m, fft_bin] = 1.0 - abs(center - mbin) / ms + + return matrix + + +class PhiOAudioFeatureExtractor(SequenceFeatureExtractor): + model_input_names = ["input_audio_embeds", "audio_embed_sizes"] + + def __init__(self, audio_compression_rate, audio_downsample_rate, audio_feat_stride, **kwargs): + feature_size = 80 + sampling_rate = 16000 + padding_value = 0.0 + super().__init__(feature_size, sampling_rate, padding_value, **kwargs) + + self.compression_rate = audio_compression_rate + self.qformer_compression_rate = audio_downsample_rate + self.feat_stride = audio_feat_stride + + self._eightk_method = "fillzero" + self._mel = speechlib_mel(16000, 512, 80, fmin=None, fmax=7690).T + + self._hamming400 = np.hamming(400) # for 16k audio + self._hamming200 = np.hamming(200) # for 8k audio + + def duration_to_frames(self, duration): + """Duration in s, estimated frames.""" + frame_rate = 10 + + num_frames = duration * 1000 // frame_rate + return num_frames + + def __call__( + self, + audios: List[AudioInput], + return_tensors: Optional[Union[str, TensorType]] = None, + ): + # Ref: /~https://github.com/huggingface/transformers/blob/v4.47.0/src/transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py#L161 + returned_input_audio_embeds = [] + returned_audio_embed_sizes = [] + + for audio_data, sample_rate in audios: + audio_embeds = self._extract_features(audio_data, sample_rate) + audio_frames = len(audio_embeds) * self.feat_stride + audio_embed_size = self._compute_audio_embed_size(audio_frames) + + returned_input_audio_embeds.append(torch.tensor(audio_embeds)) + returned_audio_embed_sizes.append(torch.tensor(audio_embed_size).long()) + + returned_input_audio_embeds = pad_sequence(returned_input_audio_embeds, batch_first=True) + returned_audio_embed_sizes = torch.stack(returned_audio_embed_sizes, dim=0) + + data = { + "input_audio_embeds": returned_input_audio_embeds, + "audio_embed_sizes": returned_audio_embed_sizes, + } + + return BatchFeature(data=data, tensor_type=return_tensors) + + def _extract_spectrogram(self, wav, fs): + """Extract spectrogram features from waveform. + + Args: + wav (1D array): waveform of the input + fs (int): sampling rate of the waveform, 16000 or 8000. + If fs=8000, the waveform will be resampled to 16000Hz. + Output: + log_fbank (2D array): a TxD matrix of log Mel filterbank features. + D=80, and T is the number of frames. + """ + if wav.ndim > 1: + wav = np.squeeze(wav) + + # by default, we extract the mean if stereo + if len(wav.shape) == 2: + wav = wav.mean(1) + + # Resample to 16000 or 8000 if needed + if fs > 16000: + wav = scipy.signal.resample_poly(wav, 1, fs // 16000) + fs = 16000 + elif 8000 < fs < 16000: + wav = scipy.signal.resample_poly(wav, 1, fs // 8000) + fs = 8000 + elif fs < 8000: + raise RuntimeError(f"Unsupported sample rate {fs}") + + if fs == 8000: + if self._eightk_method == "resample": + # Input audio is 8 kHz. Convert to 16 kHz before feature + # extraction + wav = scipy.signal.resample_poly(wav, 2, 1) + fs = 16000 + # Do nothing here for fillzero method + elif fs != 16000: + # Input audio is not a supported sample rate. + raise RuntimeError(f"Input data using an unsupported sample rate: {fs}") + + preemphasis = 0.97 + + if fs == 8000: + n_fft = 256 + win_length = 200 + hop_length = 80 + fft_window = self._hamming200 + elif fs == 16000: + n_fft = 512 + win_length = 400 + hop_length = 160 + fft_window = self._hamming400 + + # Spec 1: SpeechLib cut remaining sample insufficient for a hop + n_batch = (wav.shape[0] - win_length) // hop_length + 1 + # Here we don't use stride_tricks since the input array may not satisfy + # memory layout requirement and we need writeable output + # Here we only use list of views before copy to destination + # so it is more efficient than broadcasting + y_frames = np.array( + [wav[_stride : _stride + win_length] for _stride in range(0, hop_length * n_batch, hop_length)], + dtype=np.float32, + ) + + # Spec 2: SpeechLib applies preemphasis within each batch + y_frames_prev = np.roll(y_frames, 1, axis=1) + y_frames_prev[:, 0] = y_frames_prev[:, 1] + y_frames = (y_frames - preemphasis * y_frames_prev) * 32768 + + S = np.fft.rfft(fft_window * y_frames, n=n_fft, axis=1).astype(np.complex64) + + if fs == 8000: + # Need to pad the output to look like 16 kHz data but with zeros in + # the 4 to 8 kHz bins. + frames, bins = S.shape + padarray = np.zeros((frames, bins)) + S = np.concatenate((S[:, 0:-1], padarray), axis=1) # Nyquist bin gets set to zero + + spec = np.abs(S).astype(np.float32) + return spec + + def _extract_features(self, wav, fs): + """Extract log filterbank features from waveform. + + Args: + wav (1D array): waveform of the input + fs (int): sampling rate of the waveform, 16000 or 8000. + If fs=8000, the waveform will be resampled to 16000Hz. + Output: + log_fbank (2D array): a TxD matrix of log Mel filterbank features. + D=80, and T is the number of frames. + """ + spec = self._extract_spectrogram(wav, fs) + spec_power = spec**2 + + fbank_power = np.clip(spec_power.dot(self._mel), 1.0, None) + log_fbank = np.log(fbank_power).astype(np.float32) + + return log_fbank + + def _compute_audio_embed_size(self, audio_frames): + integer = audio_frames // self.compression_rate + remainder = audio_frames % self.compression_rate + + result = integer if remainder == 0 else integer + 1 + + integer = result // self.qformer_compression_rate + remainder = result % self.qformer_compression_rate + result = integer if remainder == 0 else integer + 1 # qformer compression + + return result + + +class PhiOProcessor(ProcessorMixin): + r"""Constructs a PhiO processor which raps an image processor, a audio processor, and a GPT tokenizer into a single processor. + + [`PhiOProcessor`] offers all the functionalities of [`PhiOImageProcessor`] and [`GPT2Tokenizer`]. See the + [`~PhiOProcessor.__call__`] and [`~PhiOProcessor.decode`] for more information. + + Args: + image_processor ([`PhiOImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`GPT2Tokenizer`], *optional*): + The tokenizer is a required input. + """ + + attributes = ["image_processor", "audio_processor", "tokenizer"] + tokenizer_class = "GPT2TokenizerFast" + image_processor_class = "AutoImageProcessor" # PhiOImageProcessor will be registered later + audio_processor_class = "AutoFeatureExtractor" # PhiOAudioFeatureExtractor will be registered later + + def __init__(self, image_processor, audio_processor, tokenizer): + self.image_processor = image_processor + self.audio_processor = audio_processor + self.tokenizer = tokenizer + + def __call__( + self, + text: Union[TextInput, List[TextInput]], + images: Optional[ImageInput] = None, + audios: Optional[AudioInputs] = None, + padding: Union[bool, str, PaddingStrategy] = False, + truncation: Optional[Union[bool, str, TruncationStrategy]] = None, + max_length=None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + ) -> BatchFeature: + """Main method to prepare for the model one or several sequences(s) and image(s). This method forards the `text` + and `kwargs` arguments to GPT2Tokenizer's [`~GPT2Tokenizer.__call__`] if `text` is not `None` to encode + the text. To prepare the image(s), this method forwards the `images` and `kwrags` arguments to + PhiOImageProcessor's [`~PhiOImageProcessor.__call__`] if `images` is not `None`. Please refer to the doctsring + of the above two methods for more information. + + Args: + text (`str`, `List[str]`, `List[List[str]]`): + The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings + (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set + `is_split_into_words=True` (to lift the ambiguity with a batch of sequences). + images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): + The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch + tensor. Both channels-first and channels-last formats are supported. + padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`): + Select a strategy to pad the returned sequences (according to the model's padding side and padding + index) among: + - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single + sequence if provided). + - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum + acceptable input length for the model if that argument is not provided. + - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different + lengths). + max_length (`int`, *optional*): + Maximum length of the returned list and optionally padding length (see above). + truncation (`bool`, *optional*): + Activates truncation to cut input sequences longer than `max_length` to `max_length`. + return_tensors (`str` or [`~utils.TensorType`], *optional*): + If set, will return tensors of a particular framework. Acceptable values are: + + - `'tf'`: Return TensorFlow `tf.constant` objects. + - `'pt'`: Return PyTorch `torch.Tensor` objects. + - `'np'`: Return NumPy `np.ndarray` objects. + - `'jax'`: Return JAX `jnp.ndarray` objects. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **input_ids** -- List of token ids to be fed to a model. + - **input_image_embeds** -- Pixel values to be fed to a model. + - **image_sizes** -- List of tuples specifying the size of each image in `input_image_embeds`. + - **image_attention_mask** -- List of attention masks for each image in `input_image_embeds`. + - **input_audio_embeds** -- Audio embeddings to be fed to a model. + - **audio_embed_sizes** -- List of integers specifying the size of each audio in `input_audio_embeds`. + - **attention_mask** -- List of indices specifying which tokens should be attended to by the model. + """ + image_inputs = self.image_processor(images, return_tensors=return_tensors) if images is not None else {} + audio_inputs = self.audio_processor(audios, return_tensors=return_tensors) if audios is not None else {} + inputs = self._convert_images_audios_text_to_inputs( + image_inputs, + audio_inputs, + text, + padding=padding, + truncation=truncation, + max_length=max_length, + return_tensors=return_tensors, + ) + + # idenfity the input mode + if len(image_inputs) > 0 and len(audio_inputs) > 0: + input_mode = InputMode.VISION_SPEECH + elif len(image_inputs) > 0: + input_mode = InputMode.VISION + elif len(audio_inputs) > 0: + input_mode = InputMode.SPEECH + else: + input_mode = InputMode.LANGUAGE + inputs["input_mode"] = torch.tensor([input_mode.value], dtype=torch.long) + + return inputs + + @property + def special_image_token_id(self): + return self.tokenizer.convert_tokens_to_ids(self.special_image_token) + + def get_special_image_token_id(self): + return self.tokenizer.convert_tokens_to_ids(self.special_image_token) + + def _convert_images_audios_text_to_inputs( + self, images, audios, text, padding=False, truncation=None, max_length=None, return_tensors=None + ): + # prepare image id to image input ids + if len(images) > 0: + input_image_embeds = images["input_image_embeds"] + image_sizes = images["image_sizes"] + image_attention_mask = images["image_attention_mask"] + num_img_tokens = images["num_img_tokens"] + else: + input_image_embeds = torch.tensor([]) + image_sizes = torch.tensor([]) + image_attention_mask = torch.tensor([]) + num_img_tokens = [] + + # prepare audio id to audio input ids + if len(audios) > 0: + input_audio_embeds = audios["input_audio_embeds"] + audio_embed_sizes = audios["audio_embed_sizes"] + else: + input_audio_embeds = torch.tensor([]) + audio_embed_sizes = torch.tensor([]) + + # Replace certain special tokens for compatibility + # Ref: https://stackoverflow.com/questions/11475885/python-replace-regex + processed_text = re.sub(_COMPATIBLE_IMAGE_SPECIAL_TOKEN_PATTERN, _IMAGE_SPECIAL_TOKEN, text) + processed_text = re.sub(_COMPATIBLE_AUDIO_SPECIAL_TOKEN_PATTERN, _AUDIO_SPECIAL_TOKEN, processed_text) + + input_ids = self.tokenizer(processed_text).input_ids + i = 0 + img_cnt, audio_cnt = 0, 0 # only needed for later assertion + image_token_count_iter = iter(num_img_tokens) + audio_embed_size_iter = iter(audio_embed_sizes.tolist()) + while i < len(input_ids): + token_id = input_ids[i] + if token_id == _AUDIO_SPECIAL_TOKEN_ID: + token_count = next(audio_embed_size_iter) + audio_cnt += 1 + elif token_id == _IMAGE_SPECIAL_TOKEN_ID: + token_count = next(image_token_count_iter) + img_cnt += 1 + else: + i += 1 + continue + tokens = [token_id] * token_count + input_ids = input_ids[:i] + tokens + input_ids[i + 1 :] + i += token_count + input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0) + + # If the below assertion fails, it might be that input pure-text + # messages contain image/audio special tokens literally + # (<|endoftext10|>, <|endoftext11|>). + assert img_cnt == len(num_img_tokens), ( + f"Number of image tokens in prompt_token_ids ({img_cnt}) " + f"does not match number of images ({len(num_img_tokens)})" + ) + assert audio_cnt == len(audio_embed_sizes), ( + f"Number of audio tokens in prompt_token_ids ({audio_cnt}) " + f"does not match number of audios ({len(audio_embed_sizes)})" + ) + + # prepare attention mask + attention_mask = torch.ones_like(input_ids, dtype=torch.long) + + # prepare batch feature + data = { + "input_ids": input_ids, + "input_image_embeds": input_image_embeds, + "image_sizes": image_sizes, + "image_attention_mask": image_attention_mask, + "input_audio_embeds": input_audio_embeds, + "audio_embed_sizes": audio_embed_sizes, + "attention_mask": attention_mask, + } + + return BatchFeature(data=data) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.batch_decode`]. + + Please + refer to the docstring of this method for more information. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """This method forwards all its arguments to GPT2Tokenizer's [`~PreTrainedTokenizer.decode`]. + + Please refer to + the docstring of this method for more information. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + audio_processor_input_names = self.audio_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names + audio_processor_input_names)) + + +AutoImageProcessor.register("PhiOImageProcessor", PhiOImageProcessor) +AutoFeatureExtractor.register("PhiOAudioFeatureExtractor", PhiOAudioFeatureExtractor) diff --git a/comps/llms/src/text-generation/patch/enhance-multimodal-patch/speech_conformer_encoder.py b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/speech_conformer_encoder.py new file mode 100644 index 000000000..35c2c9cb9 --- /dev/null +++ b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/speech_conformer_encoder.py @@ -0,0 +1,2854 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +#!/usr/bin/env python3 + +# activation_checkpointing.py +"""Helper function for activation checkpointing.""" + +from functools import partial +from typing import Callable, Dict, Union + +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + CheckpointImpl, + checkpoint_wrapper, + offload_wrapper, +) + +# utils.py +"""cascade basic blocks""" + +import math +import random +from typing import Optional, Tuple + +import backoff +import numpy as np +import torch +import torch.nn.functional as F +from torch import Tensor, nn + +# conformer_encoder.py +"""ConformerEncoder Module""" + +import abc +from typing import List, Literal + +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointWrapper +from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel + + +# activation_checkpointing.py +def validate_checkpointing_config(activation_checkpointing): + """Validate activation checkpointing configuration.""" + if isinstance(activation_checkpointing, str): + assert activation_checkpointing in ( + "", + "checkpoint", + "offload", + ), "activation_checkpointing has to be a dict or a str in ('', 'checkpoint', 'offload')." + elif isinstance(activation_checkpointing, dict): + assert activation_checkpointing.get("module", "transformer") in ( + "transformer", + "attention", + ), "module in activation_checkpointing has to be in ('transformer', 'attention')." + else: + raise ValueError("activation_checkpointing has to be a str or dict.") + + +def embedding_checkpoint_wrapper( + activation_checkpointing: Union[str, Dict], +) -> Callable: + """Return encoder embedding activation checkpoint wrapper.""" + validate_checkpointing_config(activation_checkpointing) + + if isinstance(activation_checkpointing, str): + if activation_checkpointing: + if activation_checkpointing == "offload": + return offload_wrapper + return partial(checkpoint_wrapper) + return lambda x: x + + if isinstance(activation_checkpointing, dict): + enabled = activation_checkpointing.get("embed", False) + if enabled: + offloading = activation_checkpointing.get("offload", False) + if offloading: + return offload_wrapper + impl = ( + CheckpointImpl.REENTRANT + if activation_checkpointing.get("reentrant", False) + else CheckpointImpl.NO_REENTRANT + ) + return partial(checkpoint_wrapper, checkpoint_impl=impl) + return lambda x: x + raise ValueError("Invalid activation_checkpointing config") + + +def encoder_checkpoint_wrapper( + activation_checkpointing: Union[str, Dict], + layer_cls: type, + idx: int = 0, +) -> Callable: + """Return encoder activation checkpoint wrapper.""" + validate_checkpointing_config(activation_checkpointing) + + if isinstance(activation_checkpointing, str): + if activation_checkpointing: + if activation_checkpointing == "offload": + return offload_wrapper + return partial(checkpoint_wrapper) + return lambda x: x + + if isinstance(activation_checkpointing, dict): + target_layer_cls = activation_checkpointing.get("module", "transformer") + if target_layer_cls.lower() == "transformer": + target_layer_cls = ( + "EncoderLayer", + "ConformerEncoderLayer", + ) + elif target_layer_cls.lower() == "attention": + target_layer_cls = ("MultiHeadedAttention", "MultiHeadAttention") + checkpointing_interval = activation_checkpointing.get("interval", 1) + offloading = activation_checkpointing.get("offload", False) + impl = ( + CheckpointImpl.REENTRANT if activation_checkpointing.get("reentrant", True) else CheckpointImpl.NO_REENTRANT + ) + + if idx % checkpointing_interval == 0 and layer_cls.__name__ in target_layer_cls: + if offloading: + return offload_wrapper + return partial(checkpoint_wrapper, checkpoint_impl=impl) + return lambda x: x + + raise ValueError("Invalid activation_checkpointing config") + + +def attn_checkpointing(activation_checkpointing: Union[str, Dict], i) -> Union[str, Dict]: + """Return activation checkpointing config for attention layer.""" + if isinstance(activation_checkpointing, str): + return "" + + if isinstance(activation_checkpointing, dict): + target_layer_cls = activation_checkpointing.get("module", "transformer") + checkpointing_interval = activation_checkpointing.get("interval", 1) + if target_layer_cls == "attention" and i % checkpointing_interval == 0: + return activation_checkpointing + return "" + + raise ValueError("Invalid activation_checkpointing config") + + +# utils.py +class Block(nn.Module): + """Block abstract module.""" + + def __init__(self, input_size, output_size): + super().__init__() + self.input_size = input_size + self.output_size = output_size + + +def get_activation(name="relu"): + """Select an activation function by name. + + Args: + name: str + activation function name, + one of ["relu", "gelu", "swish", "sigmoid"], + default "relu". + """ + name = name.lower() + if name == "relu": + return nn.ReLU(inplace=True) + if name == "gelu": + return nn.GELU() + if name == "swish": + return Swish() + if name == "sigmoid": + return torch.nn.Sigmoid() + return nn.Identity() + + +def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0): + """ + The function is very important for Transformer Transducer Streaming mode + Args: + xs_len (int): sequence length + chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48]. It also supports adaptive chunk size [0,10,15,45] + left_window (int): how many left chunks can be seen + right_window (int): how many right chunks can be seen. It is used for chunk overlap model. + Returns: + mask (torch.Tensor): a mask tensor for streaming model + Torch 1.0.1 + tensor([[1., 1., 0., 0.], + [0., 1., 1., 0.], + [0., 0., 1., 1.]]) + Torch 1.4.1 + tensor([[True., True., False., False.], + [False., True., True., False.], + [False., False., True., True.]]) + """ + chunk_start_idx = torch.Tensor(chunk_start_idx).long() # first idx of each chunk, such as [0,18,36,48]. + start_pad = torch.nn.functional.pad( + chunk_start_idx, (1, 0) + ) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48] + end_pad = torch.nn.functional.pad( + chunk_start_idx, (0, 1), value=x_len + ) # append x_len to the end, so it becomes [0,18,36,48, x_len] + seq_range = torch.arange(0, x_len).unsqueeze(-1) # seq_range size: [x_len, 1] + idx = ((seq_range < end_pad) & (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len] + boundary = end_pad[idx] # boundary size: [x_len] + seq_range_expand = torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1) # seq_range_expand size [x_len, x_len] + idx_left = idx - left_window + idx_left[idx_left < 0] = 0 + boundary_left = start_pad[idx_left] + mask_left = seq_range_expand >= boundary_left.unsqueeze(-1) + idx_right = idx + right_window + idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx) + boundary_right = end_pad[idx_right] + mask_right = seq_range_expand < boundary_right.unsqueeze(-1) + return mask_left & mask_right + + +class Swish(nn.Module): + """Implement Swish activation module. + + From https://arxiv.org/pdf/2005.03191.pdf + """ + + def __init__(self) -> None: + super().__init__() + self.act_fn = nn.Sigmoid() + + def forward(self, x: Tensor) -> Tensor: + """Apply Swish function. + + Args: + x: torch.Tensor + Input. + """ + return x * self.act_fn(x) + + +class GLU(nn.Module): + """Implement Gated Linear Unit (GLU) module.""" + + def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None: + super().__init__() + self.dim = dim + self.act_name = act_name.lower() + + if self.act_name == "relu": + self.act_fn = nn.ReLU(inplace=True) + elif self.act_name == "gelu": + self.act_fn = nn.GELU() + elif self.act_name == "swish": + self.act_fn = Swish() + elif self.act_name == "sigmoid": + self.act_fn = nn.Sigmoid() + else: + self.act_fn = nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + """GLU forward + Apply Swish function on the first half of input matrices + with sigmoid of the second half. + + Args: + x: torch.Tensor + Input. + """ + half_x, gate = x.chunk(2, dim=self.dim) + return half_x * self.act_fn(gate) + + +# TODO: Abdel, this can be improved using GLU module +class GLUPointWiseConv(nn.Module): + """GLUPointWiseConv module + used for conformer architecture, + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + output_dim: int + output channel size. + kernel_size: int + kernel size + glu_type: str, optional + activation function one of + ["sigmoid", "relu", "gelu"] + default "sigmoid". + bias_in_glu: bool, optional + use addtive bias in glu + causal: bool, optional + if set to True, padding is set to the half of + kernel size, ie, convolution can't see future frames. + default False. + + """ + + def __init__(self, input_dim, output_dim, kernel_size, glu_type="sigmoid", bias_in_glu=True, causal=False): + super().__init__() + + self.glu_type = glu_type + self.output_dim = output_dim + self.bias_in_glu = bias_in_glu + if causal: + self.ext_pw_conv_1d = nn.Conv1d(input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1)) + else: + self.ext_pw_conv_1d = nn.Conv1d(input_dim, output_dim * 2, kernel_size, 1, padding=(kernel_size - 1) // 2) + + if glu_type == "sigmoid": + self.glu_act = nn.Sigmoid() + elif glu_type == "relu": + self.glu_act = nn.ReLU() + elif glu_type == "gelu": + self.glu_act = nn.GELU() + elif glu_type == "swish": + self.glu_act = Swish() + else: + raise ValueError(f"Unsupported activation type {self.glu_act}") + + if bias_in_glu: + self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1)) + self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1)) + + def forward(self, x): + """ + Args: + x: torch.Tensor + input tensor + """ + # to be consistent with GLULinear, we assume the input always has the #channel (#dim) in the last dimension of the tensor, so need to switch the dimension first for 1D-Conv case + x = x.permute([0, 2, 1]) + x = self.ext_pw_conv_1d(x) + if self.glu_type == "bilinear": + if self.bias_in_glu: + x = (x[:, 0 : self.output_dim, :] + self.b1) * ( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) + else: + x = (x[:, 0 : self.output_dim, :]) * (x[:, self.output_dim : self.output_dim * 2, :]) + else: + if self.bias_in_glu: + x = (x[:, 0 : self.output_dim, :] + self.b1) * self.glu_act( + x[:, self.output_dim : self.output_dim * 2, :] + self.b2 + ) + else: + x = (x[:, 0 : self.output_dim, :]) * self.glu_act(x[:, self.output_dim : self.output_dim * 2, :]) + + x = x.permute([0, 2, 1]) + return x + + +class DepthWiseSeperableConv1d(nn.Module): + """DepthWiseSeperableConv1d module used in Convnet module + for the conformer, for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + depthwise_seperable_out_channel: int + if set different to 0, the number of depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + kernel_size: int + kernel_size + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + padding: int, optional + padding for the conv1d, + default: 0. + + """ + + def __init__( + self, + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=0, + ): + super().__init__() + + self.dw_conv = nn.Conv1d( + input_dim, + input_dim * depthwise_multiplier, + kernel_size, + 1, + padding=padding, + groups=input_dim, + ) + + if depthwise_seperable_out_channel != 0: + self.pw_conv = nn.Conv1d(input_dim * depthwise_multiplier, depthwise_seperable_out_channel, 1, 1, 0) + else: + self.pw_conv = nn.Identity() + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + + def forward(self, x): + """ + + Args: + x: torch.Tensor + input tensor + """ + x = self.dw_conv(x) + if self.depthwise_seperable_out_channel != 0: + x = self.pw_conv(x) + return x + + +class ConvModule(nn.Module): + """ConvModule Module for the conformer block. + for more details see: + https://arxiv.org/pdf/2005.08100v1.pdf + + Args: + input_dim: int + input channel size. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation. + default False + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + chunk_size: int, optional + chunk size for cnn. default 18 + activation: str, optional + activation function used in ConvModule, + default: "relu". + glu_type: str, optional + activation function used for the glu, + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + export: bool, optional, + if set to True, padding is equal to 0. This is for inference, + or onnx export. Typically this is set by the export program or + the decoder program, and it isn't present in your config file. + default False + """ + + def __init__( + self, + input_dim, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal=False, + batch_norm=False, + chunk_se=0, + chunk_size=18, + activation="relu", + glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + export=False, + ): + super().__init__() + self.layer_norm = nn.LayerNorm(input_dim) + self.input_dim = input_dim + self.ext_pw_out_channel = ext_pw_out_channel + self.ext_pw_kernel_size = ext_pw_kernel_size + self.depthwise_seperable_out_channel = depthwise_seperable_out_channel + self.glu_type = glu_type + self.bias_in_glu = bias_in_glu + self.linear_glu_in_convm = linear_glu_in_convm + self.causal = causal + + self._add_ext_pw_layer() + + self.batch_norm = batch_norm + self.kernel_size = kernel_size + + if batch_norm: + self.bn_layer = nn.BatchNorm1d(input_dim) + + self.act = get_activation(activation) + self.dropout = nn.Dropout(dropout_rate) + self.export = export + + if causal: + if export: # Inference only. + padding = 0 # A cache is concatenated to the left. No padding in the kernel. + else: + # Training only. Padding will be added symmetrically on both sides. + # After convolution, clip off kernel_size-1 points on the right. + padding = kernel_size - 1 + else: + padding = (kernel_size - 1) // 2 + + self.dw_sep_conv_1d = DepthWiseSeperableConv1d( + input_dim, + depthwise_seperable_out_channel, + kernel_size, + depthwise_multiplier, + padding=padding, + ) + + if depthwise_seperable_out_channel != 0: + if input_dim != depthwise_seperable_out_channel: + self.ln2 = nn.Linear(depthwise_seperable_out_channel, input_dim) + else: + if depthwise_multiplier != 1: + self.ln2 = nn.Linear(input_dim * depthwise_multiplier, input_dim) + + def _add_ext_pw_layer(self): + """This function is an extension of __init__ function + and dedicated to the convolution module creation + of the conformer.""" + self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = nn.Identity() # jit hacks. + self.squeeze_excitation = nn.Identity() # jit. + self.apply_ln1 = self.fix_len1 = False # jit. + + if self.ext_pw_out_channel != 0: + if self.causal: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1), + ) + if self.ext_pw_kernel_size > 1: + self.fix_len1 = True + else: + self.fix_len1 = False + else: + self.ext_pw_conv_1d = nn.Conv1d( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + 1, + padding=(self.ext_pw_kernel_size - 1) // 2, + ) + self.fix_len1 = False + + if self.linear_glu_in_convm: + self.glu = GLULinear(self.input_dim, self.ext_pw_out_channel, self.glu_type, self.bias_in_glu) + else: + self.glu = GLUPointWiseConv( + self.input_dim, + self.ext_pw_out_channel, + self.ext_pw_kernel_size, + self.glu_type, + self.bias_in_glu, + self.causal, + ) + + if self.input_dim != self.ext_pw_out_channel: + self.apply_ln1 = True + self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim) + else: + self.apply_ln1 = False + else: + self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3)) + self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3)) + + def forward(self, x): + """ConvModule Forward. + + Args: + x: torch.Tensor + input tensor. + """ + x = self.layer_norm(x) + + if self.ext_pw_out_channel != 0: + x = self.glu(x) + if self.causal and self.ext_pw_kernel_size > 1: + x = x[:, : -(self.ext_pw_kernel_size - 1), :] + if self.apply_ln1: + x = self.ln1(x) + else: + x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0] + x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1] + x = x_0 + x_1 + + x = x.permute([0, 2, 1]) + + x = self.dw_sep_conv_1d(x) + if self.causal and self.kernel_size > 1: + x = x[:, :, : -(self.kernel_size - 1)] + if hasattr(self, "ln2"): + x = x.permute([0, 2, 1]) + x = self.ln2(x) + x = x.permute([0, 2, 1]) + if self.batch_norm: + x = self.bn_layer(x) + x = self.act(x) + + if self.ext_pw_out_channel != 0: + x = self.ext_pw_conv_1d(x) + if self.fix_len1: + x = x[:, :, : -(self.ext_pw_kernel_size - 1)] + + if self.apply_ln1: + x = x.permute([0, 2, 1]) + x = self.ln1(x) + x = x.permute([0, 2, 1]) + + x = x.permute([0, 2, 1]) + else: + x = x.unsqueeze(1).permute([0, 1, 3, 2]) + x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2] + x = x.squeeze(1) + + x = self.dropout(x) + return x + + +class GLULinear(nn.Module): + """Linear + GLU module. + + Args: + input_dim: int + input size + output_dim: int + output size. + glu_type: + activation function name used in glu module. + default "sigmoid" (swish function). + bias_in_glu: bool, optional + If True, the addtive bias is added. Default False. + """ + + def __init__( + self, + input_dim, + output_dim, + glu_type="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu) + self.glu_act = GLU(-1, glu_type) + + def forward(self, x): + """GLULinear forward. + + Args: + x: torch.Tensor + inpute tensor. + """ + x = self.linear(x) + return self.glu_act(x) + + +class FeedForward(nn.Module): + """FeedForward Module. + For more details see Conformer paper: + https://arxiv.org/pdf/2005.08100.pdf + + Args: + d_model: int + input size. + d_inner: int + output size. + dropout_rate: float, + dropout rate. + activation: str, + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "sigmoid". + bias_in_glu: bool, optional + """ + + def __init__( + self, + d_model, + d_inner, + dropout_rate, + activation="sigmoid", + bias_in_glu=True, + ): + super().__init__() + self.d_model = d_model + self.d_inner = d_inner + + self.layer_norm = nn.LayerNorm(d_model) + module = GLULinear(d_model, d_inner, activation, bias_in_glu) + self.net = nn.Sequential( + module, + nn.Dropout(dropout_rate), + nn.Linear(d_inner, d_model), + nn.Dropout(dropout_rate), + ) + + def forward(self, x): + """FeedForward forward function. + + Args: + x: torch.Tensor + input tensor. + """ + out = self.net(self.layer_norm(x)) + + return out + + +#### positional encoding starts here +def _pre_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): + """Perform pre-hook in load_state_dict for backward compatibility. + + Note: + We saved self.pe until v.0.5.2 but we have omitted it later. + Therefore, we remove the item "pe" from `state_dict` for backward compatibility. + """ + k = prefix + "pe" + if k in state_dict: + state_dict.pop(k) + + +class T5RelativeAttentionLogitBias(nn.Module): + """ + This module implements the relative position bias described in Section 2.1 of + the T5 paper: https://arxiv.org/pdf/1910.10683.pdf + + The Huggingface implementation is used as a reference + /~https://github.com/huggingface/transformers/blob/v4.30.0/src/transformers/models/t5/modeling_t5.py#L435 + + Modifies attention as Q*K^T + B, where B is a learned scalar bias based on relative position + of the query and key. It is HxNxN, where H is the number of heads, N is the sequence length. + + I've made these modifications to the original T5 bias: + - Skipping of the bucketing step. Original T5 bias converted rel position distances into + logarithmically increasing buckets. This is supposed to help with length generalization. + - I just directly use rel position index as bias values, as we don't need length + generalization (40s max is good enough for ASR encoder), and it keeps ONNX export simple. + - I've also extended it so that biases can be asymmetric, the default implementation treats + L->R and R->L the same. Asymmetric was found to yield better results in my experiments. + + Args: + num_heads: int + Number of attention heads + num_buckets: int + Number of buckets to use for relative attention bias. This is the size of the learnable + bias parameter. Bucketing is not yet supported, so this defaults to -1 which means + no bucketing is used (max_distance determines size of bias param). + max_distance: int + Maximum distance to use for relative attention bias. With num_buckets=-1, this directly + controls the max size of the bias parameter. When num_buckets > 0 is supported, this + will control the maximum distance for logarithmic bucketing after which all positions + are in the same bucket. + symmetric: bool + Whether to use symmetric or asymmetric biases. symmetric=False uses 2x number of bias + params to distinguish L->R from R->L. This was found to be better for the encoder. + """ + + def __init__(self, num_heads, num_buckets=-1, max_distance=1000, symmetric=False): + super().__init__() + self.num_heads = num_heads + self.num_buckets = num_buckets + self.max_distance = max_distance + self.symmetric = symmetric + self._skip_bucketing = self.num_buckets < 0 + if self._skip_bucketing: + self.num_buckets = max_distance + else: + raise NotImplementedError("T5 attention bias with bucketed positions is not yet tested") + if not self.symmetric: + self.num_buckets *= 2 + self.bias_values = nn.Embedding(self.num_buckets, self.num_heads) + + def forward(self, x): + # instantiate bias compatible with shape of x + maxpos = x.size(1) + context_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[:, None] + memory_position = torch.arange(maxpos, device=x.device, dtype=torch.long)[None, :] + relative_position = memory_position - context_position + # clipping to a maximum distance using ops that play well with ONNX export + relative_position = relative_position.masked_fill(relative_position < -self.max_distance, -self.max_distance) + relative_position = relative_position.masked_fill( + relative_position > self.max_distance - 1, self.max_distance - 1 + ) + + # mapping from relative position to index in the bias parameter + if self._skip_bucketing: + bias_idx = relative_position + else: + bias_idx = self._bucket_relative_position(relative_position) + if self.symmetric: + bias_idx = bias_idx.abs() + else: + bias_idx += self.num_buckets // 2 + + t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H] + t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(0) # [1, H, L, L] + + return t5_rel_att_bias + + def _bucket_relative_position(self, relative_position): + # this is a placeholder (isn't tested, likely buggy) using HuggingFace implem as a reference + # this also needs to be extended to support asymmetric +/- ve positions + relative_buckets = 0 + if not self.causal: + num_buckets //= 2 + relative_buckets += (relative_position > 0).to(torch.long) * num_buckets + relative_position = torch.abs(relative_position) + else: + relative_position = -torch.min(relative_position, torch.zeros_like(relative_position)) + # now relative_position is in the range [0, inf) + + # half of the buckets are for exact increments in positions + max_exact = num_buckets // 2 + is_small = relative_position < max_exact + + # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance + relative_position_if_large = max_exact + ( + torch.log(relative_position.float() / max_exact) + / math.log(self.max_distance / max_exact) + * (num_buckets - max_exact) + ).to(torch.long) + relative_position_if_large = torch.min( + relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1) + ) + + relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) + return relative_buckets + + +class AbsolutePositionalEncoding(nn.Module): + """Absolute Positional encoding module. + This module implement Absolute sinusoidal positional encoding + from: https://arxiv.org/pdf/1706.03762.pdf + + Args: + d_model: int + Input embedding size. + dropout_rate: float + dropout rate + max_len: int, optional + Maximum input length sequence, Default 5000 + + """ + + def __init__(self, d_model, dropout_rate, max_len=5000): + """Construct an PositionalEncoding object.""" + super().__init__() + self.d_model = d_model + self.xscale = math.sqrt(self.d_model) + self.dropout = torch.nn.Dropout(p=dropout_rate) + self.pe = None + self.extend_pe(torch.tensor(0.0).expand(1, max_len)) + self._register_load_state_dict_pre_hook(_pre_hook) + + def extend_pe(self, x): + """Reset the positional encodings. + + Args: + x: torch.Tensor + """ + if self.pe is not None: + if self.pe.size(1) >= x.size(1): + if self.pe.dtype != x.dtype or self.pe.device != x.device: + self.pe = self.pe.to(dtype=x.dtype, device=x.device) + return + pe = torch.zeros(x.size(1), self.d_model) + position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, self.d_model, 2, dtype=torch.float32) * -(math.log(10000.0) / self.d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0) + self.pe = pe.to(device=x.device, dtype=x.dtype) + + def forward(self, x: torch.Tensor): + """Add positional encoding. + + Args: + x: torch.Tensor + Input tensor. shape is (batch, time, ...) + + Returns: + torch.Tensor: Encoded tensor. Its shape is (batch, time, ...) + """ + self.extend_pe(x) + x = x * self.xscale + self.pe[:, : x.size(1)] + return self.dropout(x) + + +#### forward embedding layers starts here + + +@backoff.on_exception(backoff.expo, Exception, max_tries=10) +def np_loadtxt_with_retry(filepath): + """np.loadtxt with retry. + + Args: + filepath: str + file path to the numpy array. + """ + result = np.loadtxt(filepath, dtype="f") + return result + + +class MeanVarianceNormLayer(nn.Module): + """Mean/variance normalization layer. + + Will subtract mean and multiply input by inverted standard deviation. + Typically used as a very first layer in a model. + + Args: + input_size: int + layer input size. + """ + + def __init__(self, input_size): + super().__init__() + self.input_size = input_size + self.register_buffer("global_mean", torch.zeros(input_size)) + self.register_buffer("global_invstd", torch.ones(input_size)) + self.global_mean: Optional[Tensor] + self.global_invstd: Optional[Tensor] + + def forward(self, input_: Tensor) -> Tensor: + """MeanVarianceNormLayer Forward. + + Args: + input_: torch.Tensor + input tensor. + """ + return (input_ - self.global_mean) * self.global_invstd + + def load_mean_invstd(self, mean_file, invstd_file, cuside_features=False): + """Load feature mean and variance used for normalization. + + Args: + mean_file: str + path to the feature mean statistics file. + invstd_file: str + path to the features inverted standard deviation + statistics file. + cuside_features: bool + Boolean that indicates CUSIDE is being used. + The statistics of CUSIDE features are copied + from the normal features + """ + self.global_mean.data = torch.from_numpy(np_loadtxt_with_retry(mean_file)) + self.global_invstd.data = torch.from_numpy(np_loadtxt_with_retry(invstd_file)) + + if cuside_features: + self.global_mean.data = torch.cat((self.global_mean.data, self.global_mean.data), 0) + self.global_invstd.data = torch.cat((self.global_invstd.data, self.global_invstd.data), 0) + + +class CausalConv1D(nn.Conv1d): + """A causal version of nn.Conv1d where each step would have limited access to locations on its right or left + All arguments are the same as nn.Conv1d except padding. + + If padding is set None, then paddings are set automatically to make it a causal convolution where each location would not see any steps on its right. + + If padding is set as a list (size of 2), then padding[0] would be used as left padding and padding[1] as right padding. + It would make it possible to control the number of steps to be accessible on the right and left. + This mode is not supported when stride > 1. padding[0]+padding[1] should be equal to (kernel_size - 1). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + self.cache_drop_size = None + if padding is None: + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + else: + if stride != 1 and padding != kernel_size - 1: + raise ValueError("No striding allowed for non-symmetric convolutions!") + if isinstance(padding, int): + self._left_padding = padding + self._right_padding = padding + elif isinstance(padding, list) and len(padding) == 2 and padding[0] + padding[1] == kernel_size - 1: + self._left_padding = padding[0] + self._right_padding = padding[1] + else: + raise ValueError(f"Invalid padding param: {padding}!") + + self._max_cache_len = self._left_padding + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode, + device=device, + dtype=dtype, + ) + + def update_cache(self, x, cache=None): + if cache is None: + new_x = F.pad(x, pad=(self._left_padding, self._right_padding)) + next_cache = cache + else: + new_x = F.pad(x, pad=(0, self._right_padding)) + new_x = torch.cat([cache, new_x], dim=-1) + if self.cache_drop_size > 0: + next_cache = new_x[:, :, : -self.cache_drop_size] + else: + next_cache = new_x + next_cache = next_cache[:, :, -cache.size(-1) :] + return new_x, next_cache + + def forward(self, x, cache=None): + x, cache = self.update_cache(x, cache=cache) + x = super().forward(x) + if cache is None: + return x + else: + return x, cache + + +class CausalConv2D(nn.Conv2d): + """A causal version of nn.Conv2d where each location in the 2D matrix would have no access to locations on its right or down + All arguments are the same as nn.Conv2d except padding which should be set as None.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + stride: int = 1, + padding: Union[str, int] = 0, + dilation: int = 1, + groups: int = 1, + bias: bool = True, + padding_mode: str = "zeros", + device=None, + dtype=None, + ) -> None: + if padding is not None: + raise ValueError("Argument padding should be set to None for CausalConv2D.") + self._left_padding = kernel_size - 1 + self._right_padding = stride - 1 + + padding = 0 + super().__init__( + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + bias, + padding_mode, + device, + dtype, + ) + + def forward( + self, + x, + ): + if self.training: + x = F.pad( + x, + pad=( + self._left_padding, + self._right_padding, + self._left_padding, + self._right_padding, + ), + ) + else: + x = F.pad( + x, + pad=(self._left_padding, self._right_padding, 0, 0), + ) + x = super().forward(x) + return x + + +class NemoConvSubsampling(torch.nn.Module): + """Convlutional subsampling module, taken from NeMo ASR + (/~https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a34501479cf/nemo/collections/asr/parts/submodules/subsampling.py) + + Striding Subsampling: "Speech-Transformer: A No-Recurrence Sequence-to-Sequence Model for + Speech Recognition" by Linhao Dong et al. (https://ieeexplore.ieee.org/document/8462506) + + + Compared with the EncoderConv2D (`input_layer: custom`), this is a much simplified approach, + and uses no LayerNorm and far fewer Conv2Ds. Moreover, depthwise convolutions are used to reduce + FLOPs, but the first layer is kept as a regular convolution so as not to degrade accuracy. + + `Striding` and `dw_striding` are the same except that the latter uses depthwise convolutions + after the first layer, whereas the former does not. + + Args: + subsampling_factor (int): Time reduction factor + feat_in (int): size of the input features + feat_out (int): size of the output features + subsampling (str): The subsampling technique, choose from + {"striding", "dw-striding", "striding_conv1d", "dw_striding_conv1d"} + conv_channels (int): Number of channels for the convolution layers, default is 256. + subsampling_conv_chunking_factor (int): Input chunking factor which can be -1 (no chunking) + 1 (auto) or a power of 2. Default is 1 + activation (Module): activation function, default is nn.ReLU() + is_causal (bool): whether to use causal Conv1/2D, where each step will have limited access + to locations on its right or left + """ + + def __init__( + self, + feat_in, + feat_out, + subsampling_factor=4, + subsampling="dw_striding", + conv_channels=256, + subsampling_conv_chunking_factor=1, + activation=nn.ReLU(), + is_causal=False, + ): + super().__init__() + self._subsampling = subsampling + self._conv_channels = conv_channels + self._feat_in = feat_in + self._feat_out = feat_out + + if subsampling_factor % 2 != 0: + raise ValueError("Sampling factor should be a multiply of 2!") + self._sampling_num = int(math.log(subsampling_factor, 2)) + self.subsampling_factor = subsampling_factor + self.is_causal = is_causal + self.subsampling_causal_cond = subsampling in ("dw_striding", "striding", "striding_conv1d") + + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2") + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + in_channels = 1 + layers = [] + + if subsampling == "dw_striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + # Layer 1 + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + groups=in_channels, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ) + ) + + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding": + self._stride = 2 + self._kernel_size = 3 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv2D( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv2d( + in_channels=in_channels, + out_channels=conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + if self.is_causal: + self._left_padding = self._kernel_size - 1 + self._right_padding = self._stride - 1 + self._max_cache_len = subsampling_factor + 1 + else: + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + self._max_cache_len = 0 + + for i in range(self._sampling_num): + if self.is_causal: + layers.append( + CausalConv1D( + in_channels=in_channels, + out_channels=feat_out if self._sampling_num == i + 1 else conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=None, + ) + ) + else: + layers.append( + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=feat_out if self._sampling_num == i + 1 else conv_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + ) + ) + layers.append(activation) + in_channels = conv_channels + + elif subsampling == "dw_striding_conv1d": + in_channels = feat_in + + self._stride = 2 + self._kernel_size = 5 + self._ceil_mode = False + + self._left_padding = (self._kernel_size - 1) // 2 + self._right_padding = (self._kernel_size - 1) // 2 + + # Layer 1 + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=feat_out if self._sampling_num == 1 else conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) + in_channels = conv_channels + layers.append(activation) + + for i in range(self._sampling_num - 1): + layers.extend( + [ + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=in_channels, + kernel_size=self._kernel_size, + stride=self._stride, + padding=self._left_padding, + groups=in_channels, + ), + torch.nn.Conv1d( + in_channels=in_channels, + out_channels=feat_out if self._sampling_num == i + 2 else conv_channels, + kernel_size=1, + stride=1, + padding=0, + groups=1, + ), + ] + ) + layers.append(activation) + in_channels = conv_channels + + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + if subsampling in ["dw_striding", "striding"]: + in_length = torch.tensor(feat_in, dtype=torch.float) + out_length = calc_length( + lengths=in_length, + all_paddings=self._left_padding + self._right_padding, + kernel_size=self._kernel_size, + stride=self._stride, + ceil_mode=self._ceil_mode, + repeat_num=self._sampling_num, + ) + self.out = torch.nn.Linear(conv_channels * int(out_length), feat_out) + self.conv2d_subsampling = True + elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]: + self.out = None + self.conv2d_subsampling = False + else: + raise ValueError(f"Not valid sub-sampling: {subsampling}!") + + self.conv = torch.nn.Sequential(*layers) + + def get_sampling_frames(self): + return [1, self.subsampling_factor] + + def get_streaming_cache_size(self): + return [0, self.subsampling_factor + 1] + + def forward(self, x, mask): + """Forward method for NeMo subsampling. + + Args: + x[Batch, Time, Filters]: torch.Tensor + input tensor + x_mask: torch.Tensor + input mask + + Returns: + x: torch.Tensor + Resulting tensor from subsampling (B, T // time_reduction_factor, feat_out) + pad_mask: torch.Tensor + tensor of padded hidden state sequences (B, 1, T // time_reduction_factor) + """ + # Unsqueeze Channel Axis + if self.conv2d_subsampling: + x = x.unsqueeze(1) + # Transpose to Channel First mode + else: + x = x.transpose(1, 2) + + # split inputs if chunking_factor is set + if self.subsampling_conv_chunking_factor != -1 and self.conv2d_subsampling: + if self.subsampling_conv_chunking_factor == 1: + # if subsampling_conv_chunking_factor is 1, we split only if needed + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see /~https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + if torch.numel(x) > x_ceil: + need_to_split = True + else: + need_to_split = False + else: + # if subsampling_conv_chunking_factor > 1 we always split + need_to_split = True + + if need_to_split: + x, success = self.conv_split_by_batch(x) + if not success: # if unable to split by batch, try by channel + if self._subsampling == "dw_striding": + x = self.conv_split_by_channel(x) + else: + x = self.conv(x) # try anyway + else: + x = self.conv(x) + else: + x = self.conv(x) + + # Flatten Channel and Frequency Axes + if self.conv2d_subsampling: + b, c, t, f = x.size() + x = self.out(x.transpose(1, 2).reshape(b, t, -1)) + # Transpose to Channel Last mode + else: + x = x.transpose(1, 2) + + if mask is None: + return x, None + + max_audio_length = x.shape[1] + feature_lens = mask.sum(1) + padding_length = torch.ceil(feature_lens / self.subsampling_factor) + if self.is_causal and self.subsampling_causal_cond: + feature_lens_remainder = feature_lens % self.subsampling_factor + padding_length[feature_lens_remainder != 1] += 1 + pad_mask = torch.arange(0, max_audio_length, device=x.device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + return x, pad_mask.unsqueeze(1) + + def reset_parameters(self): + # initialize weights + if self._subsampling == "dw_striding": + with torch.no_grad(): + # init conv + scale = 1.0 / self._kernel_size + dw_max = (self._kernel_size**2) ** -0.5 + pw_max = self._conv_channels**-0.5 + + torch.nn.init.uniform_(self.conv[0].weight, -scale, scale) + torch.nn.init.uniform_(self.conv[0].bias, -scale, scale) + + for idx in range(2, len(self.conv), 3): + torch.nn.init.uniform_(self.conv[idx].weight, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx].bias, -dw_max, dw_max) + torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max, pw_max) + torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max, pw_max) + + # init fc (80 * 64 = 5120 from /~https://github.com/kssteven418/Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/src/models/conformer_encoder.py#L487 + fc_scale = (self._feat_out * self._feat_in / self._sampling_num) ** -0.5 + torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale) + torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale) + + def conv_split_by_batch(self, x): + """Tries to split input by batch, run conv and concat results.""" + b, _, _, _ = x.size() + if b == 1: # can't split if batch size is 1 + return x, False + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see /~https://github.com/pytorch/pytorch/issues/80020 + x_ceil = 2**31 / self._conv_channels * self._stride * self._stride + p = math.ceil(math.log(torch.numel(x) / x_ceil, 2)) + cf = 2**p + + new_batch_size = b // cf + if new_batch_size == 0: # input is too big + return x, False + + return torch.cat([self.conv(chunk) for chunk in torch.split(x, new_batch_size, 0)]), True + + def conv_split_by_channel(self, x): + """For dw convs, tries to split input by time, run conv and concat results.""" + x = self.conv[0](x) # full conv2D + x = self.conv[1](x) # activation + + for i in range(self._sampling_num - 1): + _, c, t, _ = x.size() + + if self.subsampling_conv_chunking_factor > 1: + cf = self.subsampling_conv_chunking_factor + else: + # avoiding a bug / feature limiting indexing of tensors to 2**31 + # see /~https://github.com/pytorch/pytorch/issues/80020 + p = math.ceil(math.log(torch.numel(x) / 2**31, 2)) + cf = 2**p + + new_c = int(c // cf) + if new_c == 0: + new_c = 1 + + new_t = int(t // cf) + if new_t == 0: + new_t = 1 + + x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c, x) # conv2D, depthwise + + # splitting pointwise convs by time + x = torch.cat([self.conv[i * 3 + 3](chunk) for chunk in torch.split(x, new_t, 2)], 2) # conv2D, pointwise + x = self.conv[i * 3 + 4](x) # activation + return x + + def channel_chunked_conv(self, conv, chunk_size, x): + """Performs channel chunked convolution.""" + + ind = 0 + out_chunks = [] + for chunk in torch.split(x, chunk_size, 1): + step = chunk.size()[1] + + if self.is_causal: + chunk = nn.functional.pad( + chunk, + pad=( + self._kernel_size - 1, + self._stride - 1, + self._kernel_size - 1, + self._stride - 1, + ), + ) + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=0, + groups=step, + ) + else: + ch_out = nn.functional.conv2d( + chunk, + conv.weight[ind : ind + step, :, :, :], + bias=conv.bias[ind : ind + step], + stride=self._stride, + padding=self._left_padding, + groups=step, + ) + out_chunks.append(ch_out) + ind += step + + return torch.cat(out_chunks, 1) + + def change_subsampling_conv_chunking_factor(self, subsampling_conv_chunking_factor: int): + if ( + subsampling_conv_chunking_factor != -1 + and subsampling_conv_chunking_factor != 1 + and subsampling_conv_chunking_factor % 2 != 0 + ): + raise ValueError("subsampling_conv_chunking_factor should be -1, 1, or a power of 2") + self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor + + +def calc_length(lengths, all_paddings, kernel_size, stride, ceil_mode, repeat_num=1): + """Calculates the output length of a Tensor passed through a convolution or max pooling layer.""" + add_pad: float = all_paddings - kernel_size + one: float = 1.0 + for i in range(repeat_num): + lengths = torch.div(lengths.to(dtype=torch.float) + add_pad, stride) + one + if ceil_mode: + lengths = torch.ceil(lengths) + else: + lengths = torch.floor(lengths) + return lengths.to(dtype=torch.int) + + +#### multihead attention starts here +class AttModule(nn.Module): + """Attention abstraction module.""" + + def __init__(self): + super().__init__() + self.export_mode = False + + def set_export(self, mode=True): + """Set the export mode.""" + self.export_mode = mode + + def forward( + self, + x: Tensor, + memory: Optional[Tensor] = None, + pos_emb: Optional[Tensor] = None, + att_mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]: + """AttModule forward. + + Args: + x: torch.Tensor + input tensor. + memory: torch.Tensor, optional + memory tensor. + pos_emb: torch.Tensor, optional + positional encoder embedding. + att_mask: torch.Tensor, optional + attention mask tensor. + """ + return x, memory, pos_emb, att_mask + + +class AttBlock(Block, AttModule): + """Attention Block module to support both Attention and Block module.""" + + def memory_dims(self, max_len=False): + """Memory dimensions.""" + return (1, self.input_size) + + +def masked_softmax( + scores, + mask: Optional[Tensor], +): + if mask is not None: + mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2) + scores = scores.masked_fill(mask, -torch.inf) + attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) # (batch, head, time1, time2) + else: + attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) + return attn + + +class MultiHeadedAttention(nn.Module): + """Multi-Head Attention layer with optional relative position embedding and GLU. + + Args: + n_head: int + the number of heads. + n_feat: int + input size features. + dropout_rate: float + dropout rate. + use_LN: bool + apply layer norm or not + dropout_at_output: bool + whether to apply dropout at output + attention_inner_dim: int, optional + the attention dimension used in the class, + it can be different from the input dimension n_feat. + default: -1 (equal to n_feat). + use_pt_scaled_dot_product_attention: bool, optional + if set True, use pytorch scaled dot product attention in training. NOTE: this will NOT + be used in ONNX decoding due to a lack of support. In that case, we use the original + attention implementation, which shows no regression. + default: False. + n_value: int, optional + if set to values other than -1, use a different dimension for value. With the default value (i.e. -1), it is backward compatible. + group_size: int, optional. must divide `n_head` + if group_size > 1: GQA + if group_size = 1: MHA + if group_size = n_head: MQA + """ + + inv_sqrt_d_k: torch.jit.Final[float] + h: torch.jit.Final[int] + h_k: torch.jit.Final[int] + g: torch.jit.Final[int] + + def __init__( + self, + n_head, + n_feat, + dropout_rate, + attention_inner_dim=-1, + glu_type="swish", + bias_in_glu=True, + use_pt_scaled_dot_product_attention=False, + n_value=-1, + group_size: int = 1, + ): + super().__init__() + if n_value == -1: + n_value = n_feat + if attention_inner_dim == -1: + attention_inner_dim = n_feat + assert attention_inner_dim % n_head == 0 + + # We assume d_v always equals d_k + self.d_k = attention_inner_dim // n_head + self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k) + self.h = n_head + assert n_head % group_size == 0, "group_size must divide n_head" + self.g = group_size + self.h_k = n_head // group_size + + self.linear_q = nn.Linear(n_feat, attention_inner_dim) + self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size) + self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size) + self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value) + + self.attn = torch.jit.Attribute(None, Optional[Tensor]) + self.dropout = nn.Dropout(p=dropout_rate) + self.dropout_rate = dropout_rate + self.use_pt_scaled_dot_product_attention = use_pt_scaled_dot_product_attention + + if use_pt_scaled_dot_product_attention and group_size > 1: + raise ValueError("Cannot use PT Scaled Attention with GQA") + + # Torchscript eager quantization. Note that these functions below are + # NOOPs and have very little impact on performance unless quantization is + # enabled. + self.quant_q = torch.ao.quantization.QuantStub() + self.quant_x = torch.ao.quantization.QuantStub() + self.dequant = torch.ao.quantization.DeQuantStub() + self.ffunc = torch.ao.nn.quantized.FloatFunctional() + + def forward( + self, + query: Tensor, + key: Tensor, + value: Tensor, + pos_k: Tensor, + pos_v: Tensor, + mask: Optional[Tensor], + relative_attention_bias: Optional[Tensor] = None, + ): + """Compute 'Scaled Dot Product Attention'. + + Args: + query: torch.Tensor + query tensor (batch, time1, size) + key: torch.Tensor + key tensor (batch, time2, size) + value: torch.Tensor + value tensor (batch, time1, size) + pos_k: torch.Tensor + key tensor used for relative positional embedding. + pos_v: torch.Tensor + value tensor used for relative positional embedding. + mask: torch.Tensor + mask tensor (batch, time1, time2) + relative_attention_bias: torch.Tensor + bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) + """ + n_batch = query.size(0) + + q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) # (b, t, d) + k = self.linear_k(key).view(n_batch, -1, self.h_k, self.d_k) # (b, t, d) + v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k) + q = ( + q.transpose(1, 2) + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting() + else q.transpose(1, 2) * self.inv_sqrt_d_k + ) + k = k.transpose(1, 2) # (batch, head_k, time2, d_k) + v = v.transpose(1, 2) # (batch, head_k, time2, d_k) + + if self.use_pt_scaled_dot_product_attention and not torch.jit.is_scripting(): + attn_mask = None + if mask is not None: + mask = mask.unsqueeze(1) + if relative_attention_bias is not None: + attn_mask = mask + relative_attention_bias + else: + attn_mask = mask + if mask.dtype != q.dtype: + attn_mask = attn_mask.to(q.dtype) + + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + x = torch.nn.functional.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.dropout_rate, + ) + else: + if self.h != self.h_k: + q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k) + A = torch.einsum("b g h t d, b h s d -> b h t s", q, k) + else: + A = torch.matmul(q, k.transpose(-2, -1)) + if pos_k is not None: + if self.h != self.h_k: + B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k) + else: + reshape_q = q.contiguous().view(n_batch * self.h, -1, self.d_k).transpose(0, 1) # (t1,nh,dk) + B = torch.matmul(reshape_q, pos_k.transpose(-2, -1)) # pos_k: (t1,dk,t2) + B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0), pos_k.size(1)) + scores = A + B + else: + scores = A + + if relative_attention_bias is not None: + scores = scores + relative_attention_bias + + attn = masked_softmax(scores, mask) # (batch, head, time1, time2) + + self.attn = attn + + p_attn = self.dropout(attn) + x = torch.matmul(p_attn.to(v.dtype), v) # (batch, head, time1, d_k) + if pos_v is not None: + reshape_attn = ( + p_attn.contiguous().view(n_batch * self.h, pos_v.size(0), pos_v.size(1)).transpose(0, 1) + ) # (t1, bh, t2) + + attn_v = ( + torch.matmul(reshape_attn, pos_v) + .transpose(0, 1) + .contiguous() + .view(n_batch, self.h, pos_v.size(0), self.d_k) + ) + x = x + attn_v + x = x.transpose(1, 2).contiguous().view(n_batch, -1, self.h_k * self.d_k) # (batch, time1, d_model) + + return self.linear_out(x) # (batch, time1, d_model) + + +def unfold_tensor(xs_pad, max_seq_len): + """For a given tensor with shape of (N, T, D), if sequence length T is longer than max_seq_len, + this function unfold it to a (NT', max_seq_len, D) where T' is T // max_seq_len. + + Args: + xs_pad: N, T, D + """ + _, _, D = xs_pad.shape + xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T + # N x D x 1 x T => N x (D x max_seq_len) x T' + xs_pad = F.unfold( + xs_pad[..., None, :], + kernel_size=(1, max_seq_len), + stride=(1, max_seq_len), + ) + + new_bsz, _, slen = xs_pad.shape + # N x D x max_seq_len x T' + xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen) + # N x T' x max_seq_len x D + xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous() + # NT' x max_seq_len x D + xs_pad = xs_pad.view(-1, max_seq_len, D) + return xs_pad + + +# conformer_encoder.py +class MultiSequential(torch.nn.Sequential): + """Multi-input multi-output torch.nn.Sequential.""" + + @torch.jit.ignore + def forward(self, *args): + """Forward method implementation.""" + for m in self: + args = m(*args) + return args + + +def repeat(repeat_num, module_gen_fn): + """Repeat module N times. + + :param int repeat_num: repeat time + :param function module_gen_fn: function to generate module + :return: repeated modules + :rtype: MultiSequential + """ + return MultiSequential(*[module_gen_fn(i) for i in range(repeat_num)]) + + +class ConformerEncoderLayer(nn.Module): + """ConformerEncoder Layer module. + for more details see conformer paper: + https://arxiv.org/abs/2005.08100 + This module implement the Conformer block layer. + + Args: + d_model: int + attention dim. + ext_pw_out_channel: int + if > 0, ext_pw_out_channel is a dim channel size + for the last pointwise conv after swish activation. + depthwise_seperable_out_channel: int + if set different to 0, the number of depthwise_seperable_out_channel + will be used as a channel_out of the second conv1d layer. + otherwise, it equal to 0, the second conv1d layer is skipped. + depthwise_multiplier: int + number of input_dim channels duplication. this value + will be used to compute the hidden channels of the Conv1D. + n_head: int + the number of heads for multihead attention module. + d_ffn: int + output size of the feed_forward blocks. + ext_pw_kernel_size: int + kernel size of the conv pointwise of the conformer. + kernel_size: int + kernel size. + dropout_rate: float + dropout rate. + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + activation: str, optional + activation function name, + one of ["relu", "swish", "sigmoid"], + sigmoid activation is only used with "glu_in_fnn=True", + default "relu". + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + chunk_size: int, optional + chunk_size for cnn. default 18 + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, optional + activation function used for the glu inside + the ConvModule part of the conformer. + default: "sigmoid". + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_innner_dim: int, otional + if equal to -1, attention dim for linears k/q/v is + equal to d_model. otherwise attention_innner_dim is used. + default -1. + attention_glu_type: str, optional + activation function for glu used in the multihead attention, + default "swish". + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + use_pt_scaled_dot_product_attention: bool, optional + if set to True, use pytorch's scaled dot product attention implementation in training. + attn_group_sizes: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attn_group_sizes < attention_heads = Grouped-Query Attention + attn_group_sizes = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + d_model=512, + ext_pw_out_channel=0, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + n_head=4, + d_ffn=2048, + ext_pw_kernel_size=1, + kernel_size=3, + dropout_rate=0.1, + causal=False, + batch_norm=False, + activation="relu", + chunk_se=0, + chunk_size=18, + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_innner_dim=-1, + attention_glu_type="swish", + activation_checkpointing="", + export=False, + use_pt_scaled_dot_product_attention=False, + attn_group_sizes: int = 1, + ): + super().__init__() + + self.feed_forward_in = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.self_attn = encoder_checkpoint_wrapper( + activation_checkpointing, + MultiHeadedAttention, + )( + MultiHeadedAttention( + n_head, + d_model, + dropout_rate, + attention_innner_dim, + attention_glu_type, + bias_in_glu, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + group_size=attn_group_sizes, + ) + ) + self.conv = ConvModule( + d_model, + ext_pw_out_channel, + depthwise_seperable_out_channel, + ext_pw_kernel_size, + kernel_size, + depthwise_multiplier, + dropout_rate, + causal, + batch_norm, + chunk_se, + chunk_size, + conv_activation, + conv_glu_type, + bias_in_glu, + linear_glu_in_convm, + export=export, + ) + + self.feed_forward_out = FeedForward( + d_model=d_model, + d_inner=d_ffn, + dropout_rate=dropout_rate, + activation=activation, + bias_in_glu=bias_in_glu, + ) + + self.layer_norm_att = nn.LayerNorm(d_model) + self.layer_norm = nn.LayerNorm(d_model) + + def forward( + self, + x, + pos_k, + pos_v, + mask, + relative_attention_bias: Optional[Tensor] = None, + ): + """ConformerEncoder forward. + + Args: + x: torch.Tensor + input feature of shape (batch, max_time_in, size) + pos_k: torch.Tensor + positional key embedding. + mask: torch.Tensor + mask for x (batch, max_time_in) + relative_attention_bias: Optional[torch.Tensor] + bias added to attention logits w.r.t. relative positions (1, n_head, time1, time2) + """ + x = x + 0.5 * self.feed_forward_in(x) + norm_x = self.layer_norm_att(x) + + x = x + self.self_attn( + norm_x, + norm_x, + norm_x, + pos_k, + pos_v, + mask, + relative_attention_bias=relative_attention_bias, + ) + x = x + self.conv(x) + x = x + 0.5 * self.feed_forward_out(x) + + out = self.layer_norm(x) + + return out, pos_k, pos_v, mask + + +class TransformerEncoderBase(abc.ABC, nn.Module): + """The Base class for Transformer based encoders. + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + time_reduction: int, optional + time reduction factor + default 4 + dropout_rate: float, optional + dropout rate. default 0.1 + padding_idx: int, optional + padding index for input_layer=embed + default -1 + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention (Q*K^T + B) + implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see transformer_base.py) + positional_dropout_rate: float, optional + dropout rate after positional encoding. default 0.0 + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default None + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True). + if True or feat_time, the extra padding is added into non full + supraframe utts in batch. + Default: none + attention_group_size: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + def __init__( + self, + input_size, + chunk_size, + left_chunk, + attention_dim=256, + attention_heads=4, + input_layer="nemo_conv", + cnn_out=-1, + cnn_layer_norm=False, + time_reduction=4, + dropout_rate=0.0, + padding_idx=-1, + relative_attention_bias_args=None, + positional_dropout_rate=0.0, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__() + self.input_size = input_size + self.input_layer = input_layer + self.chunk_size = chunk_size + self.left_chunk = left_chunk + self.attention_dim = attention_dim + self.num_heads = attention_heads + self.attention_group_size = attention_group_size + self.time_reduction = time_reduction + self.nemo_conv_settings = nemo_conv_settings + self.encoder_embedding_config = encoder_embedding_config + + if self.input_layer == "nemo_conv": + default_nemo_conv_settings = { + "subsampling": "dw_striding", + "subsampling_factor": self.time_reduction, + "feat_in": input_size, + "feat_out": attention_dim, + "conv_channels": 256, + "subsampling_conv_chunking_factor": 1, + "activation": nn.ReLU(), + "is_causal": False, + } + # Override any of the defaults with the incoming, user settings + if nemo_conv_settings: + default_nemo_conv_settings.update(nemo_conv_settings) + for i in ["subsampling_factor", "feat_in", "feat_out"]: + assert i not in nemo_conv_settings, "{i} should be specified outside of the NeMo dictionary" + + self.embed = NemoConvSubsampling( + **default_nemo_conv_settings, + ) + else: + raise ValueError("unknown input_layer: " + input_layer) + + self.pos_emb = AbsolutePositionalEncoding(attention_dim, positional_dropout_rate) + + self.relative_attention_bias_type = ( + relative_attention_bias_args.get("type") if relative_attention_bias_args else None + ) + if self.relative_attention_bias_type == "t5": + assert self.num_heads % self.attention_group_size == 0, "attention_group_size must divide n_head" + self.relative_attention_bias_layer = T5RelativeAttentionLogitBias( + self.num_heads // self.attention_group_size, + max_distance=relative_attention_bias_args.get("t5_bias_max_distance", 1000), + symmetric=relative_attention_bias_args.get("t5_bias_symmetric", False), + ) + else: + raise NotImplementedError + + def post_init(self, init_model_config): + + pretrained_speech_encoder_path = init_model_config.get("pretrained_speech_encoder_path", None) + if pretrained_speech_encoder_path: + model_state = torch.load(pretrained_speech_encoder_path, map_location="cpu") + encoder_state_dict = {} + for k, v in model_state.items(): + if "encoder." in k: + tmp_k = k.replace("encoder.", "") + encoder_state_dict[tmp_k] = v + + if hasattr(self, "encoder_embedding"): + del self.encoder_embedding + self.load_state_dict(encoder_state_dict) + + if not hasattr(self, "encoder_embedding"): + self.encoder_embedding = MeanVarianceNormLayer(self.encoder_embedding_config["input_size"]) + + mean_file = init_model_config.get("mean_file", None) + invstd_file = init_model_config.get("invstd_file", None) + if mean_file is not None and invstd_file is not None: + self.encoder_embedding.load_mean_invstd(mean_file, invstd_file) + + def compute_lens_change(self, feature_lens): + """feature_lens: int + return updated feature lens. + + This used to return a different lambda function for each case that computed + the right thing. That does not work within Torchscript. If you really + need this to be faster, create nn.Module()-s for all the cases and return + one of them. Torchscript does support that. + """ + if self.input_layer == "nemo_conv": + # Handle the special causal case + subsampling_causal_cond = self.nemo_conv_settings.get("subsampling", "dw_striding") in [ + "dw_striding", + "striding", + "striding_conv1d", + ] + is_causal = self.nemo_conv_settings.get("is_causal", False) + if is_causal and subsampling_causal_cond: + lens_change = ( + torch.ceil(feature_lens / self.time_reduction).long() + if isinstance(feature_lens, Tensor) + else math.ceil(feature_lens / self.time_reduction) + ) + feature_lens_remainder = feature_lens % self.time_reduction + if isinstance(feature_lens, Tensor): + lens_change[feature_lens_remainder != 1] += 1 + elif feature_lens_remainder != 1: + lens_change += 1 + return lens_change + ceil_func = math.ceil if isinstance(feature_lens, int) else torch.ceil + return ceil_func(feature_lens / self.time_reduction) + + @abc.abstractmethod + def forward(self): + """Abstract forward method implementation.""" + + def _chunk_size_selection(self, chunk_size=None, left_chunk=None): + """If chunk size is a list, we will randomly select a chunk size.""" + + if chunk_size is None: + chunk_size = self.chunk_size + if left_chunk is None: + left_chunk = self.left_chunk + if isinstance(chunk_size, list): + # Variable chunk size during training + chunk_size_index = int(torch.randint(low=0, high=len(chunk_size), size=(1,))) + chunk_size_train_eff = chunk_size[chunk_size_index] + if not isinstance(left_chunk, list): + raise ValueError("Since chunk_size is a list, left_chunk must be a list") + if len(left_chunk) != len(chunk_size): + raise ValueError("The length of left_chunk must be the same as length of chunk_size.") + left_chunk_train_eff = left_chunk[chunk_size_index] + else: + chunk_size_train_eff = chunk_size + left_chunk_train_eff = left_chunk + + return chunk_size_train_eff, left_chunk_train_eff + + def _get_embed_class(self, embed): + # pylint: disable=protected-access + is_embed_using_act_chkpt = isinstance(embed, CheckpointWrapper) + is_embed_fsdp_wrapped = isinstance(embed, FullyShardedDataParallel) + embed_class = embed + if is_embed_using_act_chkpt: + embed_class = embed._checkpoint_wrapped_module + if is_embed_fsdp_wrapped: + embed_class = embed.module + return embed_class + + def _forward_embeddings_core(self, input_tensor, masks): + embed_class = self._get_embed_class(self.embed) + assert isinstance(embed_class, NemoConvSubsampling) + input_tensor, masks = self.embed(input_tensor, masks) + return input_tensor, masks + + def _position_embedding(self, input_tensor): + pos_k = None + pos_v = None + if self.relative_attention_bias_layer is None: + input_tensor = self.pos_emb(input_tensor) # default to add abs sinusoid embedding + return pos_k, pos_v + + def _streaming_mask(self, seq_len, batch_size, chunk_size, left_chunk): + chunk_size_train_eff, left_chunk_train_eff = self._chunk_size_selection(chunk_size, left_chunk) + + # Create mask matrix for streaming + # S stores start index. if chunksize is 18, s is [0,18,36,....] + chunk_start_idx = np.arange(0, seq_len, chunk_size_train_eff) + # avoid randomness when run evaluation or decoding + if self.training and np.random.rand() > 0.5: + # Either first or last chunk is not complete. + # If only the last one is not complete, EOS is not effective + chunk_start_idx = seq_len - chunk_start_idx + chunk_start_idx = chunk_start_idx[::-1] + chunk_start_idx = chunk_start_idx[:-1] + chunk_start_idx = np.insert(chunk_start_idx, 0, 0) + + enc_streaming_mask = ( + adaptive_enc_mask(seq_len, chunk_start_idx, left_window=left_chunk_train_eff) + .unsqueeze(0) + .expand([batch_size, -1, -1]) + ) + return enc_streaming_mask + + def forward_embeddings(self, xs_pad, masks, chunk_size_nc=None, left_chunk_nc=None): + """Forwarding the inputs through the top embedding layers. + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + input mask + chunk_size_nc: (optional, default is None) chunk size for non-causal layers + left_chunk_nc: (optional, default is None) # of left chunks for non-causal layers + """ + # pylint: disable=R0915 + # get new lens. + seq_len = int(self.compute_lens_change(xs_pad.shape[1])) + if seq_len <= 0: + raise ValueError( + f"""The sequence length after time reduction is invalid: {seq_len}. + Your input feature is too short. Consider filtering out the very + short sentence from data loader""", + ) + + batch_size = xs_pad.shape[0] + + enc_streaming_mask = self._streaming_mask(seq_len, batch_size, self.chunk_size, self.left_chunk) + + if xs_pad.is_cuda: + enc_streaming_mask = enc_streaming_mask.cuda() + xs_pad = xs_pad.cuda() + + input_tensor = xs_pad + input_tensor, masks = self._forward_embeddings_core(input_tensor, masks) + + streaming_mask = enc_streaming_mask + if streaming_mask is not None and masks is not None: + hs_mask = masks & streaming_mask + elif masks is not None: + hs_mask = masks + else: + hs_mask = streaming_mask + + if chunk_size_nc is not None: + enc_streaming_mask_nc = self._streaming_mask(seq_len, batch_size, chunk_size_nc, left_chunk_nc) + if xs_pad.is_cuda: + enc_streaming_mask_nc = enc_streaming_mask_nc.cuda() + if masks is not None: + hs_mask_nc = masks & enc_streaming_mask_nc + else: + hs_mask_nc = enc_streaming_mask_nc + else: + hs_mask_nc = None + + pos_k, pos_v = self._position_embedding(input_tensor) + + if chunk_size_nc is None: + return input_tensor, pos_k, pos_v, hs_mask, masks + return input_tensor, pos_k, pos_v, hs_mask, masks, hs_mask_nc + + def get_offset(self): + """Returns offset used when retaining inputs for decoding. + + This is essentially, how many additional frames have to be added to + the front-end CNN input to ensure it can produce a single output. + So if the "padding" parameter is 0, typically offset will be > 0. + """ + return get_offset(self.input_layer, self.time_reduction) + + +def get_offset(input_layer: str, time_reduction: int): + """Get an offset. We will use the offset for determining #frames of a subsampled feature. + + Args: + input_layer (str): Type of an input layer + time_reduction (int): time reduction factor for downsampling a feature + Returns: + int: offset + """ + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4: + return 3 + if input_layer in ("conv2d",) and time_reduction == 6: + return 1 + if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8: + return 7 + return 0 + + +class ConformerEncoder(TransformerEncoderBase): + """ConformerEncoder module. + see original paper for more details: + https://arxiv.org/abs/2005.08100 + + Please set causal = True in streaming model + Args: + input_size: int + input feature dimension. + chunk_size: int, list(int) + Number of frames for each chunk + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training + Some examples for the 2 cases: + chunk_size = 12 + chunk_size = [6, 8, 12, 24] + left_chunk: int, list(int) + Number of chunks used for masking in streaming mode. + This variable can take 2 forms: + int: Used for inference, or single chunk size training + list(int) : Used only for variable chunk size training. When + chunk_size is a list, left_chunk must be a list with same length. + Some examples for the 2 cases: + left_chunk = 6 + left_chunk = [12, 9, 6, 3] + left_chunk: int + number of chunks used for masking in streaming mode. + num_lang: int + This parameter is used to store the number of languages in the lang_dict, + only used for multiseed/multilingual models. default None. + attention_dim: int, optional + attention dimension. default 256. + attention_heads: int, optional + the number of heads. default 4 + linear_units: + the number of units of position-wise feed forward. + default 2048 + num_block: + number of Transformer layer. default 6 + dropout_rate: float, optional + dropout rate. default 0.1 + input_layer: str, optional + input layer type before Conformer, + one of ["linear", "conv2d", "custom", "vgg2l", "embed"], + default "conv2d" + causal: bool, optional + if set to True, convolution have no access + to future frames. default False. + batch_norm: bool, optional + if set to True, apply batchnorm before activation + in ConvModule layer of the conformer. + default False + cnn_out: int, optional + the number of CNN channels before Conformer. + default -1. + cnn_layer_norm: bool, optional + layer norm between Conformer and the first CNN. + default False. + ext_pw_out_channel: int, optional + the number of channel for CNN + before depthwise_seperable_CNN. + If 0 then use linear. default 0. + ext_pw_kernel_size: int, optional + kernel size of N before depthwise_seperable_CNN. + only work for ext_pw_out_channel > 0. + default 1 + depthwise_seperable_out_channel: int, optional + the number of channel for + depthwise_seperable_CNN. + default 256. + depthwise_multiplier: int, optional + the number of multiplier for + depthwise_seperable_CNN. + default 1. + chunk_se: int, optional + 0 for offline SE. + 1 for streaming SE, where mean is computed + by accumulated history until current chunk_se. + 2 for streaming SE, where mean is computed + by only the current chunk. + default 0. + kernel_size: int, optional + the number of kernels for depthwise_seperable_CNN. + default 3. + activation: str, optional + FeedForward block activation. + one of ["relu", "swish", "sigmoid"] + default "relu". + conv_activation: str, optional + activation function used in ConvModule part + of the conformer, default "relu". + conv_glu_type: str, otional + activation used use glu in depthwise_seperable_CNN, + default "sigmoid" + bias_in_glu: bool, optional + if set to True, use additive bias in the weight module + before GLU. default True + linear_glu_in_convm: bool, optional + if set to True, use GLULinear module, + otherwise, used GLUPointWiseConv module. + default to False. + attention_glu_type: str + only work for glu_in_attention !=0 + default "swish". + export: bool, optional + if set to True, it remove the padding from convolutional layers + and allow the onnx conversion for inference. + default False. + activation_checkpointing: str, optional + a dictionarry of {"module","interval","offload"}, where + "module": str + accept ["transformer", "attention"] to select + which module should do activation checkpointing. + "interval": int, default 1, + interval of applying activation checkpointing, + interval = 1 means that we apply checkpointing + on every layer (if activation), otherwise, + we apply it every x interval. + "offload": bool, default False, + if set to True, we offload activation to cpu and + reload it during backward, otherwise, + we recalculate activation in backward. + default "". + extra_layer_output_idx: int + the layer index to be exposed. + relative_attention_bias_args: dict, optional + use more efficient scalar bias-based relative multihead attention (Q*K^T + B) + implemented in cmb.basics.embedding.[T5/ALiBi]RelativeAttentionLogitBias + usage: relative_attention_bias_args={"type": t5/alibi} + additional method-specific arguments can be provided (see transformer_base.py) + time_reduction: int optional + time reduction factor + default 4 + use_pt_scaled_dot_product_attention: whether to use pytorch scaled dot product attention + in training. + Default: False + nemo_conv_settings: dict, optional + A dictionary of settings for NeMo Subsampling. + default: None + usage: nemo_conv_settings= + { + "subsampling": + dw_striding/striding/dw_striding_conv1d/striding_conv1d, + "conv_channels": int, + "subsampling_conv_chunking_factor": int, + "is_causal": True/False + } + conv2d_extra_padding: str, optional + Add extra padding in conv2d subsampling layers. Choices are + (feat, feat_time, none, True) + Default: none + replication_pad_for_subsample_embedding: For batched-streaming decoding, use + "replication" padding for the cache at start of utterance. + Default: False + attention_group_size: int, optional + the number of groups to use for attention, default 1 (Multi-Head Attention), + 1 = typical Multi-Head Attention, + 1 < attention_group_size < attention_heads = Grouped-Query Attention + attention_group_size = attenion_heads = Multi-Query Attention + """ + + extra_multi_layer_output_idxs: List[int] + + def __init__( # pylint: disable-all + self, + input_size, + chunk_size, + left_chunk, + num_lang=None, + attention_dim=256, + attention_heads=4, + linear_units=2048, + num_blocks=6, + dropout_rate=0.1, + input_layer="nemo_conv", + causal=True, + batch_norm=False, + cnn_out=-1, + cnn_layer_norm=False, + ext_pw_out_channel=0, + ext_pw_kernel_size=1, + depthwise_seperable_out_channel=256, + depthwise_multiplier=1, + chunk_se=0, + kernel_size=3, + activation="relu", + conv_activation="relu", + conv_glu_type="sigmoid", + bias_in_glu=True, + linear_glu_in_convm=False, + attention_glu_type="swish", + export=False, + extra_layer_output_idx=-1, + extra_multi_layer_output_idxs=[], + activation_checkpointing="", + relative_attention_bias_args=None, + time_reduction=4, + use_pt_scaled_dot_product_attention=False, + nemo_conv_settings=None, + conv2d_extra_padding: Literal["feat", "feat_time", "none", True] = "none", + replication_pad_for_subsample_embedding=False, + attention_group_size=1, + encoder_embedding_config=None, + ): + super().__init__( + input_size, + chunk_size, + left_chunk, + attention_dim, + attention_heads, + input_layer, + cnn_out, + cnn_layer_norm, + time_reduction, + dropout_rate=dropout_rate, + relative_attention_bias_args=relative_attention_bias_args, + positional_dropout_rate=0.0, + nemo_conv_settings=nemo_conv_settings, + conv2d_extra_padding=conv2d_extra_padding, + attention_group_size=attention_group_size, + encoder_embedding_config=encoder_embedding_config, + ) + self.num_blocks = num_blocks + self.num_lang = num_lang + self.kernel_size = kernel_size + self.embed = embedding_checkpoint_wrapper(activation_checkpointing)(self.embed) + self.replication_pad_for_subsample_embedding: bool = replication_pad_for_subsample_embedding + assert self.num_heads % attention_group_size == 0, "attention_group_size must divide n_head" + self.num_heads_k = self.num_heads // attention_group_size + + self.encoders = repeat( + num_blocks, + lambda i: encoder_checkpoint_wrapper(activation_checkpointing, ConformerEncoderLayer, i)( + ConformerEncoderLayer( + d_model=attention_dim, + ext_pw_out_channel=ext_pw_out_channel, + depthwise_seperable_out_channel=depthwise_seperable_out_channel, + depthwise_multiplier=depthwise_multiplier, + n_head=attention_heads, + d_ffn=linear_units, + ext_pw_kernel_size=ext_pw_kernel_size, + kernel_size=kernel_size, + dropout_rate=dropout_rate, + causal=causal, + batch_norm=batch_norm, + activation=activation, + chunk_se=chunk_se, + chunk_size=chunk_size, + conv_activation=conv_activation, + conv_glu_type=conv_glu_type, + bias_in_glu=bias_in_glu, + linear_glu_in_convm=linear_glu_in_convm, + attention_glu_type=attention_glu_type, + activation_checkpointing=attn_checkpointing(activation_checkpointing, i), + export=export, + use_pt_scaled_dot_product_attention=use_pt_scaled_dot_product_attention, + attn_group_sizes=attention_group_size, + ) + ), + ) + self.extra_layer_output_idx = extra_layer_output_idx + self.extra_multi_layer_output_idxs = extra_multi_layer_output_idxs + # Make a zeros scalar we can use in get_initial_state to determine + # the device and the needed dtype: + self.register_buffer("dev_type", torch.zeros(()), persistent=False) + + def init_relative_attention_bias(self, input_tensor): + if self.relative_attention_bias_layer: + return self.relative_attention_bias_layer(input_tensor) + + def calculate_hs_mask(self, xs_pad, device, mask): + max_audio_length = xs_pad.shape[1] + batch_size = xs_pad.shape[0] + enc_streaming_mask = self._streaming_mask(max_audio_length, batch_size, self.chunk_size, self.left_chunk) + enc_streaming_mask = enc_streaming_mask.to(device) + if mask is None: + return enc_streaming_mask + + feature_lens = mask.sum(1) + padding_length = feature_lens + pad_mask = torch.arange(0, max_audio_length, device=device).expand( + padding_length.size(0), -1 + ) < padding_length.unsqueeze(1) + pad_mask = pad_mask.unsqueeze(1) + pad_mask = pad_mask & enc_streaming_mask + return pad_mask + + @torch.jit.ignore + def forward(self, xs_pad, masks): + """Conformer Forward function. + + Args: + xs_pad: torch.Tensor + input tensor + masks: torch.Tensor + post-embedding input lengths + """ + xs_pad = self.encoder_embedding(xs_pad) + input_tensor, pos_k, pos_v, hs_mask, masks = self.forward_embeddings(xs_pad, masks) + + unfolded = False + ori_bz, seq_len, D = input_tensor.shape + max_seq_len = 500 # maximum position for absolute positional encoding + if seq_len > max_seq_len: + # audio sequence is longer than max_seq_len, unfold it into chunks of max_seq_len + unfolded = True + # the unfold op will drop residual frames, pad it to the multiple of max_seq_len + if seq_len % max_seq_len > 0: + chunk_pad_size = max_seq_len - (seq_len % max_seq_len) + else: + chunk_pad_size = 0 + if chunk_pad_size > 0: + input_tensor_pad = F.pad(input_tensor, (0, 0, 0, chunk_pad_size), "constant", 0) + input_tensor = input_tensor_pad.to(input_tensor.device) + + input_tensor = unfold_tensor(input_tensor, max_seq_len) + if masks is not None: + # revise hs_mask here because the previous calculated hs_mask did not consider extra pad + subsampled_pad_mask = masks.squeeze(1) # [bz, subsampled_unmask_seq_len] + extra_padded_subsamlped_pad_mask = F.pad( + subsampled_pad_mask, (0, chunk_pad_size), "constant", False + ) # extra padding to the pad mask + extra_padded_subsamlped_pad_mask = extra_padded_subsamlped_pad_mask.unsqueeze(-1).float() + masks_unfold = unfold_tensor( + extra_padded_subsamlped_pad_mask, max_seq_len + ) # unfold the pad mask like we did to the input tensor + masks_unfold = masks_unfold.squeeze(-1).bool() # unfold op does not support bool tensor + else: + masks_unfold = None + hs_mask = self.calculate_hs_mask( + input_tensor, input_tensor.device, masks_unfold + ) # calculate hs_mask based on the unfolded pad mask + layer_emb = None + + relative_attention_bias = self.init_relative_attention_bias(input_tensor) + + _simplified_path = self.extra_layer_output_idx == -1 and relative_attention_bias is None + + if _simplified_path: + input_tensor, *_ = self.encoders(input_tensor, pos_k, pos_v, hs_mask) + else: + for i, layer in enumerate(self.encoders): + input_tensor, _, _, _ = layer( + input_tensor, + pos_k, + pos_v, + hs_mask, + relative_attention_bias=relative_attention_bias, + ) + + if i == self.extra_layer_output_idx: + layer_emb = input_tensor + if unfolded: + embed_dim = input_tensor.shape[-1] + input_tensor = input_tensor.reshape(ori_bz, -1, embed_dim) + # if we ever padded before unfolding, we need to remove the padding + if chunk_pad_size > 0: + input_tensor = input_tensor[:, :-chunk_pad_size, :] + return input_tensor, masks # , layer_emb + + def gradient_checkpointing_enable(self): + pass diff --git a/comps/llms/src/text-generation/patch/enhance-multimodal-patch/vision_siglip_navit.py b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/vision_siglip_navit.py new file mode 100644 index 000000000..f1855262d --- /dev/null +++ b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/vision_siglip_navit.py @@ -0,0 +1,1722 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Siglip model configuration.""" + +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json", +} + + +class SiglipTextConfig(PretrainedConfig): + r"""This is the configuration class to store the configuration of a [`SiglipTextModel`]. + + It is used to instantiate a + Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SiglipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + Example: + ```python + >>> from transformers import SiglipTextConfig, SiglipTextModel + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipTextConfig() + >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipTextModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "siglip_text_model" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip + # See /~https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + _flash_attn_2_enabled=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self._flash_attn_2_enabled = _flash_attn_2_enabled + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipVisionConfig(PretrainedConfig): + r"""This is the configuration class to store the configuration of a [`SiglipVisionModel`]. + + It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + Example: + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + _flash_attn_2_enabled=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self._flash_attn_2_enabled = _flash_attn_2_enabled + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipConfig(PretrainedConfig): + r"""[`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. + + It is used to + instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + Example: + ```python + >>> from transformers import SiglipConfig, SiglipModel + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipConfig() + >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig + >>> from transformers import SiglipTextConfig, SiglipVisionConfig + >>> # Initializing a SiglipText and SiglipVision configuration + >>> config_text = SiglipTextConfig() + >>> config_vision = SiglipVisionConfig() + >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) + ``` + """ + + model_type = "siglip" + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") + + self.text_config = SiglipTextConfig(**text_config) + self.vision_config = SiglipVisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs): + r"""Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision + model configuration. + + Returns: + [`SiglipConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Siglip model.""" + + +import math +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, + replace_return_docstrings, +) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. + + The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + """Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + """Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full( + size=( + batch_size, + max_nb_patches_h * max_nb_patches_w, + ), + fill_value=0, + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipFlashAttention2(SiglipAttention): + """Llama flash attention module. + + This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False # Hack to make sure we don't use a causal mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # if past_key_value is not None: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = ( + SiglipAttention(config) + if not getattr(config, "_flash_attn_2_enabled", False) + else SiglipFlashAttention2(config) + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models.""" + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.normal_(module.probe.data) + nn.init.normal_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.tensor(0.0) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SIGLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip +class SiglipEncoder(nn.Module): + """Transformer encoder consisting of `config.num_hidden_layers` self attention layers. + + Each layer is a + [`SiglipEncoderLayer`]. + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, embed_dim) + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"] + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + Examples: + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if not self.config._flash_attn_2_enabled + else patch_attention_mask + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask, + ) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state, attention_mask): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention( + query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask + )[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@add_start_docstrings( + """The vision model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(SIGLIP_START_DOCSTRING) +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise ValueError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.text_model = SiglipTextTransformer(text_config) + self.vision_model = SiglipVisionTransformer(vision_config) + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + Examples: + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SiglipOutput]: + r""" + Returns: + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[1] + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + raise NotImplementedError("SigLIP loss to be implemented") + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs): + siglip_vision_config = { + "hidden_size": 1152, + "image_size": 448, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + } + + model_config = SiglipVisionConfig(**siglip_vision_config, _flash_attn_2_enabled=_flash_attn_2_enabled, **kwargs) + + vision_model = SiglipVisionModel(model_config).vision_model + + return vision_model diff --git a/comps/llms/src/text-generation/patch/enhance-multimodal-patch/vision_siglip_navit_lazy.py b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/vision_siglip_navit_lazy.py new file mode 100644 index 000000000..1a11cbcbc --- /dev/null +++ b/comps/llms/src/text-generation/patch/enhance-multimodal-patch/vision_siglip_navit_lazy.py @@ -0,0 +1,2020 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Siglip model configuration.""" + +import os +from typing import Union + +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +SIGLIP_PRETRAINED_CONFIG_ARCHIVE_MAP = { + "google/siglip-base-patch16-224": "https://huggingface.co/google/siglip-base-patch16-224/resolve/main/config.json", +} + + +class SiglipTextConfig(PretrainedConfig): + r"""This is the configuration class to store the configuration of a [`SiglipTextModel`]. + + It is used to instantiate a + Siglip text encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the text encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 32000): + Vocabulary size of the Siglip text model. Defines the number of different tokens that can be represented by + the `inputs_ids` passed when calling [`SiglipModel`]. + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + max_position_embeddings (`int`, *optional*, defaults to 64): + The maximum sequence length that this model might ever be used with. Typically set this to something large + just in case (e.g., 512 or 1024 or 2048). + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` `"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + pad_token_id (`int`, *optional*, defaults to 1): + The id of the padding token in the vocabulary. + bos_token_id (`int`, *optional*, defaults to 49406): + The id of the beginning-of-sequence token in the vocabulary. + eos_token_id (`int`, *optional*, defaults to 49407): + The id of the end-of-sequence token in the vocabulary. + Example: + ```python + >>> from transformers import SiglipTextConfig, SiglipTextModel + >>> # Initializing a SiglipTextConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipTextConfig() + >>> # Initializing a SiglipTextModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipTextModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "siglip_text_model" + + def __init__( + self, + vocab_size=32000, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=64, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + # This differs from `CLIPTokenizer`'s default and from openai/siglip + # See /~https://github.com/huggingface/transformers/pull/24773#issuecomment-1632287538 + pad_token_id=1, + bos_token_id=49406, + eos_token_id=49407, + _flash_attn_2_enabled=True, + **kwargs, + ): + super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) + + self.vocab_size = vocab_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.max_position_embeddings = max_position_embeddings + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self.attention_dropout = attention_dropout + self._flash_attn_2_enabled = _flash_attn_2_enabled + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the text config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["text_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipVisionConfig(PretrainedConfig): + r"""This is the configuration class to store the configuration of a [`SiglipVisionModel`]. + + It is used to instantiate a + Siglip vision encoder according to the specified arguments, defining the model architecture. Instantiating a + configuration with the defaults will yield a similar configuration to that of the vision encoder of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + hidden_size (`int`, *optional*, defaults to 768): + Dimensionality of the encoder layers and the pooler layer. + intermediate_size (`int`, *optional*, defaults to 3072): + Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder. + num_hidden_layers (`int`, *optional*, defaults to 12): + Number of hidden layers in the Transformer encoder. + num_attention_heads (`int`, *optional*, defaults to 12): + Number of attention heads for each attention layer in the Transformer encoder. + num_channels (`int`, *optional*, defaults to 3): + Number of channels in the input images. + image_size (`int`, *optional*, defaults to 224): + The size (resolution) of each image. + patch_size (`int`, *optional*, defaults to 16): + The size (resolution) of each patch. + hidden_act (`str` or `function`, *optional*, defaults to `"gelu_pytorch_tanh"`): + The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`, + `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported. + layer_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the layer normalization layers. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + Example: + ```python + >>> from transformers import SiglipVisionConfig, SiglipVisionModel + >>> # Initializing a SiglipVisionConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipVisionConfig() + >>> # Initializing a SiglipVisionModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipVisionModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + ``` + """ + + model_type = "siglip_vision_model" + + def __init__( + self, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + num_channels=3, + image_size=224, + patch_size=16, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + _flash_attn_2_enabled=True, + **kwargs, + ): + super().__init__(**kwargs) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_attention_heads + self.num_channels = num_channels + self.patch_size = patch_size + self.image_size = image_size + self.attention_dropout = attention_dropout + self.layer_norm_eps = layer_norm_eps + self.hidden_act = hidden_act + self._flash_attn_2_enabled = _flash_attn_2_enabled + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + cls._set_token_in_kwargs(kwargs) + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + + # get the vision config dict if we are loading from SiglipConfig + if config_dict.get("model_type") == "siglip": + config_dict = config_dict["vision_config"] + + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + +class SiglipConfig(PretrainedConfig): + r"""[`SiglipConfig`] is the configuration class to store the configuration of a [`SiglipModel`]. + + It is used to + instantiate a Siglip model according to the specified arguments, defining the text model and vision model configs. + Instantiating a configuration with the defaults will yield a similar configuration to that of the Siglip + [google/siglip-base-patch16-224](https://huggingface.co/google/siglip-base-patch16-224) architecture. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + text_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipTextConfig`]. + vision_config (`dict`, *optional*): + Dictionary of configuration options used to initialize [`SiglipVisionConfig`]. + kwargs (*optional*): + Dictionary of keyword arguments. + Example: + ```python + >>> from transformers import SiglipConfig, SiglipModel + >>> # Initializing a SiglipConfig with google/siglip-base-patch16-224 style configuration + >>> configuration = SiglipConfig() + >>> # Initializing a SiglipModel (with random weights) from the google/siglip-base-patch16-224 style configuration + >>> model = SiglipModel(configuration) + >>> # Accessing the model configuration + >>> configuration = model.config + >>> # We can also initialize a SiglipConfig from a SiglipTextConfig and a SiglipVisionConfig + >>> from transformers import SiglipTextConfig, SiglipVisionConfig + >>> # Initializing a SiglipText and SiglipVision configuration + >>> config_text = SiglipTextConfig() + >>> config_vision = SiglipVisionConfig() + >>> config = SiglipConfig.from_text_vision_configs(config_text, config_vision) + ``` + """ + + model_type = "siglip" + + def __init__(self, text_config=None, vision_config=None, **kwargs): + super().__init__(**kwargs) + + if text_config is None: + text_config = {} + logger.info("`text_config` is `None`. Initializing the `SiglipTextConfig` with default values.") + + if vision_config is None: + vision_config = {} + logger.info("`vision_config` is `None`. initializing the `SiglipVisionConfig` with default values.") + + self.text_config = SiglipTextConfig(**text_config) + self.vision_config = SiglipVisionConfig(**vision_config) + + self.initializer_factor = 1.0 + + @classmethod + def from_text_vision_configs(cls, text_config: SiglipTextConfig, vision_config: SiglipVisionConfig, **kwargs): + r"""Instantiate a [`SiglipConfig`] (or a derived class) from siglip text model configuration and siglip vision + model configuration. + + Returns: + [`SiglipConfig`]: An instance of a configuration object + """ + + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) + + +# coding=utf-8 +# Copyright 2024 Google AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Siglip model.""" + + +import math +import warnings +from dataclasses import dataclass +from typing import Any, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn.init import _calculate_fan_in_and_fan_out +from transformers.activations import ACT2FN +from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask +from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from transformers.modeling_utils import PreTrainedModel +from transformers.utils import ( + ModelOutput, + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + logging, + replace_return_docstrings, +) + +logger = logging.get_logger(__name__) + +_CHECKPOINT_FOR_DOC = "google/siglip-base-patch16-224" + +SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST = [ + "google/siglip-base-patch16-224", + # See all SigLIP models at https://huggingface.co/models?filter=siglip +] + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def _trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + if tensor.dtype in [torch.float16, torch.bfloat16]: + # The `erfinv_` op is not (yet?) defined in float16+cpu, bfloat16+gpu + og_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + tensor.erfinv_() + tensor = tensor.to(og_dtype) + else: + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + if tensor.dtype == torch.float16: + # The `clamp_` op is not (yet?) defined in float16+cpu + tensor = tensor.to(torch.float32) + tensor.clamp_(min=a, max=b) + tensor = tensor.to(torch.float16) + else: + tensor.clamp_(min=a, max=b) + + +def trunc_normal_tf_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0, a: float = -2.0, b: float = 2.0 +) -> torch.Tensor: + """Fills the input Tensor with values drawn from a truncated + normal distribution. + + The values are effectively drawn from the + normal distribution :math:`\\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \\leq \text{mean} \\leq b`. + NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the + bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0 + and the result is subsequently scaled and shifted by the mean and std args. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + with torch.no_grad(): + _trunc_normal_(tensor, 0, 1.0, a, b) + tensor.mul_(std).add_(mean) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + with torch.no_grad(): + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + with torch.no_grad(): + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def default_flax_embed_init(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="normal") + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPVisionModelOutput with CLIP->Siglip +class SiglipVisionModelOutput(ModelOutput): + """Base class for vision model's outputs that also contains image embeddings of the pooling of the last hidden states. + + Args: + image_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The image embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + image_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPTextModelOutput with CLIP->Siglip +class SiglipTextModelOutput(ModelOutput): + """Base class for text model's outputs that also contains a pooling of the last hidden states. + + Args: + text_embeds (`torch.FloatTensor` of shape `(batch_size, output_dim)` *optional* returned when model is initialized with `with_projection=True`): + The text embeddings obtained by applying the projection layer to the pooler_output. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Sequence of hidden-states at the output of the last layer of the model. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings, if the model has an embedding layer, + + one for the output of each layer) of shape `(batch_size, sequence_length, hidden_size)`. + Hidden-states of the model at the output of each layer plus the optional initial embedding outputs. + attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. + Attentions weights after the attention softmax, used to compute the weighted average in the self-attention + heads. + """ + + text_embeds: Optional[torch.FloatTensor] = None + last_hidden_state: torch.FloatTensor = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + + +@dataclass +# Copied from transformers.models.clip.modeling_clip.CLIPOutput with CLIP->Siglip +class SiglipOutput(ModelOutput): + """ + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `return_loss` is `True`): + Contrastive loss for image-text similarity. + logits_per_image:(`torch.FloatTensor` of shape `(image_batch_size, text_batch_size)`): + The scaled dot product scores between `image_embeds` and `text_embeds`. This represents the image-text + similarity scores. + logits_per_text:(`torch.FloatTensor` of shape `(text_batch_size, image_batch_size)`): + The scaled dot product scores between `text_embeds` and `image_embeds`. This represents the text-image + similarity scores. + text_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The text embeddings obtained by applying the projection layer to the pooled output of [`SiglipTextModel`]. + image_embeds(`torch.FloatTensor` of shape `(batch_size, output_dim`): + The image embeddings obtained by applying the projection layer to the pooled output of [`SiglipVisionModel`]. + text_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipTextModel`]. + vision_model_output(`BaseModelOutputWithPooling`): + The output of the [`SiglipVisionModel`]. + """ + + loss: Optional[torch.FloatTensor] = None + logits_per_image: torch.FloatTensor = None + logits_per_text: torch.FloatTensor = None + text_embeds: torch.FloatTensor = None + image_embeds: torch.FloatTensor = None + text_model_output: BaseModelOutputWithPooling = None + vision_model_output: BaseModelOutputWithPooling = None + + def to_tuple(self) -> Tuple[Any]: + return tuple( + self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() + for k in self.keys() + ) + + +class SiglipVisionEmbeddings(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.image_size = config.image_size + self.patch_size = config.patch_size + + self.patch_embedding = nn.Conv2d( + in_channels=config.num_channels, + out_channels=self.embed_dim, + kernel_size=self.patch_size, + stride=self.patch_size, + padding="valid", + ) + + self.num_patches_per_side = self.image_size // self.patch_size + self.num_patches = self.num_patches_per_side**2 + self.num_positions = self.num_patches + self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) + + def forward(self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor) -> torch.Tensor: + batch_size = pixel_values.size(0) + + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + + max_im_h, max_im_w = pixel_values.size(2), pixel_values.size(3) + max_nb_patches_h, max_nb_patches_w = max_im_h // self.patch_size, max_im_w // self.patch_size + boundaries = torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) + position_ids = torch.full( + size=( + batch_size, + max_nb_patches_h * max_nb_patches_w, + ), + fill_value=0, + ) + + for batch_idx, p_attn_mask in enumerate(patch_attention_mask): + nb_patches_h = p_attn_mask[:, 0].sum() + nb_patches_w = p_attn_mask[0].sum() + + fractional_coords_h = torch.arange(0, 1 - 1e-6, 1 / nb_patches_h) + fractional_coords_w = torch.arange(0, 1 - 1e-6, 1 / nb_patches_w) + + bucket_coords_h = torch.bucketize(fractional_coords_h, boundaries, right=True) + bucket_coords_w = torch.bucketize(fractional_coords_w, boundaries, right=True) + + pos_ids = (bucket_coords_h[:, None] * self.num_patches_per_side + bucket_coords_w).flatten() + position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids + + position_ids = position_ids.to(self.position_embedding.weight.device) + + embeddings = embeddings + self.position_embedding(position_ids) + return embeddings + + +# Copied from transformers.models.clip.modeling_clip.CLIPTextEmbeddings with CLIP->Siglip +class SiglipTextEmbeddings(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + embed_dim = config.hidden_size + + self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) + self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) + + # position_ids (1, len position emb) is contiguous in memory and exported when serialized + self.register_buffer( + "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False + ) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] + + if position_ids is None: + position_ids = self.position_ids[:, :seq_length] + + if inputs_embeds is None: + inputs_embeds = self.token_embedding(input_ids) + + position_embeddings = self.position_embedding(position_ids) + embeddings = inputs_embeds + position_embeddings + + return embeddings + + +class SiglipAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper.""" + + # Copied from transformers.models.clip.modeling_clip.CLIPAttention.__init__ + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + + self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + """Input shape: Batch x Time x Channel""" + + batch_size, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + k_v_seq_len = key_states.shape[-2] + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, k_v_seq_len): + raise ValueError( + f"Attention weights should be of size {(batch_size, self.num_heads, q_len, k_v_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, k_v_seq_len): + raise ValueError( + f"Attention mask should be of size {(batch_size, 1, q_len, k_v_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (batch_size, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(batch_size, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(batch_size, q_len, self.embed_dim) + + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class SiglipFlashAttention2(SiglipAttention): + """Llama flash attention module. + + This module inherits from `LlamaAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.is_causal = False # Hack to make sure we don't use a causal mask + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # Flash attention requires the input to have the shape + # batch_size x seq_length x head_dim x hidden_dim + # therefore we just need to keep the original shape + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + # cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + # if past_key_value is not None: + # cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + # key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # TODO: These transpose are quite inefficient but Flash Attention requires the layout [batch_size, sequence_length, num_heads, head_dim]. We would need to refactor the KV cache + # to be able to avoid many of these transpose/reshape/view. + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.dropout if self.training else 0.0 + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in the correct dtype just to be sure everything works as expected. + # This might slowdown training & inference so it is recommended to not cast the LayerNorms + # in fp32. (LlamaRMSNorm handles it correctly) + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + "The input hidden states seems to be silently casted in float32, this might be related to the fact" + " you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + attn_output = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate + ) + + attn_output = attn_output.reshape(bsz, q_len, self.embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights + + def _flash_attention_forward( + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`int`, *optional*): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + """ + + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + attn_output = flash_attn_func( + query_states, key_states, value_states, dropout, softmax_scale=softmax_scale, causal=causal + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape + + key_layer = index_first_axis( + key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + value_layer = index_first_axis( + value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k + ) + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, self.num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.clip.modeling_clip.CLIPMLP with CLIP->Siglip +class SiglipMLP(nn.Module): + def __init__(self, config): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) + self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states = self.fc1(hidden_states) + hidden_states = self.activation_fn(hidden_states) + hidden_states = self.fc2(hidden_states) + return hidden_states + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->Siglip +class SiglipEncoderLayer(nn.Module): + def __init__(self, config: SiglipConfig): + super().__init__() + self.embed_dim = config.hidden_size + self.self_attn = ( + SiglipAttention(config) + if not getattr(config, "_flash_attn_2_enabled", False) + else SiglipFlashAttention2(config) + ) + self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + self.layer_norm2 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor]: + """ + Args: + hidden_states (`torch.FloatTensor`): + Input to the layer of shape `(batch, seq_len, embed_dim)`. + attention_mask (`torch.FloatTensor`): + Attention mask of shape `(batch, 1, q_len, k_v_seq_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*, defaults to `False`): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + """ + residual = hidden_states + + hidden_states = self.layer_norm1(hidden_states) + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +class SiglipPreTrainedModel(PreTrainedModel): + """An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models.""" + + config_class = SiglipConfig + base_model_prefix = "siglip" + supports_gradient_checkpointing = True + + def _init_weights(self, module): + """Initialize the weights.""" + + if isinstance(module, SiglipVisionEmbeddings): + width = ( + self.config.vision_config.hidden_size + if isinstance(self.config, SiglipConfig) + else self.config.hidden_size + ) + nn.init.normal_(module.position_embedding.weight, std=1 / np.sqrt(width)) + elif isinstance(module, nn.Embedding): + default_flax_embed_init(module.weight) + elif isinstance(module, SiglipAttention): + nn.init.normal_(module.q_proj.weight) + nn.init.normal_(module.k_proj.weight) + nn.init.normal_(module.v_proj.weight) + nn.init.normal_(module.out_proj.weight) + nn.init.zeros_(module.q_proj.bias) + nn.init.zeros_(module.k_proj.bias) + nn.init.zeros_(module.v_proj.bias) + nn.init.zeros_(module.out_proj.bias) + elif isinstance(module, SiglipMLP): + nn.init.normal_(module.fc1.weight) + nn.init.normal_(module.fc2.weight) + nn.init.normal_(module.fc1.bias, std=1e-6) + nn.init.normal_(module.fc2.bias, std=1e-6) + elif isinstance(module, SiglipMultiheadAttentionPoolingHead): + nn.init.normal_(module.probe.data) + nn.init.normal_(module.attention.in_proj_weight.data) + nn.init.zeros_(module.attention.in_proj_bias.data) + elif isinstance(module, SiglipModel): + logit_scale_init = torch.tensor(0.0) + module.logit_scale.data.fill_(logit_scale_init) + module.logit_bias.data.zero_() + elif isinstance(module, (nn.Linear, nn.Conv2d)): + lecun_normal_(module.weight) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +SIGLIP_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + Parameters: + config ([`SiglipConfig`]): Model configuration class with all the parameters of the model. + Initializing with a config file does not load the weights associated with the model, only the + configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + +SIGLIP_TEXT_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_VISION_INPUTS_DOCSTRING = r""" + Args: + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + +SIGLIP_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.max_position_embeddings - 1]`. + [What are position IDs?](../glossary#position-ids) + pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using + [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details. + return_loss (`bool`, *optional*): + Whether or not to return the contrastive loss. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->Siglip +class SiglipEncoder(nn.Module): + """Transformer encoder consisting of `config.num_hidden_layers` self attention layers. + + Each layer is a + [`SiglipEncoderLayer`]. + Args: + config: SiglipConfig + """ + + def __init__(self, config: SiglipConfig): + super().__init__() + self.config = config + self.layers = nn.ModuleList([SiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.gradient_checkpointing = False + + # Ignore copy + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutput]: + r""" + Args: + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. + This is useful if you want more control over how to convert `input_ids` indices into associated vectors + than the model's internal embedding lookup matrix. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + [What are attention masks?](../glossary#attention-mask) + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors + for more detail. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for encoder_layer in self.layers: + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + encoder_layer.__call__, + hidden_states, + attention_mask, + output_attentions, + ) + else: + layer_outputs = encoder_layer( + hidden_states, + attention_mask, + output_attentions=output_attentions, + ) + + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + if output_hidden_states: + encoder_states = encoder_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) + + +class SiglipTextTransformer(nn.Module): + def __init__(self, config: SiglipTextConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + self.embeddings = SiglipTextEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + + self.head = nn.Linear(embed_dim, embed_dim) + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is None: + raise ValueError("You have to specify input_ids") + + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + + hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) + + # note: SigLIP's text model does not use a causal mask, unlike the original CLIP model. + # expand attention_mask + if attention_mask is not None: + # [batch_size, seq_len] -> [batch_size, 1, tgt_seq_len, src_seq_len] + # attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype) + attention_mask = expand_2d_attention_mask(attention_mask) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.final_layer_norm(last_hidden_state) + + # Assuming "sticky" EOS tokenization, last token is always EOS. + pooled_output = last_hidden_state[:, -1, :] + pooled_output = self.head(pooled_output) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +@add_start_docstrings( + """The text model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipTextModel(SiglipPreTrainedModel): + config_class = SiglipTextConfig + + _no_split_modules = ["SiglipTextEmbeddings", "SiglipEncoderLayer"] + + def __init__(self, config: SiglipTextConfig): + super().__init__(config) + self.text_model = SiglipTextTransformer(config) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.text_model.embeddings.token_embedding + + def set_input_embeddings(self, value): + self.text_model.embeddings.token_embedding = value + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipTextConfig) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + Examples: + ```python + >>> from transformers import AutoTokenizer, SiglipTextModel + >>> model = SiglipTextModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +class SiglipVisionTransformer(nn.Module): + def __init__(self, config: SiglipVisionConfig): + super().__init__() + self.config = config + embed_dim = config.hidden_size + + self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder(config) + self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.head = SiglipMultiheadAttentionPoolingHead(config) + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + batch_size = pixel_values.size(0) + if patch_attention_mask is None: + patch_attention_mask = torch.ones( + size=( + batch_size, + pixel_values.size(2) // self.config.patch_size, + pixel_values.size(3) // self.config.patch_size, + ), + dtype=torch.bool, + device=pixel_values.device, + ) + + hidden_states = self.embeddings(pixel_values=pixel_values, patch_attention_mask=patch_attention_mask) + + patch_attention_mask = patch_attention_mask.view(batch_size, -1) + # The call to `_upad_input` in `_flash_attention_forward` is expensive + # So when the `patch_attention_mask` is full of 1s (i.e. attending to the whole sequence), + # avoiding passing the attention_mask, which is equivalent to attending to the full sequence + if not torch.any(~patch_attention_mask): + attention_mask = None + else: + attention_mask = ( + _prepare_4d_attention_mask(patch_attention_mask, hidden_states.dtype) + if not self.config._flash_attn_2_enabled + else patch_attention_mask + ) + + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + last_hidden_state = encoder_outputs[0] + last_hidden_state = self.post_layernorm(last_hidden_state) + + pooled_output = self.head( + hidden_state=last_hidden_state, + attention_mask=patch_attention_mask, + ) + + if not return_dict: + return (last_hidden_state, pooled_output) + encoder_outputs[1:] + + return BaseModelOutputWithPooling( + last_hidden_state=last_hidden_state, + pooler_output=pooled_output, + hidden_states=encoder_outputs.hidden_states, + attentions=encoder_outputs.attentions, + ) + + +class SiglipMultiheadAttentionPoolingHead(nn.Module): + """Multihead Attention Pooling.""" + + def __init__(self, config: SiglipVisionConfig): + super().__init__() + + self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size)) + self.attention = torch.nn.MultiheadAttention(config.hidden_size, config.num_attention_heads, batch_first=True) + self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.mlp = SiglipMLP(config) + + def forward(self, hidden_state, attention_mask): + batch_size = hidden_state.shape[0] + probe = self.probe.repeat(batch_size, 1, 1) + + hidden_state = self.attention( + query=probe, key=hidden_state, value=hidden_state, key_padding_mask=~attention_mask + )[0] + + residual = hidden_state + hidden_state = self.layernorm(hidden_state) + hidden_state = residual + self.mlp(hidden_state) + + return hidden_state[:, 0] + + +@add_start_docstrings( + """The vision model from SigLIP without any head or projection on top.""", + SIGLIP_START_DOCSTRING, +) +class SiglipVisionModel(SiglipPreTrainedModel): + config_class = SiglipVisionConfig + main_input_name = "pixel_values" + + def __init__(self, config: SiglipVisionConfig): + super().__init__(config) + + self.vision_model = SiglipVisionTransformer(config) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self) -> nn.Module: + return self.vision_model.embeddings.patch_embedding + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=SiglipVisionConfig) + def forward( + self, + pixel_values, + patch_attention_mask: Optional[torch.BoolTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPooling]: + r""" + Returns: + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, SiglipVisionModel + >>> model = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled features + ```""" + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + return self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + +@add_start_docstrings(SIGLIP_START_DOCSTRING) +class SiglipModel(SiglipPreTrainedModel): + config_class = SiglipConfig + + def __init__(self, config: SiglipConfig): + super().__init__(config) + + if not isinstance(config.text_config, SiglipTextConfig): + raise ValueError( + "config.text_config is expected to be of type SiglipTextConfig but is of type" + f" {type(config.text_config)}." + ) + + if not isinstance(config.vision_config, SiglipVisionConfig): + raise ValueError( + "config.vision_config is expected to be of type SiglipVisionConfig but is of type" + f" {type(config.vision_config)}." + ) + + text_config = config.text_config + vision_config = config.vision_config + + self.text_model = SiglipTextTransformer(text_config) + self.vision_model = SiglipVisionTransformer(vision_config) + + self.logit_scale = nn.Parameter(torch.randn(1)) + self.logit_bias = nn.Parameter(torch.randn(1)) + + # Initialize weights and apply final processing + self.post_init() + + @add_start_docstrings_to_model_forward(SIGLIP_TEXT_INPUTS_DOCSTRING) + def get_text_features( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by + applying the projection layer to the pooled output of [`SiglipTextModel`]. + Examples: + ```python + >>> from transformers import AutoTokenizer, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> tokenizer = AutoTokenizer.from_pretrained("google/siglip-base-patch16-224") + >>> # important: make sure to set padding="max_length" as that's how the model was trained + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... text_features = model.get_text_features(**inputs) + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = text_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_VISION_INPUTS_DOCSTRING) + def get_image_features( + self, + pixel_values: Optional[torch.FloatTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> torch.FloatTensor: + r""" + Returns: + image_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The image embeddings obtained by + applying the projection layer to the pooled output of [`SiglipVisionModel`]. + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> inputs = processor(images=image, return_tensors="pt") + >>> with torch.no_grad(): + ... image_features = model.get_image_features(**inputs) + ```""" + # Use SiglipModel's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + pooled_output = vision_outputs[1] + + return pooled_output + + @add_start_docstrings_to_model_forward(SIGLIP_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=SiglipOutput, config_class=SiglipConfig) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + return_loss: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SiglipOutput]: + r""" + Returns: + Examples: + ```python + >>> from PIL import Image + >>> import requests + >>> from transformers import AutoProcessor, AutoModel + >>> import torch + >>> model = AutoModel.from_pretrained("google/siglip-base-patch16-224") + >>> processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224") + >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg" + >>> image = Image.open(requests.get(url, stream=True).raw) + >>> texts = ["a photo of 2 cats", "a photo of 2 dogs"] + >>> # important: we pass `padding=max_length` since the model was trained with this + >>> inputs = processor(text=texts, images=image, padding="max_length", return_tensors="pt") + >>> with torch.no_grad(): + ... outputs = model(**inputs) + >>> logits_per_image = outputs.logits_per_image + >>> probs = torch.sigmoid(logits_per_image) # these are the probabilities + >>> print(f"{probs[0][0]:.1%} that image 0 is '{texts[0]}'") + 31.9% that image 0 is 'a photo of 2 cats' + ```""" + # Use SigLIP model's config for some fields (if specified) instead of those of vision & text components. + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + vision_outputs = self.vision_model( + pixel_values=pixel_values, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + text_outputs = self.text_model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + image_embeds = vision_outputs[1] + text_embeds = text_outputs[1] + + # normalized features + image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) + text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) + + # cosine similarity as logits + logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * self.logit_scale.exp() + self.logit_bias + logits_per_image = logits_per_text.t() + + loss = None + if return_loss: + raise NotImplementedError("SigLIP loss to be implemented") + + if not return_dict: + output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) + return ((loss,) + output) if loss is not None else output + + return SiglipOutput( + loss=loss, + logits_per_image=logits_per_image, + logits_per_text=logits_per_text, + text_embeds=text_embeds, + image_embeds=image_embeds, + text_model_output=text_outputs, + vision_model_output=vision_outputs, + ) + + +def get_siglip_vision_model(_flash_attn_2_enabled=True, **kwargs): + siglip_vision_config = { + "hidden_size": 1152, + "image_size": 448, + "intermediate_size": 4304, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_hidden_layers": 27, + "patch_size": 14, + } + + model_config = SiglipVisionConfig(**siglip_vision_config, _flash_attn_2_enabled=_flash_attn_2_enabled, **kwargs) + + vision_model = SiglipVisionModel(model_config).vision_model + + return vision_model + + +######################################################################## + +import sys + +import torch +import torch.nn as nn +from transformers.modeling_outputs import BaseModelOutput + +sys.path.append("/root/.cache/huggingface/modules/transformers_modules/phi-4-multimodal") + +from transformers.utils import logging +from vision_siglip_navit import SiglipAttention, SiglipConfig, SiglipEncoder, SiglipEncoderLayer + +logger = logging.get_logger(__name__) + +############################################################################### +# 1) Imports & checks for Gaudi fused kernels +############################################################################### +import habana_frameworks.torch.core as htcore + +try: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + + _GAUDI_FUSED_SDPA_AVAILABLE = True +except ImportError: + FusedSDPA = None + _GAUDI_FUSED_SDPA_AVAILABLE = False + + +############################################################################### +# 2) A small module that wraps the FusedSDPA call +############################################################################### +class ModuleFusedSDPA(torch.nn.Module): + """Simple wrapper around the Gaudi fused scaled dot-product attention kernel.""" + + def __init__(self, fusedSDPA): + super().__init__() + self._fused_sdpa = fusedSDPA + + def forward(self, query, key, value, attn_mask, dropout_p, is_causal, scale): + # is_causal can remain False for Siglip since it’s not using causal masks + # scale can be set to None if you wish to let the kernel compute the default + return self._fused_sdpa.apply(query, key, value, attn_mask, dropout_p, is_causal, scale) + + +def expand_2d_attention_mask(mask_2d: torch.Tensor) -> torch.Tensor: + """Convert a 2D `batch_size x seq_len` mask into a 4D mask + `batch_size x 1 x seq_len x seq_len`. + + Zeros in `mask_2d` become large negative values + in the expanded mask (to block out attention). + """ + + # mask_2d is shape [batch_size, seq_len], containing 1 for "attend" and 0 for "mask out" + + batch_size, seq_len = mask_2d.shape + + # (1) Convert 1/0 to float and invert if needed + # Often we transform 1→0.0 (keep) and 0→-∞ (mask) + # or multiply by -1e4 to get "large negative" for masked positions + extended_mask = mask_2d[:, None, None, :].to(torch.float32) # shape [bsz, 1, 1, seq_len] + + # (2) Broadcast along the second dimension to produce [bsz, 1, seq_len, seq_len] + extended_mask = extended_mask.expand(batch_size, 1, seq_len, seq_len) + + # (3) Convert 1.0→0.0 and 0.0→-∞ if the attention code expects additive mask + extended_mask = (1.0 - extended_mask) * -1.0e4 + + return extended_mask + + +############################################################################### +# 3) New attention class that inherits from SiglipAttention and overrides forward +############################################################################### +from typing import Optional, Tuple + + +class GaudiSiglipAttention(SiglipAttention): + """A Gaudi-optimized SiglipAttention. + + Uses fused scaled-dot-product attention + if FusedSDPA is available and `use_flash_attention=True`. + Otherwise, falls back to the original SiglipAttention logic. + """ + + def __init__(self, config): + super().__init__(config) + # If the Gaudi fused kernel is available, wrap it; else None + self.fused_scaled_dot_product_attention = ModuleFusedSDPA(FusedSDPA) if _GAUDI_FUSED_SDPA_AVAILABLE else None + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = False, + use_flash_attention: bool = True, # <-- new optional argument + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Overridden forward method. + + Args: + hidden_states: [batch_size, seq_len, embed_dim]. + attention_mask: [batch_size, 1, seq_len, seq_len] if provided. + output_attentions: If True, returns attention probabilities (not supported by fused kernel). + use_flash_attention: If True and FusedSDPA is available, uses Habana’s fused kernel. + """ + # If user wants raw attention outputs, we have to fallback + if output_attentions and use_flash_attention: + logger.warning( + "GaudiSiglipAttention: output_attentions=True is not currently " + "supported by FusedSDPA. Falling back to the default PyTorch path." + ) + use_flash_attention = False + + batch_size, q_len, _ = hidden_states.size() + + # standard linear projections + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + # shape = [bsz, num_heads, seq_len, head_dim] + query_states = query_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(batch_size, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + use_fused_kernel = use_flash_attention and self.fused_scaled_dot_product_attention is not None + + q_len = query_states.shape[2] + k_len = key_states.shape[2] + if k_len != q_len and k_len != 1: + use_fused_kernel = False + else: + use_fused_kernel = True + + if use_flash_attention and self.fused_scaled_dot_product_attention is not None and use_fused_kernel: + # Gaudi fused kernel path + # Typically, we open an sdp_kernel context: + import habana_frameworks.torch.hpu as ht + + # import pdb + + # pdb.set_trace() + with ht.sdp_kernel(enable_recompute=False): # or True if you want to enable recompute + # Here scale can be self.scale, which is (1/sqrt(head_dim)). + attn_output = self.fused_scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=attention_mask, # shape: [bsz, 1, q_len, q_len] + dropout_p=self.dropout if self.training else 0.0, + is_causal=False, # Siglip is not causal, can be True if you have a causal mask + scale=self.scale, + ) + attn_weights = None # the fused kernel doesn't produce attention weights + else: + # Original fallback path + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scale + + if attn_weights.size() != (batch_size, self.num_heads, q_len, q_len): + raise ValueError( + f"Attention weights size mismatch. Expected {(batch_size, self.num_heads, q_len, q_len)}, " + f"got {attn_weights.size()}" + ) + if attention_mask is not None: + if attention_mask.size() != (batch_size, 1, q_len, q_len): + raise ValueError( + f"Attention mask size mismatch. Expected {(batch_size,1,q_len,q_len)}, " + f"got {attention_mask.size()}" + ) + attn_weights += attention_mask + + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + # [bsz, num_heads, seq_len, head_dim] -> [bsz, seq_len, embed_dim] + attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, q_len, self.embed_dim) + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +############################################################################### +# 4) A corresponding “EncoderLayer” that swaps in GaudiSiglipAttention +############################################################################### +class GaudiSiglipEncoderLayer(SiglipEncoderLayer): + """Exactly like SiglipEncoderLayer, but uses GaudiSiglipAttention for self_attn.""" + + def __init__(self, config: SiglipConfig): + super().__init__(config) + self.self_attn = GaudiSiglipAttention(config) # <--- override + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor, + output_attentions: Optional[bool] = False, + use_flash_attention: bool = True, # new + ) -> Tuple[torch.FloatTensor]: + """Same signature, but we add a `use_flash_attention` argument.""" + residual = hidden_states + hidden_states = self.layer_norm1(hidden_states) + + # Pass the new argument down to GaudiSiglipAttention + hidden_states, attn_weights = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + use_flash_attention=use_flash_attention, + ) + hidden_states = residual + hidden_states + + residual = hidden_states + hidden_states = self.layer_norm2(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + if output_attentions: + outputs += (attn_weights,) + + return outputs + + +############################################################################### +# 5) A “GaudiSiglipEncoder” that simply uses GaudiSiglipEncoderLayer +############################################################################### +class GaudiSiglipEncoder(SiglipEncoder): + """Inherits from SiglipEncoder, but each layer is a GaudiSiglipEncoderLayer.""" + + def __init__(self, config: SiglipConfig): + super().__init__(config) + # Re-initialize the list of layers with the Gaudi version + self.layers = nn.ModuleList([GaudiSiglipEncoderLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward( + self, + inputs_embeds, + attention_mask: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + use_flash_attention: bool = True, # new + ): + # identical logic to SiglipEncoder, just pass down `use_flash_attention` to each layer + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + encoder_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + hidden_states = inputs_embeds + for layer in self.layers: + if output_hidden_states: + encoder_states += (hidden_states,) + if self.gradient_checkpointing and self.training: + layer_outputs = self._gradient_checkpointing_func( + layer.__call__, + hidden_states, + attention_mask, + output_attentions, + use_flash_attention, + ) + else: + layer_outputs = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + use_flash_attention=use_flash_attention, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + # If you want to mark steps for HPU graph capturing: + htcore.mark_step() + + if output_hidden_states: + encoder_states += (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) + + return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions) + + +############### TODO: REMOVE #################### +SiglipAttention = GaudiSiglipAttention +SiglipEncoder = GaudiSiglipEncoder +SiglipEncoderLayer = GaudiSiglipEncoderLayer +############### TODO: REMOVE #################### diff --git a/comps/llms/src/text-generation/patch/optimum-habana-enhance.patch b/comps/llms/src/text-generation/patch/optimum-habana-enhance.patch new file mode 100644 index 000000000..72b6519de --- /dev/null +++ b/comps/llms/src/text-generation/patch/optimum-habana-enhance.patch @@ -0,0 +1,2629 @@ +From 3220191ba89a00fa04915d670c133ddd6bd6866d Mon Sep 17 00:00:00 2001 +From: leopck +Date: Wed, 26 Feb 2025 06:17:44 +0200 +Subject: [PATCH] Phi-4-mini-instruct + +Signed-off-by: leopck +--- + .../01-patch-transformer.sh | 6 + + .../phi-4-mini-instruct/02-run-sample.sh | 25 + + .../phi-4-mini-instruct/README.md | 9 + + .../phi-4-mini-instruct/patch/__init__.py | 67 + + .../patch/configuration_phi3.py | 226 +++ + .../patch/modeling_phi3.py | 1527 +++++++++++++++++ + optimum/habana/transformers/modeling_utils.py | 19 + + .../habana/transformers/models/__init__.py | 12 + + .../transformers/models/phi3/__init__.py | 6 + + .../transformers/models/phi3/modeling_phi3.py | 621 +++++++ + 10 files changed, 2518 insertions(+) + create mode 100755 examples/text-generation/phi-4-mini-instruct/01-patch-transformer.sh + create mode 100755 examples/text-generation/phi-4-mini-instruct/02-run-sample.sh + create mode 100644 examples/text-generation/phi-4-mini-instruct/README.md + create mode 100644 examples/text-generation/phi-4-mini-instruct/patch/__init__.py + create mode 100644 examples/text-generation/phi-4-mini-instruct/patch/configuration_phi3.py + create mode 100644 examples/text-generation/phi-4-mini-instruct/patch/modeling_phi3.py + create mode 100644 optimum/habana/transformers/models/phi3/__init__.py + create mode 100644 optimum/habana/transformers/models/phi3/modeling_phi3.py + +diff --git a/examples/text-generation/phi-4-mini-instruct/01-patch-transformer.sh b/examples/text-generation/phi-4-mini-instruct/01-patch-transformer.sh +new file mode 100755 +index 00000000..eb21f8af +--- /dev/null ++++ b/examples/text-generation/phi-4-mini-instruct/01-patch-transformer.sh +@@ -0,0 +1,6 @@ ++#!/bin/bash ++set -x ++echo -e "Patching phi3 into the transformer installed dist packages" ++cp patch/__init__.py "/usr/local/lib/python3.10/dist-packages/transformers/models/phi3/__init__.py" ++cp patch/configuration_phi3.py "/usr/local/lib/python3.10/dist-packages/transformers/models/phi3/configuration_phi3.py" ++cp patch/modeling_phi3.py "/usr/local/lib/python3.10/dist-packages/transformers/models/phi3/modeling_phi3.py" +\ No newline at end of file +diff --git a/examples/text-generation/phi-4-mini-instruct/02-run-sample.sh b/examples/text-generation/phi-4-mini-instruct/02-run-sample.sh +new file mode 100755 +index 00000000..4c3dfa75 +--- /dev/null ++++ b/examples/text-generation/phi-4-mini-instruct/02-run-sample.sh +@@ -0,0 +1,25 @@ ++#!/bin/bash ++ ++ ++################################################################################ ++# Patching Function ++################################################################################ ++echo -e "############################# WARNING #############################" ++echo -e "Patching this script, please remove this script if not needed" ++echo -e "Only for phy-4-3.8B-mini-instruct" ++./01-patch-transformer.sh ++ ++python run_generation.py \ ++ --model_name_or_path "phi4/phi-4" \ ++ --max_input_tokens 128 \ ++ --max_new_tokens 128 \ ++ --bf16 \ ++ --use_hpu_graphs \ ++ --batch_size 1 \ ++ --attn_softmax_bf16 \ ++ --limit_hpu_graphs \ ++ --flash_attention_causal_mask \ ++ --flash_attention_recompute \ ++ --warmup 3 \ ++ --n_iterations 1 \ ++ --use_flash_attention +\ No newline at end of file +diff --git a/examples/text-generation/phi-4-mini-instruct/README.md b/examples/text-generation/phi-4-mini-instruct/README.md +new file mode 100644 +index 00000000..8fa2b6aa +--- /dev/null ++++ b/examples/text-generation/phi-4-mini-instruct/README.md +@@ -0,0 +1,9 @@ ++# Getting started ++ ++Before running 01-patch-transformer.sh, find the dist-package location of your transformer installation path in my case "/usr/local/lib/python3.10/dist-packages/transformers/models/phi3/", we will be patching this Phi-4-mini with the latest version of Phi-4 as well as init file for importing. ++ ++```sh ++./01-patch-transformer.sh ++# Before running sample, change the model path inside the sample to point to your phi-4-mini-instruct ++./02-run-sample.sh ++``` +\ No newline at end of file +diff --git a/examples/text-generation/phi-4-mini-instruct/patch/__init__.py b/examples/text-generation/phi-4-mini-instruct/patch/__init__.py +new file mode 100644 +index 00000000..83af185b +--- /dev/null ++++ b/examples/text-generation/phi-4-mini-instruct/patch/__init__.py +@@ -0,0 +1,67 @@ ++# Copyright 2024 Microsoft and The HuggingFace Inc. team. All rights reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++ ++from typing import TYPE_CHECKING ++ ++from ...utils import ( ++ OptionalDependencyNotAvailable, ++ _LazyModule, ++ is_sentencepiece_available, ++ is_tokenizers_available, ++ is_torch_available, ++) ++ ++ ++_import_structure = { ++ "configuration_phi3": ["Phi3Config"], ++} ++ ++try: ++ if not is_torch_available(): ++ raise OptionalDependencyNotAvailable() ++except OptionalDependencyNotAvailable: ++ pass ++else: ++ _import_structure["modeling_phi3"] = [ ++ "Phi3PreTrainedModel", ++ "Phi3Model", ++ "Phi3ForCausalLM", ++ "Phi3ForSequenceClassification", ++ "Phi3ForTokenClassification", ++ ] ++ ++ ++if TYPE_CHECKING: ++ from .configuration_phi3 import Phi3Config ++ ++ try: ++ if not is_torch_available(): ++ raise OptionalDependencyNotAvailable() ++ except OptionalDependencyNotAvailable: ++ pass ++ else: ++ from .modeling_phi3 import ( ++ Phi3ForCausalLM, ++ Phi3ForSequenceClassification, ++ Phi3ForTokenClassification, ++ Phi3Model, ++ Phi3PreTrainedModel, ++ ) ++ ++ ++else: ++ import sys ++ ++ sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) +\ No newline at end of file +diff --git a/examples/text-generation/phi-4-mini-instruct/patch/configuration_phi3.py b/examples/text-generation/phi-4-mini-instruct/patch/configuration_phi3.py +new file mode 100644 +index 00000000..6b45af10 +--- /dev/null ++++ b/examples/text-generation/phi-4-mini-instruct/patch/configuration_phi3.py +@@ -0,0 +1,226 @@ ++# coding=utf-8 ++# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++"""Phi-3 model configuration""" ++ ++from transformers.configuration_utils import PretrainedConfig ++from transformers.utils import logging ++ ++ ++logger = logging.get_logger(__name__) ++ ++ ++class Phi3Config(PretrainedConfig): ++ r""" ++ This is the configuration class to store the configuration of a [`Phi3Model`]. It is used to instantiate a Phi-3 ++ model according to the specified arguments, defining the model architecture. Instantiating a configuration with the ++ defaults will yield a similar configuration to that of the ++ [microsoft/Phi-3-mini-4k-instruct](https://huggingface.co/microsoft/Phi-3-mini-4k-instruct). ++ ++ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the ++ documentation from [`PretrainedConfig`] for more information. ++ ++ Args: ++ vocab_size (`int`, *optional*, defaults to 32064): ++ Vocabulary size of the Phi-3 model. Defines the number of different tokens that can be represented by the ++ `inputs_ids` passed when calling [`Phi3Model`]. ++ hidden_size (`int`, *optional*, defaults to 3072): ++ Dimension of the hidden representations. ++ intermediate_size (`int`, *optional*, defaults to 8192): ++ Dimension of the MLP representations. ++ num_hidden_layers (`int`, *optional*, defaults to 32): ++ Number of hidden layers in the Transformer decoder. ++ num_attention_heads (`int`, *optional*, defaults to 32): ++ Number of attention heads for each attention layer in the Transformer decoder. ++ num_key_value_heads (`int`, *optional*): ++ This is the number of key_value heads that should be used to implement Grouped Query Attention. If ++ `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if ++ `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When ++ converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed ++ by meanpooling all the original heads within that group. For more details checkout [this ++ paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to ++ `num_attention_heads`. ++ resid_pdrop (`float`, *optional*, defaults to 0.0): ++ Dropout probability for mlp outputs. ++ embd_pdrop (`int`, *optional*, defaults to 0.0): ++ The dropout ratio for the embeddings. ++ attention_dropout (`float`, *optional*, defaults to 0.0): ++ The dropout ratio after computing the attention scores. ++ hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): ++ The non-linear activation function (function or string) in the decoder. ++ max_position_embeddings (`int`, *optional*, defaults to 4096): ++ The maximum sequence length that this model might ever be used with. ++ original_max_position_embeddings (`int`, *optional*, defaults to 4096): ++ The maximum sequence length that this model was trained with. This is used to determine the size of the ++ original RoPE embeddings when using long scaling. ++ initializer_range (`float`, *optional*, defaults to 0.02): ++ The standard deviation of the truncated_normal_initializer for initializing all weight matrices. ++ rms_norm_eps (`float`, *optional*, defaults to 1e-05): ++ The epsilon value used for the RMSNorm. ++ use_cache (`bool`, *optional*, defaults to `True`): ++ Whether or not the model should return the last key/values attentions (not used by all models). Only ++ relevant if `config.is_decoder=True`. Whether to tie weight embeddings or not. ++ tie_word_embeddings (`bool`, *optional*, defaults to `False`): ++ Whether to tie weight embeddings ++ rope_theta (`float`, *optional*, defaults to 10000.0): ++ The base period of the RoPE embeddings. ++ rope_scaling (`dict`, *optional*): ++ The scaling strategy for the RoPE embeddings. If `None`, no scaling is applied. If a dictionary, it must ++ contain the following keys: `type`, `short_factor` and `long_factor`. The `type` must be `longrope` and ++ the `short_factor` and `long_factor` must be lists of numbers with the same length as the hidden size ++ divided by the number of attention heads divided by 2. ++ partial_rotary_factor (`float`, *optional*, defaults to 0.5): ++ Percentage of the query and keys which will have rotary embedding. ++ bos_token_id (`int`, *optional*, defaults to 1): ++ The id of the "beginning-of-sequence" token. ++ eos_token_id (`int`, *optional*, defaults to 32000): ++ The id of the "end-of-sequence" token. ++ pad_token_id (`int`, *optional*, defaults to 32000): ++ The id of the padding token. ++ sliding_window (`int`, *optional*): ++ Sliding window attention window size. If `None`, no sliding window is applied. ++ ++ Example: ++ ++ ```python ++ >>> from transformers import Phi3Model, Phi3Config ++ ++ >>> # Initializing a Phi-3 style configuration ++ >>> configuration = Phi3Config.from_pretrained("microsoft/Phi-3-mini-4k-instruct") ++ ++ >>> # Initializing a model from the configuration ++ >>> model = Phi3Model(configuration) ++ ++ >>> # Accessing the model configuration ++ >>> configuration = model.config ++ ```""" ++ ++ model_type = "phi3" ++ keys_to_ignore_at_inference = ["past_key_values"] ++ ++ def __init__( ++ self, ++ vocab_size=200064, ++ hidden_size=3072, ++ intermediate_size=8192, ++ num_hidden_layers=32, ++ num_attention_heads=32, ++ num_key_value_heads=None, ++ resid_pdrop=0.0, ++ embd_pdrop=0.0, ++ attention_dropout=0.0, ++ hidden_act="silu", ++ max_position_embeddings=4096, ++ original_max_position_embeddings=4096, ++ initializer_range=0.02, ++ rms_norm_eps=1e-5, ++ use_cache=True, ++ tie_word_embeddings=False, ++ rope_theta=10000.0, ++ rope_scaling=None, ++ partial_rotary_factor=1, ++ bos_token_id=199999, ++ eos_token_id=199999, ++ pad_token_id=199999, ++ sliding_window=None, ++ **kwargs, ++ ): ++ self.vocab_size = vocab_size ++ self.hidden_size = hidden_size ++ self.intermediate_size = intermediate_size ++ self.num_hidden_layers = num_hidden_layers ++ self.num_attention_heads = num_attention_heads ++ ++ if num_key_value_heads is None: ++ num_key_value_heads = num_attention_heads ++ ++ self.num_key_value_heads = num_key_value_heads ++ self.resid_pdrop = resid_pdrop ++ self.embd_pdrop = embd_pdrop ++ self.attention_dropout = attention_dropout ++ self.hidden_act = hidden_act ++ self.max_position_embeddings = max_position_embeddings ++ self.original_max_position_embeddings = original_max_position_embeddings ++ self.initializer_range = initializer_range ++ self.rms_norm_eps = rms_norm_eps ++ self.use_cache = use_cache ++ self.rope_theta = rope_theta ++ self.rope_scaling = rope_scaling ++ self.partial_rotary_factor = partial_rotary_factor ++ self._rope_scaling_adjustment() ++ self._rope_scaling_validation() ++ self.sliding_window = sliding_window ++ ++ super().__init__( ++ bos_token_id=bos_token_id, ++ eos_token_id=eos_token_id, ++ pad_token_id=pad_token_id, ++ tie_word_embeddings=tie_word_embeddings, ++ **kwargs, ++ ) ++ ++ def _rope_scaling_adjustment(self): ++ """ ++ Adjust the `type` of the `rope_scaling` configuration for backward compatibility. ++ """ ++ if self.rope_scaling is None: ++ return ++ ++ rope_scaling_type = self.rope_scaling.get("type", None) ++ ++ # For backward compatibility if previous version used "su" or "yarn" ++ if rope_scaling_type is not None and rope_scaling_type in ["su", "yarn"]: ++ self.rope_scaling["type"] = "longrope" ++ ++ def _rope_scaling_validation(self): ++ """ ++ Validate the `rope_scaling` configuration. ++ """ ++ if self.rope_scaling is None: ++ return ++ ++ if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 3: ++ raise ValueError( ++ "`rope_scaling` must be a dictionary with three fields, `type`, `short_factor` and `long_factor`, " ++ f"got {self.rope_scaling}" ++ ) ++ rope_scaling_type = self.rope_scaling.get("type", None) ++ rope_scaling_short_factor = self.rope_scaling.get("short_factor", None) ++ rope_scaling_long_factor = self.rope_scaling.get("long_factor", None) ++ if rope_scaling_type is None or rope_scaling_type not in ["longrope"]: ++ raise ValueError(f"`rope_scaling`'s type field must be one of ['longrope'], got {rope_scaling_type}") ++ if not ( ++ isinstance(rope_scaling_short_factor, list) ++ and all(isinstance(x, (int, float)) for x in rope_scaling_short_factor) ++ ): ++ raise ValueError( ++ f"`rope_scaling`'s short_factor field must be a list of numbers, got {rope_scaling_short_factor}" ++ ) ++ rotary_ndims = int(self.hidden_size // self.num_attention_heads * self.partial_rotary_factor) ++ if not len(rope_scaling_short_factor) == rotary_ndims // 2: ++ raise ValueError( ++ f"`rope_scaling`'s short_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_short_factor)}" ++ ) ++ if not ( ++ isinstance(rope_scaling_long_factor, list) ++ and all(isinstance(x, (int, float)) for x in rope_scaling_long_factor) ++ ): ++ raise ValueError( ++ f"`rope_scaling`'s long_factor field must be a list of numbers, got {rope_scaling_long_factor}" ++ ) ++ if not len(rope_scaling_long_factor) == rotary_ndims // 2: ++ raise ValueError( ++ f"`rope_scaling`'s long_factor field must have length {rotary_ndims // 2}, got {len(rope_scaling_long_factor)}" ++ ) +diff --git a/examples/text-generation/phi-4-mini-instruct/patch/modeling_phi3.py b/examples/text-generation/phi-4-mini-instruct/patch/modeling_phi3.py +new file mode 100644 +index 00000000..dd40bb0e +--- /dev/null ++++ b/examples/text-generation/phi-4-mini-instruct/patch/modeling_phi3.py +@@ -0,0 +1,1527 @@ ++# coding=utf-8 ++# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++ ++"""PyTorch Phi-3 model.""" ++ ++import math ++import warnings ++from typing import List, Optional, Tuple, Union ++ ++import torch ++import torch.utils.checkpoint ++from torch import nn ++from torch.nn import CrossEntropyLoss ++ ++from transformers.activations import ACT2FN ++from transformers.cache_utils import Cache, DynamicCache, SlidingWindowCache, StaticCache ++from transformers.generation import GenerationMixin ++from transformers.modeling_attn_mask_utils import AttentionMaskConverter ++from transformers.modeling_flash_attention_utils import _flash_attention_forward ++from transformers.modeling_outputs import ( ++ BaseModelOutputWithPast, ++ CausalLMOutputWithPast, ++ SequenceClassifierOutputWithPast, ++ TokenClassifierOutput, ++) ++from transformers.modeling_utils import PreTrainedModel ++from transformers.utils import ( ++ add_code_sample_docstrings, ++ add_start_docstrings, ++ add_start_docstrings_to_model_forward, ++ is_flash_attn_greater_or_equal_2_10, ++ logging, ++ replace_return_docstrings, ++) ++from .configuration_phi3 import Phi3Config ++ ++ ++logger = logging.get_logger(__name__) ++ ++_CHECKPOINT_FOR_DOC = "microsoft/Phi-3-mini-4k-instruct" ++_CONFIG_FOR_DOC = "Phi3Config" ++ ++ ++# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Phi3 ++class Phi3RMSNorm(nn.Module): ++ def __init__(self, hidden_size, eps=1e-6): ++ """ ++ Phi3RMSNorm is equivalent to T5LayerNorm ++ """ ++ super().__init__() ++ self.weight = nn.Parameter(torch.ones(hidden_size)) ++ self.variance_epsilon = eps ++ ++ def forward(self, hidden_states): ++ input_dtype = hidden_states.dtype ++ hidden_states = hidden_states.to(torch.float32) ++ variance = hidden_states.pow(2).mean(-1, keepdim=True) ++ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) ++ return self.weight * hidden_states.to(input_dtype) ++ ++ def extra_repr(self): ++ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" ++ ++ ++# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding with gemma->phi3, Gemma->Phi3 ++class Phi3RotaryEmbedding(nn.Module): ++ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): ++ super().__init__() ++ ++ self.dim = dim ++ self.max_position_embeddings = max_position_embeddings ++ self.base = base ++ ++ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) ++ self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) ++ ++ @torch.no_grad() ++ def forward(self, x, position_ids, seq_len=None): ++ # x: [bs, num_attention_heads, seq_len, head_size] ++ self.inv_freq.to(x.device) ++ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ++ position_ids_expanded = position_ids[:, None, :].float() ++ # Force float32 since bfloat16 loses precision on long contexts ++ # See /~https://github.com/huggingface/transformers/pull/29285 ++ device_type = x.device.type ++ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ++ with torch.autocast(device_type=device_type, enabled=False): ++ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) ++ emb = torch.cat((freqs, freqs), dim=-1) ++ cos = emb.cos() ++ sin = emb.sin() ++ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) ++ ++ ++class Phi3SuScaledRotaryEmbedding(Phi3RotaryEmbedding): ++ def __init__(self, dim, config, device=None): ++ warnings.warn( ++ "The class Phi3SuScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers. Please" ++ " use Phi3LongRoPEScaledRotaryEmbedding instead.", ++ FutureWarning, ++ ) ++ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) ++ ++ self.short_factor = config.rope_scaling["short_factor"] ++ self.long_factor = config.rope_scaling["long_factor"] ++ self.original_max_position_embeddings = config.original_max_position_embeddings ++ ++ @torch.no_grad() ++ def forward(self, x, position_ids, seq_len=None): ++ seq_len = torch.max(position_ids) + 1 ++ if seq_len > self.original_max_position_embeddings: ++ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) ++ else: ++ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) ++ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim ++ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) ++ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ++ position_ids_expanded = position_ids[:, None, :].float() ++ # Force float32 since bfloat16 loses precision on long contexts ++ # See /~https://github.com/huggingface/transformers/pull/29285 ++ device_type = x.device.type ++ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ++ with torch.autocast(device_type=device_type, enabled=False): ++ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) ++ emb = torch.cat((freqs, freqs), dim=-1) ++ scale = self.max_position_embeddings / self.original_max_position_embeddings ++ if scale <= 1.0: ++ scaling_factor = 1.0 ++ else: ++ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) ++ cos = emb.cos() * scaling_factor ++ sin = emb.sin() * scaling_factor ++ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) ++ ++ ++class Phi3YarnScaledRotaryEmbedding(Phi3RotaryEmbedding): ++ def __init__(self, dim, config, device=None): ++ warnings.warn( ++ "The class Phi3YarnScaledRotaryEmbedding is deprecated and will be removed in version 5 of Transformers", ++ FutureWarning, ++ ) ++ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) ++ ++ self.short_factor = config.rope_scaling["short_factor"] ++ self.long_factor = config.rope_scaling["long_factor"] ++ self.original_max_position_embeddings = config.original_max_position_embeddings ++ ++ @torch.no_grad() ++ def forward(self, x, position_ids, seq_len=None): ++ seq_len = torch.max(position_ids) + 1 ++ if seq_len > self.original_max_position_embeddings: ++ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) ++ else: ++ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) ++ ++ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim ++ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) ++ ++ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ++ position_ids_expanded = position_ids[:, None, :].float() ++ ++ # Force float32 since bfloat16 loses precision on long contexts ++ # See /~https://github.com/huggingface/transformers/pull/29285 ++ device_type = x.device.type ++ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ++ with torch.autocast(device_type=device_type, enabled=False): ++ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) ++ emb = torch.cat((freqs, freqs), dim=-1) ++ ++ scale = self.max_position_embeddings / self.original_max_position_embeddings ++ if scale <= 1.0: ++ scaling_factor = 1.0 ++ else: ++ scaling_factor = 0.1 * math.log(scale) + 1.0 ++ ++ cos = emb.cos() * scaling_factor ++ sin = emb.sin() * scaling_factor ++ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) ++ ++ ++class Phi3LongRoPEScaledRotaryEmbedding(Phi3RotaryEmbedding): ++ def __init__(self, dim, config, device=None): ++ super().__init__(dim, config.max_position_embeddings, config.rope_theta, device) ++ ++ self.short_factor = config.rope_scaling["short_factor"] ++ self.long_factor = config.rope_scaling["long_factor"] ++ self.original_max_position_embeddings = config.original_max_position_embeddings ++ ++ @torch.no_grad() ++ def forward(self, x, position_ids, seq_len=None): ++ seq_len = seq_len or torch.max(position_ids) + 1 ++ if seq_len > self.original_max_position_embeddings: ++ ext_factors = torch.tensor(self.long_factor, dtype=torch.float32, device=x.device) ++ else: ++ ext_factors = torch.tensor(self.short_factor, dtype=torch.float32, device=x.device) ++ ++ inv_freq_shape = torch.arange(0, self.dim, 2, dtype=torch.int64, device=x.device).float() / self.dim ++ self.inv_freq = 1.0 / (ext_factors * self.base**inv_freq_shape) ++ ++ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) ++ position_ids_expanded = position_ids[:, None, :].float() ++ ++ # Force float32 since bfloat16 loses precision on long contexts ++ # See /~https://github.com/huggingface/transformers/pull/29285 ++ device_type = x.device.type ++ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu" ++ with torch.autocast(device_type=device_type, enabled=False): ++ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) ++ emb = torch.cat((freqs, freqs), dim=-1) ++ ++ scale = self.max_position_embeddings / self.original_max_position_embeddings ++ if scale <= 1.0: ++ scaling_factor = 1.0 ++ else: ++ scaling_factor = math.sqrt(1 + math.log(scale) / math.log(self.original_max_position_embeddings)) ++ ++ cos = emb.cos() * scaling_factor ++ sin = emb.sin() * scaling_factor ++ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) ++ ++ ++# Copied from transformers.models.llama.modeling_llama.rotate_half ++def rotate_half(x): ++ """Rotates half the hidden dims of the input.""" ++ x1 = x[..., : x.shape[-1] // 2] ++ x2 = x[..., x.shape[-1] // 2 :] ++ return torch.cat((-x2, x1), dim=-1) ++ ++ ++def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): ++ """Applies Rotary Position Embedding to the query and key tensors. ++ ++ Args: ++ q (`torch.Tensor`): The query tensor. ++ k (`torch.Tensor`): The key tensor. ++ cos (`torch.Tensor`): The cosine part of the rotary embedding. ++ sin (`torch.Tensor`): The sine part of the rotary embedding. ++ position_ids (`torch.Tensor`, *optional*): ++ Deprecated and unused. ++ unsqueeze_dim (`int`, *optional*, defaults to 1): ++ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and ++ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note ++ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and ++ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes ++ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have ++ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. ++ Returns: ++ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. ++ """ ++ cos = cos.unsqueeze(unsqueeze_dim) ++ sin = sin.unsqueeze(unsqueeze_dim) ++ ++ rotary_dim = cos.shape[-1] ++ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:] ++ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:] ++ ++ q_embed = torch.cat([(q_rot * cos) + (rotate_half(q_rot) * sin), q_pass], dim=-1) ++ k_embed = torch.cat([(k_rot * cos) + (rotate_half(k_rot) * sin), k_pass], dim=-1) ++ return q_embed, k_embed ++ ++ ++class Phi3MLP(nn.Module): ++ def __init__(self, config): ++ super().__init__() ++ ++ self.config = config ++ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False) ++ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) ++ ++ self.activation_fn = ACT2FN[config.hidden_act] ++ ++ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: ++ up_states = self.gate_up_proj(hidden_states) ++ ++ gate, up_states = up_states.chunk(2, dim=-1) ++ up_states = up_states * self.activation_fn(gate) ++ ++ return self.down_proj(up_states) ++ ++ ++# Copied from transformers.models.llama.modeling_llama.repeat_kv with llama->phi ++def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: ++ """ ++ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, ++ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) ++ """ ++ batch, num_key_value_heads, slen, head_dim = hidden_states.shape ++ if n_rep == 1: ++ return hidden_states ++ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) ++ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) ++ ++ ++class Phi3Attention(nn.Module): ++ """Multi-headed attention from 'Attention Is All You Need' paper""" ++ ++ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): ++ super().__init__() ++ self.config = config ++ self.layer_idx = layer_idx ++ if layer_idx is None: ++ logger.warning_once( ++ f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " ++ "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " ++ "when creating this class." ++ ) ++ ++ self.attention_dropout = config.attention_dropout ++ self.hidden_size = config.hidden_size ++ self.num_heads = config.num_attention_heads ++ self.head_dim = self.hidden_size // self.num_heads ++ self.num_key_value_heads = config.num_key_value_heads ++ self.num_key_value_groups = self.num_heads // self.num_key_value_heads ++ self.max_position_embeddings = config.max_position_embeddings ++ self.original_max_position_embeddings = config.original_max_position_embeddings ++ self.rope_theta = config.rope_theta ++ self.rope_scaling = config.rope_scaling ++ self.rotary_ndims = int(self.head_dim * config.partial_rotary_factor) ++ self.is_causal = True ++ ++ if (self.head_dim * self.num_heads) != self.hidden_size: ++ raise ValueError( ++ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" ++ f" and `num_heads`: {self.num_heads})." ++ ) ++ ++ op_size = self.num_heads * self.head_dim + 2 * (self.num_key_value_heads * self.head_dim) ++ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) ++ self.qkv_proj = nn.Linear(self.hidden_size, op_size, bias=False) ++ self._init_rope() ++ ++ def _init_rope(self): ++ if self.rope_scaling is None: ++ self.rotary_emb = Phi3RotaryEmbedding( ++ self.rotary_ndims, ++ max_position_embeddings=self.max_position_embeddings, ++ base=self.rope_theta, ++ ) ++ else: ++ scaling_type = self.config.rope_scaling["type"] ++ if scaling_type == "longrope": ++ self.rotary_emb = Phi3LongRoPEScaledRotaryEmbedding(self.rotary_ndims, self.config) ++ else: ++ raise ValueError(f"Unknown RoPE scaling type {scaling_type}") ++ ++ def forward( ++ self, ++ hidden_states: torch.Tensor, ++ attention_mask: Optional[torch.Tensor] = None, ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_value: Optional[Cache] = None, ++ output_attentions: bool = False, ++ use_cache: bool = False, ++ cache_position: Optional[torch.LongTensor] = None, ++ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ++ logger.warning_once("You are not running the flash-attention implementation, expect numerical differences.") ++ ++ bsz, q_len, _ = hidden_states.size() ++ ++ qkv = self.qkv_proj(hidden_states) ++ query_pos = self.num_heads * self.head_dim ++ query_states = qkv[..., :query_pos] ++ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] ++ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] ++ ++ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ++ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ++ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ++ ++ kv_seq_len = key_states.shape[-2] ++ if past_key_value is not None: ++ if self.layer_idx is None: ++ raise ValueError( ++ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " ++ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " ++ "with a layer index." ++ ) ++ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) ++ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) ++ ++ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) ++ ++ if past_key_value is not None: ++ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models ++ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) ++ ++ # repeat k/v heads if n_kv_heads < n_heads ++ key_states = repeat_kv(key_states, self.num_key_value_groups) ++ value_states = repeat_kv(value_states, self.num_key_value_groups) ++ ++ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) ++ ++ if attention_mask is not None: ++ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] ++ attn_weights += causal_mask ++ ++ # upcast attention to fp32 ++ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(value_states.dtype) ++ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) ++ ++ attn_output = torch.matmul(attn_weights, value_states) ++ ++ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): ++ raise ValueError( ++ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" ++ f" {attn_output.size()}" ++ ) ++ ++ attn_output = attn_output.transpose(1, 2).contiguous() ++ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) ++ ++ attn_output = self.o_proj(attn_output) ++ ++ if not output_attentions: ++ attn_weights = None ++ ++ return attn_output, attn_weights, past_key_value ++ ++ ++class Phi3FlashAttention2(Phi3Attention): ++ """ ++ Phi-3 flash attention module. This module inherits from `Phi3Attention` as the weights of the module stays ++ untouched. The only required change would be on the forward pass where it needs to correctly call the public API of ++ flash attention and deal with padding tokens in case the input contains any of them. ++ """ ++ ++ # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ ++ def __init__(self, *args, **kwargs): ++ super().__init__(*args, **kwargs) ++ ++ # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. ++ # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignment, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: /~https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. ++ # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). ++ self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() ++ ++ def forward( ++ self, ++ hidden_states: torch.Tensor, ++ attention_mask: Optional[torch.LongTensor] = None, ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_value: Optional[Cache] = None, ++ output_attentions: bool = False, ++ use_cache: bool = False, ++ cache_position: Optional[torch.LongTensor] = None, ++ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ++ # Phi3FlashAttention2 attention does not support output_attentions ++ ++ output_attentions = False ++ ++ bsz, q_len, _ = hidden_states.size() ++ ++ qkv = self.qkv_proj(hidden_states) ++ query_pos = self.num_heads * self.head_dim ++ query_states = qkv[..., :query_pos] ++ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] ++ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] ++ ++ # Flash attention requires the input to have the shape ++ # batch_size x seq_length x head_dim x hidden_dim ++ # therefore we just need to keep the original shape ++ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ++ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ++ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ++ ++ kv_seq_len = key_states.shape[-2] ++ if past_key_value is not None: ++ if self.layer_idx is None: ++ raise ValueError( ++ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " ++ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " ++ "with a layer index." ++ ) ++ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) ++ ++ # Because the input can be padded, the absolute sequence length depends on the max position id. ++ rotary_seq_len = ( ++ max(kv_seq_len, position_ids[:, -1].max().item() + 1) if position_ids is not None else kv_seq_len ++ ) ++ ++ cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len, position_ids=position_ids) ++ ++ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) ++ ++ if past_key_value is not None: ++ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models ++ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) ++ ++ # repeat k/v heads if n_kv_heads < n_heads ++ key_states = repeat_kv(key_states, self.num_key_value_groups) ++ value_states = repeat_kv(value_states, self.num_key_value_groups) ++ ++ attn_dropout = self.attention_dropout if self.training else 0.0 ++ ++ # In PEFT, usually we cast the layer norms in float32 for training stability reasons ++ # therefore the input hidden states gets silently casted in float32. Hence, we need ++ # cast them back in the correct dtype just to be sure everything works as expected. ++ # This might slowdown training & inference so it is recommended to not cast the LayerNorms ++ # in fp32. ++ ++ if query_states.dtype == torch.float32: ++ if torch.is_autocast_enabled(): ++ target_dtype = torch.get_autocast_gpu_dtype() ++ # Handle the case where the model is quantized ++ elif hasattr(self.config, "_pre_quantization_dtype"): ++ target_dtype = self.config._pre_quantization_dtype ++ else: ++ target_dtype = self.qkv_proj.weight.dtype ++ ++ logger.warning_once( ++ f"The input hidden states seems to be silently casted in float32, this might be related to" ++ f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" ++ f" {target_dtype}." ++ ) ++ ++ query_states = query_states.to(target_dtype) ++ key_states = key_states.to(target_dtype) ++ value_states = value_states.to(target_dtype) ++ ++ # Reashape to the expected shape for Flash Attention ++ query_states = query_states.transpose(1, 2) ++ key_states = key_states.transpose(1, 2) ++ value_states = value_states.transpose(1, 2) ++ ++ attn_output = _flash_attention_forward( ++ query_states, ++ key_states, ++ value_states, ++ attention_mask, ++ q_len, ++ position_ids=position_ids, ++ dropout=attn_dropout, ++ sliding_window=getattr(self.config, "sliding_window", None), ++ use_top_left_mask=self._flash_attn_uses_top_left_mask, ++ is_causal=self.is_causal, ++ ) ++ ++ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() ++ attn_output = self.o_proj(attn_output) ++ ++ if not output_attentions: ++ attn_weights = None ++ ++ return attn_output, attn_weights, past_key_value ++ ++ ++# copied from transformers.models.llama.modeling_llama.LlamaSdpaAttention with Llama->Phi3 ++# TODO @Arthur no longer copied from LLama after static cache ++class Phi3SdpaAttention(Phi3Attention): ++ """ ++ Phi3 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from ++ `Phi3Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to ++ SDPA API. ++ """ ++ ++ # Adapted from Phi3Attention.forward ++ def forward( ++ self, ++ hidden_states: torch.Tensor, ++ attention_mask: Optional[torch.Tensor] = None, ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_value: Optional[Cache] = None, ++ output_attentions: bool = False, ++ use_cache: bool = False, ++ cache_position: Optional[torch.LongTensor] = None, ++ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ++ if output_attentions: ++ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. ++ logger.warning_once( ++ "Phi3Model is using Phi3SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " ++ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' ++ ) ++ return super().forward( ++ hidden_states=hidden_states, ++ attention_mask=attention_mask, ++ position_ids=position_ids, ++ past_key_value=past_key_value, ++ output_attentions=output_attentions, ++ use_cache=use_cache, ++ ) ++ ++ bsz, q_len, _ = hidden_states.size() ++ ++ qkv = self.qkv_proj(hidden_states) ++ query_pos = self.num_heads * self.head_dim ++ query_states = qkv[..., :query_pos] ++ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] ++ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] ++ ++ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ++ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ++ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ++ ++ kv_seq_len = key_states.shape[-2] ++ if past_key_value is not None: ++ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) ++ cos, sin = self.rotary_emb(value_states, position_ids, seq_len=kv_seq_len) ++ ++ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) ++ ++ if past_key_value is not None: ++ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models ++ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) ++ ++ key_states = repeat_kv(key_states, self.num_key_value_groups) ++ value_states = repeat_kv(value_states, self.num_key_value_groups) ++ ++ causal_mask = attention_mask ++ if attention_mask is not None: ++ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] ++ ++ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, ++ # Reference: /~https://github.com/pytorch/pytorch/issues/112577. ++ if query_states.device.type == "cuda" and attention_mask is not None: ++ query_states = query_states.contiguous() ++ key_states = key_states.contiguous() ++ value_states = value_states.contiguous() ++ ++ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment ++ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. ++ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. ++ is_causal = True if causal_mask is None and q_len > 1 else False ++ ++ attn_output = torch.nn.functional.scaled_dot_product_attention( ++ query_states, ++ key_states, ++ value_states, ++ attn_mask=causal_mask, ++ dropout_p=self.attention_dropout if self.training else 0.0, ++ is_causal=is_causal, ++ ) ++ ++ attn_output = attn_output.transpose(1, 2).contiguous() ++ attn_output = attn_output.view(bsz, q_len, self.hidden_size) ++ ++ attn_output = self.o_proj(attn_output) ++ ++ return attn_output, None, past_key_value ++ ++ ++PHI3_ATTENTION_CLASSES = { ++ "eager": Phi3Attention, ++ "flash_attention_2": Phi3FlashAttention2, ++ "sdpa": Phi3SdpaAttention, ++} ++ ++ ++class Phi3DecoderLayer(nn.Module): ++ def __init__(self, config: Phi3Config, layer_idx: int): ++ super().__init__() ++ ++ self.config = config ++ self.self_attn = PHI3_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx=layer_idx) ++ ++ self.mlp = Phi3MLP(config) ++ self.input_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ ++ self.resid_attn_dropout = nn.Dropout(config.resid_pdrop) ++ self.resid_mlp_dropout = nn.Dropout(config.resid_pdrop) ++ self.post_attention_layernorm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ ++ def forward( ++ self, ++ hidden_states: torch.Tensor, ++ attention_mask: Optional[torch.Tensor] = None, ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_value: Optional[Tuple[torch.Tensor]] = None, ++ output_attentions: Optional[bool] = False, ++ use_cache: Optional[bool] = False, ++ cache_position: Optional[torch.LongTensor] = None, ++ **kwargs, ++ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ++ """ ++ Args: ++ hidden_states (`torch.FloatTensor`): ++ input to the layer of shape `(batch, seq_len, embed_dim)` ++ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size ++ `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. ++ position_ids (`torch.LongTensor` of shape `({0})`, *optional*): ++ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ++ `[0, config.n_positions - 1]`. [What are position IDs?](../glossary#position-ids) ++ output_attentions (`bool`, *optional*): ++ Whether or not to return the attentions tensors of all attention layers. See `attentions` under ++ returned tensors for more detail. ++ use_cache (`bool`, *optional*): ++ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding ++ (see `past_key_values`). ++ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states ++ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): ++ Indices depicting the position of the input sequence tokens in the sequence ++ kwargs (`dict`, *optional*): ++ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code ++ into the model ++ """ ++ ++ residual = hidden_states ++ ++ hidden_states = self.input_layernorm(hidden_states) ++ ++ # Self Attention ++ attn_outputs, self_attn_weights, present_key_value = self.self_attn( ++ hidden_states=hidden_states, ++ attention_mask=attention_mask, ++ position_ids=position_ids, ++ past_key_value=past_key_value, ++ output_attentions=output_attentions, ++ use_cache=use_cache, ++ cache_position=cache_position, ++ ) ++ ++ hidden_states = residual + self.resid_attn_dropout(attn_outputs) ++ ++ residual = hidden_states ++ hidden_states = self.post_attention_layernorm(hidden_states) ++ hidden_states = self.mlp(hidden_states) ++ hidden_states = residual + self.resid_mlp_dropout(hidden_states) ++ ++ outputs = (hidden_states,) ++ ++ if output_attentions: ++ outputs += (self_attn_weights,) ++ ++ if use_cache: ++ outputs += (present_key_value,) ++ ++ return outputs ++ ++ ++PHI3_START_DOCSTRING = r""" ++ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the ++ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads ++ etc.) ++ ++ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. ++ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage ++ and behavior. ++ ++ Parameters: ++ config ([`Phi3Config`]): ++ Model configuration class with all the parameters of the model. Initializing with a config file does not ++ load the weights associated with the model, only the configuration. Check out the ++ [`~PreTrainedModel.from_pretrained`] method to load the model weights. ++""" ++ ++ ++@add_start_docstrings( ++ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.", ++ PHI3_START_DOCSTRING, ++) ++class Phi3PreTrainedModel(PreTrainedModel): ++ config_class = Phi3Config ++ base_model_prefix = "model" ++ supports_gradient_checkpointing = True ++ _no_split_modules = ["Phi3DecoderLayer"] ++ _skip_keys_device_placement = "past_key_values" ++ _supports_flash_attn_2 = True ++ _supports_sdpa = True ++ _supports_cache_class = True ++ ++ _version = "0.0.5" ++ ++ def _init_weights(self, module): ++ std = self.config.initializer_range ++ if isinstance(module, nn.Linear): ++ module.weight.data.normal_(mean=0.0, std=std) ++ if module.bias is not None: ++ module.bias.data.zero_() ++ elif isinstance(module, nn.Embedding): ++ module.weight.data.normal_(mean=0.0, std=std) ++ if module.padding_idx is not None: ++ module.weight.data[module.padding_idx].zero_() ++ ++ ++PHI3_INPUTS_DOCSTRING = r""" ++ Args: ++ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): ++ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide ++ it. ++ ++ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and ++ [`PreTrainedTokenizer.__call__`] for details. ++ ++ [What are input IDs?](../glossary#input-ids) ++ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): ++ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: ++ ++ - 1 for tokens that are **not masked**, ++ - 0 for tokens that are **masked**. ++ ++ [What are attention masks?](../glossary#attention-mask) ++ ++ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and ++ [`PreTrainedTokenizer.__call__`] for details. ++ ++ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see ++ `past_key_values`). ++ ++ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] ++ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more ++ information on the default strategy. ++ ++ - 1 indicates the head is **not masked**, ++ - 0 indicates the head is **masked**. ++ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): ++ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, ++ config.n_positions - 1]`. ++ ++ [What are position IDs?](../glossary#position-ids) ++ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*): ++ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention ++ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values` ++ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`. ++ ++ Two formats are allowed: ++ - a [`~cache_utils.Cache`] instance, see our ++ [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache); ++ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of ++ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy ++ cache format. ++ ++ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the ++ legacy cache format will be returned. ++ ++ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't ++ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids` ++ of shape `(batch_size, sequence_length)`. ++ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): ++ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This ++ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the ++ model's internal embedding lookup matrix. ++ use_cache (`bool`, *optional*): ++ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see ++ `past_key_values`). ++ output_attentions (`bool`, *optional*): ++ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned ++ tensors for more detail. ++ output_hidden_states (`bool`, *optional*): ++ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for ++ more detail. ++ return_dict (`bool`, *optional*): ++ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. ++ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): ++ Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`, ++ this tensor is not affected by padding. It is used to update the cache in the correct position and to infer ++ the complete sequence length. ++""" ++ ++ ++@add_start_docstrings( ++ "The bare Phi-3 model outputting raw hidden-states without any specific head on top.", ++ PHI3_START_DOCSTRING, ++) ++class Phi3Model(Phi3PreTrainedModel): ++ """ ++ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Phi3DecoderLayer`] ++ ++ Args: ++ config: Phi3Config ++ """ ++ ++ def __init__(self, config: Phi3Config): ++ super().__init__(config) ++ self.padding_idx = config.pad_token_id ++ self.vocab_size = config.vocab_size ++ ++ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) ++ self.embed_dropout = nn.Dropout(config.embd_pdrop) ++ self.layers = nn.ModuleList( ++ [Phi3DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] ++ ) ++ self._attn_implementation = config._attn_implementation ++ self.norm = Phi3RMSNorm(config.hidden_size, eps=config.rms_norm_eps) ++ ++ self.gradient_checkpointing = False ++ # Initialize weights and apply final processing ++ self.post_init() ++ ++ def get_input_embeddings(self): ++ return self.embed_tokens ++ ++ def set_input_embeddings(self, value): ++ self.embed_tokens = value ++ ++ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) ++ def forward( ++ self, ++ input_ids: torch.LongTensor = None, ++ attention_mask: Optional[torch.Tensor] = None, ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_values: Optional[List[torch.FloatTensor]] = None, ++ inputs_embeds: Optional[torch.FloatTensor] = None, ++ use_cache: Optional[bool] = None, ++ output_attentions: Optional[bool] = None, ++ output_hidden_states: Optional[bool] = None, ++ return_dict: Optional[bool] = None, ++ cache_position: Optional[torch.LongTensor] = None, ++ ) -> Union[Tuple, BaseModelOutputWithPast]: ++ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions ++ output_hidden_states = ( ++ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ++ ) ++ use_cache = use_cache if use_cache is not None else self.config.use_cache ++ ++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict ++ ++ if (input_ids is None) ^ (inputs_embeds is not None): ++ raise ValueError("You must specify exactly one of input_ids or inputs_embeds") ++ ++ if self.gradient_checkpointing and self.training: ++ if use_cache: ++ logger.warning_once( ++ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." ++ ) ++ use_cache = False ++ ++ # kept for BC (non `Cache` `past_key_values` inputs) ++ return_legacy_cache = False ++ if use_cache and not isinstance(past_key_values, Cache): ++ return_legacy_cache = True ++ if past_key_values is None: ++ past_key_values = DynamicCache() ++ else: ++ past_key_values = DynamicCache.from_legacy_cache(past_key_values) ++ logger.warning_once( ++ "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " ++ "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " ++ "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" ++ ) ++ ++ if inputs_embeds is None: ++ inputs_embeds = self.embed_tokens(input_ids) ++ ++ if cache_position is None: ++ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 ++ cache_position = torch.arange( ++ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ++ ) ++ if position_ids is None: ++ position_ids = cache_position.unsqueeze(0) ++ ++ causal_mask = self._update_causal_mask( ++ attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions ++ ) ++ ++ hidden_states = inputs_embeds ++ ++ # decoder layers ++ all_hidden_states = () if output_hidden_states else None ++ all_self_attns = () if output_attentions else None ++ next_decoder_cache = None ++ ++ for decoder_layer in self.layers: ++ if output_hidden_states: ++ all_hidden_states += (hidden_states,) ++ ++ if self.gradient_checkpointing and self.training: ++ layer_outputs = self._gradient_checkpointing_func( ++ decoder_layer.__call__, ++ hidden_states, ++ causal_mask, ++ position_ids, ++ past_key_values, ++ output_attentions, ++ use_cache, ++ cache_position, ++ ) ++ else: ++ layer_outputs = decoder_layer( ++ hidden_states, ++ attention_mask=causal_mask, ++ position_ids=position_ids, ++ past_key_value=past_key_values, ++ output_attentions=output_attentions, ++ use_cache=use_cache, ++ cache_position=cache_position, ++ ) ++ ++ hidden_states = layer_outputs[0] ++ ++ if use_cache: ++ next_decoder_cache = layer_outputs[2 if output_attentions else 1] ++ ++ if output_attentions: ++ all_self_attns += (layer_outputs[1],) ++ ++ hidden_states = self.norm(hidden_states) ++ ++ # add hidden states from the last decoder layer ++ if output_hidden_states: ++ all_hidden_states += (hidden_states,) ++ ++ next_cache = next_decoder_cache if use_cache else None ++ if return_legacy_cache: ++ next_cache = next_cache.to_legacy_cache() ++ ++ if not return_dict: ++ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) ++ return BaseModelOutputWithPast( ++ last_hidden_state=hidden_states, ++ past_key_values=next_cache, ++ hidden_states=all_hidden_states, ++ attentions=all_self_attns, ++ ) ++ ++ def _update_causal_mask( ++ self, ++ attention_mask: torch.Tensor, ++ input_tensor: torch.Tensor, ++ cache_position: torch.Tensor, ++ past_key_values: Cache, ++ output_attentions: bool, ++ ): ++ if self.config._attn_implementation == "flash_attention_2": ++ if attention_mask is not None and 0.0 in attention_mask: ++ return attention_mask ++ return None ++ ++ # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in ++ # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail ++ # to infer the attention mask. ++ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 ++ using_static_cache = isinstance(past_key_values, StaticCache) ++ using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache) ++ ++ # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward ++ if ( ++ self.config._attn_implementation == "sdpa" ++ and not (using_static_cache or using_sliding_window_cache) ++ and not output_attentions ++ ): ++ if AttentionMaskConverter._ignore_causal_mask_sdpa( ++ attention_mask, ++ inputs_embeds=input_tensor, ++ past_key_values_length=past_seen_tokens, ++ sliding_window=self.config.sliding_window, ++ is_training=self.training, ++ ): ++ return None ++ ++ dtype, device = input_tensor.dtype, input_tensor.device ++ min_dtype = torch.finfo(dtype).min ++ sequence_length = input_tensor.shape[1] ++ # SlidingWindowCache or StaticCache ++ if using_sliding_window_cache or using_static_cache: ++ target_length = past_key_values.get_max_cache_shape() ++ # DynamicCache or no cache ++ else: ++ target_length = ( ++ attention_mask.shape[-1] ++ if isinstance(attention_mask, torch.Tensor) ++ else past_seen_tokens + sequence_length + 1 ++ ) ++ ++ # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). ++ causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( ++ attention_mask, ++ sequence_length=sequence_length, ++ target_length=target_length, ++ dtype=dtype, ++ device=device, ++ cache_position=cache_position, ++ batch_size=input_tensor.shape[0], ++ config=self.config, ++ past_key_values=past_key_values, ++ ) ++ ++ if ( ++ self.config._attn_implementation == "sdpa" ++ and attention_mask is not None ++ and attention_mask.device.type == "cuda" ++ and not output_attentions ++ ): ++ # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when ++ # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. ++ # Details: /~https://github.com/pytorch/pytorch/issues/110213 ++ causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) ++ ++ return causal_mask ++ ++ @staticmethod ++ # Copied from transformers.models.mistral.modeling_mistral.MistralModel._prepare_4d_causal_attention_mask_with_cache_position with Mistral->Phi3 ++ def _prepare_4d_causal_attention_mask_with_cache_position( ++ attention_mask: torch.Tensor, ++ sequence_length: int, ++ target_length: int, ++ dtype: torch.dtype, ++ device: torch.device, ++ cache_position: torch.Tensor, ++ batch_size: int, ++ config: Phi3Config, ++ past_key_values: Cache, ++ ): ++ """ ++ Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape ++ `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. ++ ++ Args: ++ attention_mask (`torch.Tensor`): ++ A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape `(batch_size, 1, query_length, key_value_length)`. ++ sequence_length (`int`): ++ The sequence length being processed. ++ target_length (`int`): ++ The target length: when generating with static cache, the mask should be as long as the static cache, to account for the 0 padding, the part of the cache that is not filled yet. ++ dtype (`torch.dtype`): ++ The dtype to use for the 4D attention mask. ++ device (`torch.device`): ++ The device to place the 4D attention mask on. ++ cache_position (`torch.Tensor`): ++ Indices depicting the position of the input sequence tokens in the sequence. ++ batch_size (`torch.Tensor`): ++ Batch size. ++ config (`Phi3Config`): ++ The model's configuration class ++ past_key_values (`Cache`): ++ The cache class that is being used currently to generate ++ """ ++ if attention_mask is not None and attention_mask.dim() == 4: ++ # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. ++ causal_mask = attention_mask ++ else: ++ min_dtype = torch.finfo(dtype).min ++ causal_mask = torch.full( ++ (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device ++ ) ++ diagonal_attend_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) ++ if config.sliding_window is not None: ++ # if we have sliding window, we should not attend to tokens beyond sliding window length, so we mask them out also ++ # the check is needed to verify is current checkpoint was trained with sliding window or not ++ if not isinstance(past_key_values, SlidingWindowCache) or sequence_length > target_length: ++ sliding_attend_mask = torch.arange(target_length, device=device) <= ( ++ cache_position.reshape(-1, 1) - config.sliding_window ++ ) ++ diagonal_attend_mask.bitwise_or_(sliding_attend_mask) ++ causal_mask *= diagonal_attend_mask ++ causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) ++ if attention_mask is not None: ++ causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit ++ if attention_mask.shape[-1] > target_length: ++ attention_mask = attention_mask[:, :target_length] ++ mask_length = attention_mask.shape[-1] ++ padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] ++ padding_mask = padding_mask == 0 ++ causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( ++ padding_mask, min_dtype ++ ) ++ return causal_mask ++ ++ ++class Phi3ForCausalLM(Phi3PreTrainedModel, GenerationMixin): ++ _tied_weights_keys = ["lm_head.weight"] ++ ++ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.__init__ with Llama->Phi3 ++ def __init__(self, config): ++ super().__init__(config) ++ self.model = Phi3Model(config) ++ self.vocab_size = config.vocab_size ++ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) ++ ++ # Initialize weights and apply final processing ++ self.post_init() ++ ++ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_input_embeddings ++ def get_input_embeddings(self): ++ return self.model.embed_tokens ++ ++ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_input_embeddings ++ def set_input_embeddings(self, value): ++ self.model.embed_tokens = value ++ ++ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_output_embeddings ++ def get_output_embeddings(self): ++ return self.lm_head ++ ++ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_output_embeddings ++ def set_output_embeddings(self, new_embeddings): ++ self.lm_head = new_embeddings ++ ++ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.set_decoder ++ def set_decoder(self, decoder): ++ self.model = decoder ++ ++ # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.get_decoder ++ def get_decoder(self): ++ return self.model ++ ++ # Ignore copy ++ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) ++ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) ++ def forward( ++ self, ++ input_ids: torch.LongTensor = None, ++ attention_mask: Optional[torch.Tensor] = None, ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_values: Optional[List[torch.FloatTensor]] = None, ++ inputs_embeds: Optional[torch.FloatTensor] = None, ++ labels: Optional[torch.LongTensor] = None, ++ use_cache: Optional[bool] = None, ++ output_attentions: Optional[bool] = None, ++ output_hidden_states: Optional[bool] = None, ++ return_dict: Optional[bool] = None, ++ cache_position: Optional[torch.LongTensor] = None, ++ num_logits_to_keep: int = 0, ++ **loss_kwargs, ++ ) -> Union[Tuple, CausalLMOutputWithPast]: ++ r""" ++ Args: ++ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): ++ Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., ++ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored ++ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. ++ ++ num_logits_to_keep (`int`, *optional*): ++ Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all ++ `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that ++ token can save memory, which becomes pretty significant for long sequences or large vocabulary size. ++ ++ Returns: ++ ++ Example: ++ ++ ```python ++ >>> from transformers import AutoTokenizer, Phi3ForCausalLM ++ ++ >>> model = Phi3ForCausalLM.from_pretrained("microsoft/phi-3-mini-4k-instruct") ++ >>> tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-3-mini-4k-instruct") ++ ++ >>> prompt = "This is an example script ." ++ >>> inputs = tokenizer(prompt, return_tensors="pt") ++ ++ >>> # Generate ++ >>> generate_ids = model.generate(inputs.input_ids, max_length=30) ++ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] ++ 'This is an example script .\n Certainly! Below is a sample script that demonstrates a simple task, such as calculating the sum' ++ ```""" ++ if ( ++ use_cache ++ and self.config.rope_scaling ++ and cache_position is not None ++ and cache_position[0] == self.config.original_max_position_embeddings ++ ): ++ logger.warning( ++ f"If you are not using the generate method, you may encounter nonsensical outputs after the {self.config.original_max_position_embeddings}th token, as the KV cache needs to be recomputed." ++ ) ++ ++ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions ++ output_hidden_states = ( ++ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ++ ) ++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict ++ ++ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) ++ outputs = self.model( ++ input_ids=input_ids, ++ attention_mask=attention_mask, ++ position_ids=position_ids, ++ past_key_values=past_key_values, ++ inputs_embeds=inputs_embeds, ++ use_cache=use_cache, ++ output_attentions=output_attentions, ++ output_hidden_states=output_hidden_states, ++ return_dict=return_dict, ++ ) ++ ++ hidden_states = outputs[0] ++ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss ++ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) ++ ++ loss = None ++ if labels is not None: ++ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs) ++ ++ if not return_dict: ++ output = (logits,) + outputs[1:] ++ return (loss,) + output if loss is not None else output ++ ++ return CausalLMOutputWithPast( ++ loss=loss, ++ logits=logits, ++ past_key_values=outputs.past_key_values, ++ hidden_states=outputs.hidden_states, ++ attentions=outputs.attentions, ++ ) ++ ++ def prepare_inputs_for_generation( ++ self, ++ input_ids, ++ past_key_values=None, ++ attention_mask=None, ++ inputs_embeds=None, ++ cache_position=None, ++ position_ids=None, ++ use_cache=True, ++ num_logits_to_keep=None, ++ **kwargs, ++ ): ++ # Overwritten -- this model may need to switch between short and long rope, invalidating the cache in the ++ # process ++ ++ # When the first time input length reached long and short factor switching point, enforce re-compute cache ++ # It will cause downside of slower at this single token position, however, better than current failure. ++ if ( ++ past_key_values ++ and self.config.rope_scaling ++ and input_ids.shape[1] >= self.config.original_max_position_embeddings + 1 ++ ): ++ past_length = cache_position[0] ++ if past_length <= self.config.original_max_position_embeddings: ++ past_key_values = None ++ ++ model_inputs = super().prepare_inputs_for_generation( ++ input_ids=input_ids, ++ past_key_values=past_key_values, ++ attention_mask=attention_mask, ++ inputs_embeds=inputs_embeds, ++ cache_position=cache_position, ++ position_ids=position_ids, ++ use_cache=use_cache, ++ num_logits_to_keep=num_logits_to_keep, ++ **kwargs, ++ ) ++ return model_inputs ++ ++ ++@add_start_docstrings( ++ """ ++ The [`Phi3Model`] with a sequence classification head on top (linear layer). ++ ++ [`Phi3ForSequenceClassification`] uses the last token in order to do the classification, as other causal models ++ (e.g. GPT-2) do. ++ ++ Since it does classification on the last token, it requires to know the position of the last token. If a ++ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If ++ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the ++ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in ++ each row of the batch). ++ """, ++ PHI3_START_DOCSTRING, ++) ++# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with Llama->Phi3, LLAMA->PHI3, self.transformer->self.model, transformer_outputs->model_outputs ++class Phi3ForSequenceClassification(Phi3PreTrainedModel): ++ def __init__(self, config): ++ super().__init__(config) ++ self.num_labels = config.num_labels ++ self.model = Phi3Model(config) ++ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False) ++ ++ # Initialize weights and apply final processing ++ self.post_init() ++ ++ def get_input_embeddings(self): ++ return self.model.embed_tokens ++ ++ def set_input_embeddings(self, value): ++ self.model.embed_tokens = value ++ ++ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) ++ def forward( ++ self, ++ input_ids: Optional[torch.LongTensor] = None, ++ attention_mask: Optional[torch.Tensor] = None, ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, ++ inputs_embeds: Optional[torch.FloatTensor] = None, ++ labels: Optional[torch.LongTensor] = None, ++ use_cache: Optional[bool] = None, ++ output_attentions: Optional[bool] = None, ++ output_hidden_states: Optional[bool] = None, ++ return_dict: Optional[bool] = None, ++ ) -> Union[Tuple, SequenceClassifierOutputWithPast]: ++ r""" ++ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): ++ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., ++ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If ++ `config.num_labels > 1` a classification loss is computed (Cross-Entropy). ++ """ ++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict ++ ++ model_outputs = self.model( ++ input_ids, ++ attention_mask=attention_mask, ++ position_ids=position_ids, ++ past_key_values=past_key_values, ++ inputs_embeds=inputs_embeds, ++ use_cache=use_cache, ++ output_attentions=output_attentions, ++ output_hidden_states=output_hidden_states, ++ return_dict=return_dict, ++ ) ++ hidden_states = model_outputs[0] ++ logits = self.score(hidden_states) ++ ++ if input_ids is not None: ++ batch_size = input_ids.shape[0] ++ else: ++ batch_size = inputs_embeds.shape[0] ++ ++ if self.config.pad_token_id is None and batch_size != 1: ++ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") ++ if self.config.pad_token_id is None: ++ sequence_lengths = -1 ++ else: ++ if input_ids is not None: ++ # if no pad token found, use modulo instead of reverse indexing for ONNX compatibility ++ sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1 ++ sequence_lengths = sequence_lengths % input_ids.shape[-1] ++ sequence_lengths = sequence_lengths.to(logits.device) ++ else: ++ sequence_lengths = -1 ++ ++ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] ++ ++ loss = None ++ if labels is not None: ++ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config) ++ ++ if not return_dict: ++ output = (pooled_logits,) + model_outputs[1:] ++ return ((loss,) + output) if loss is not None else output ++ ++ return SequenceClassifierOutputWithPast( ++ loss=loss, ++ logits=pooled_logits, ++ past_key_values=model_outputs.past_key_values, ++ hidden_states=model_outputs.hidden_states, ++ attentions=model_outputs.attentions, ++ ) ++ ++ ++@add_start_docstrings( ++ """ ++ [`Phi3Model`] with a token classification head on top (a linear layer on top of the hidden-states output) e.g. for ++ Named-Entity-Recognition (NER) tasks. ++ """, ++ PHI3_START_DOCSTRING, ++) ++# Copied from transformers.models.mpt.modeling_mpt.MptForTokenClassification with Mpt->Phi3,MPT->PHI3,self.transformer->self.model,transformer_outputs->model_outputs ++class Phi3ForTokenClassification(Phi3PreTrainedModel): ++ def __init__(self, config: Phi3Config): ++ super().__init__(config) ++ self.num_labels = config.num_labels ++ ++ self.model = Phi3Model(config) ++ if hasattr(config, "classifier_dropout") and config.classifier_dropout is not None: ++ classifier_dropout = config.classifier_dropout ++ elif hasattr(config, "hidden_dropout") and config.hidden_dropout is not None: ++ classifier_dropout = config.hidden_dropout ++ else: ++ classifier_dropout = 0.1 ++ self.dropout = nn.Dropout(classifier_dropout) ++ self.classifier = nn.Linear(config.hidden_size, config.num_labels) ++ ++ # Initialize weights and apply final processing ++ self.post_init() ++ ++ @add_start_docstrings_to_model_forward(PHI3_INPUTS_DOCSTRING) ++ @add_code_sample_docstrings( ++ checkpoint=_CHECKPOINT_FOR_DOC, ++ output_type=TokenClassifierOutput, ++ config_class=_CONFIG_FOR_DOC, ++ ) ++ def forward( ++ self, ++ input_ids: Optional[torch.LongTensor] = None, ++ past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, ++ attention_mask: Optional[torch.Tensor] = None, ++ inputs_embeds: Optional[torch.Tensor] = None, ++ labels: Optional[torch.Tensor] = None, ++ use_cache: Optional[bool] = None, ++ output_attentions: Optional[bool] = None, ++ output_hidden_states: Optional[bool] = None, ++ return_dict: Optional[bool] = None, ++ **deprecated_arguments, ++ ) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]: ++ r""" ++ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): ++ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., ++ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If ++ `config.num_labels > 1` a classification loss is computed (Cross-Entropy). ++ """ ++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict ++ ++ model_outputs = self.model( ++ input_ids, ++ past_key_values=past_key_values, ++ attention_mask=attention_mask, ++ inputs_embeds=inputs_embeds, ++ use_cache=use_cache, ++ output_attentions=output_attentions, ++ output_hidden_states=output_hidden_states, ++ return_dict=return_dict, ++ ) ++ ++ hidden_states = model_outputs[0] ++ hidden_states = self.dropout(hidden_states) ++ logits = self.classifier(hidden_states) ++ ++ loss = None ++ if labels is not None: ++ # move labels to correct device to enable model parallelism ++ labels = labels.to(logits.device) ++ batch_size, seq_length = labels.shape ++ loss_fct = CrossEntropyLoss() ++ loss = loss_fct( ++ logits.view(batch_size * seq_length, self.num_labels), labels.view(batch_size * seq_length) ++ ) ++ ++ if not return_dict: ++ output = (logits,) + model_outputs[2:] ++ return ((loss,) + output) if loss is not None else output ++ ++ return TokenClassifierOutput( ++ loss=loss, ++ logits=logits, ++ hidden_states=model_outputs.hidden_states, ++ attentions=model_outputs.attentions, ++ ) +diff --git a/optimum/habana/transformers/modeling_utils.py b/optimum/habana/transformers/modeling_utils.py +index a486a7d9..f46cc2ee 100644 +--- a/optimum/habana/transformers/modeling_utils.py ++++ b/optimum/habana/transformers/modeling_utils.py +@@ -148,6 +148,14 @@ from .models import ( + GaudiPhiDecoderLayer, + GaudiPhiForCausalLM, + GaudiPhiModel, ++ GaudiPhi3ForCausalLM, ++ GaudiPhi3Attention, ++ GaudiPhi3DecoderLayer, ++ GaudiPhi3Model, ++ # GaudiPhiOImageEmbedding, ++ # GaudiSiglipAttention, ++ # GaudiSiglipEncoder, ++ # GaudiSiglipEncoderLayer, + GaudiQwen2Attention, + GaudiQwen2DecoderLayer, + GaudiQwen2ForCausalLM, +@@ -559,6 +567,17 @@ def adapt_transformers_to_gaudi(): + transformers.models.phi.modeling_phi.PhiDecoderLayer = GaudiPhiDecoderLayer + transformers.models.phi.modeling_phi.PhiModel = GaudiPhiModel + ++ # Optimization for phi3 on Gaudi ++ transformers.models.phi3.modeling_phi3.Phi3ForCausalLM = GaudiPhi3ForCausalLM ++ transformers.models.phi3.modeling_phi3.Phi3SdpaAttention = GaudiPhi3Attention ++ transformers.models.phi3.modeling_phi3.Phi3DecoderLayer = GaudiPhi3DecoderLayer ++ transformers.models.phi3.modeling_phi3.Phi3Model = GaudiPhi3Model ++ ++ # Optimization for phio on Gaudi ++ # transformers.models.phio.modeling_phio.GaudiSiglipAttention = GaudiSiglipAttention, ++ # transformers.models.phio.modeling_phio.GaudiSiglipEncoder = GaudiSiglipEncoder, ++ # transformers.models.phio.modeling_phio.GaudiSiglipEncoderLayer = GaudiSiglipEncoderLayer, ++ + # Optimization for gemma on Gaudi + transformers.models.gemma.modeling_gemma.GemmaForCausalLM = GaudiGemmaForCausalLM + transformers.models.gemma.modeling_gemma.GemmaMLP = GaudiGemmaMLP +diff --git a/optimum/habana/transformers/models/__init__.py b/optimum/habana/transformers/models/__init__.py +index 3ae3dea2..e5b8548e 100644 +--- a/optimum/habana/transformers/models/__init__.py ++++ b/optimum/habana/transformers/models/__init__.py +@@ -238,6 +238,18 @@ from .phi import ( + GaudiPhiForCausalLM, + GaudiPhiModel, + ) ++from .phi3 import ( ++ GaudiPhi3ForCausalLM, ++ GaudiPhi3Attention, ++ GaudiPhi3DecoderLayer, ++ GaudiPhi3Model, ++) ++# from .phio import ( ++# GaudiPhiOImageEmbedding, ++# ) ++# GaudiSiglipAttention, ++# GaudiSiglipEncoder, ++# GaudiSiglipEncoderLayer, + from .qwen2 import ( + GaudiQwen2Attention, + GaudiQwen2DecoderLayer, +diff --git a/optimum/habana/transformers/models/phi3/__init__.py b/optimum/habana/transformers/models/phi3/__init__.py +new file mode 100644 +index 00000000..c036b7b1 +--- /dev/null ++++ b/optimum/habana/transformers/models/phi3/__init__.py +@@ -0,0 +1,6 @@ ++from .modeling_phi3 import ( ++ GaudiPhi3ForCausalLM, ++ GaudiPhi3Attention, ++ GaudiPhi3DecoderLayer, ++ GaudiPhi3Model, ++) +diff --git a/optimum/habana/transformers/models/phi3/modeling_phi3.py b/optimum/habana/transformers/models/phi3/modeling_phi3.py +new file mode 100644 +index 00000000..07febea3 +--- /dev/null ++++ b/optimum/habana/transformers/models/phi3/modeling_phi3.py +@@ -0,0 +1,621 @@ ++# coding=utf-8 ++# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. ++# ++# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX ++# and OPT implementations in this library. It has been modified from its ++# original forms to accommodate minor architectural differences compared ++# to GPT-NeoX and OPT used by the Meta AI team that trained the model. ++# ++# Licensed under the Apache License, Version 2.0 (the "License"); ++# you may not use this file except in compliance with the License. ++# You may obtain a copy of the License at ++# ++# http://www.apache.org/licenses/LICENSE-2.0 ++# ++# Unless required by applicable law or agreed to in writing, software ++# distributed under the License is distributed on an "AS IS" BASIS, ++# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ++# See the License for the specific language governing permissions and ++# limitations under the License. ++"""PyTorch Phi model.""" ++ ++from typing import List, Optional, Tuple, Union ++ ++import torch ++from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast ++from transformers.models.phi3.modeling_phi3 import ( ++ Phi3ForCausalLM, ++ Phi3Attention, ++ Phi3ForCausalLM, ++ Phi3DecoderLayer, ++ Phi3Model, ++ apply_rotary_pos_emb, ++) ++from transformers.utils import logging ++ ++from ...modeling_attn_mask_utils import ( ++ _gaudi_prepare_4d_causal_attention_mask, ++) ++ ++from transformers.cache_utils import Cache, DynamicCache ++from transformers.models.phi3.configuration_phi3 import Phi3Config ++from ...modeling_attn_mask_utils import ( ++ _gaudi_prepare_4d_causal_attention_mask, ++) ++from ...modeling_rope_utils import GaudiRotaryEmbedding ++from ..modeling_all_models import KVCache, Matmul ++ ++ ++logger = logging.get_logger(__name__) ++ ++ ++def gaudi_phi_repeat_kv( ++ query_states: torch.Tensor, ++ key_states: torch.Tensor, ++ value_states: torch.Tensor, ++ attention_mask: torch.Tensor, ++ n_rep: int, ++): ++ """ ++ Copied from repeat_kv: /~https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/models/phi/modeling_phi.py ++ The only differences are: ++ - Append num_key_value_heads == 1 check as kv states can be broadcasted during matmuls so need to expand and reshape them. ++ - Add new args query_states, key_states, value_states and attention_mask and update the logic for expansion. ++ The query states go from (batch, num_heads, seqlen, head_dim) to (batch, num_key_value_heads, n_rep, seqlen, head_dim) ++ The key/value states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_key_value_heads, 1, seqlen, head_dim) ++ """ ++ batch, num_key_value_heads, kv_len, head_dim = key_states.shape ++ if n_rep == 1 or num_key_value_heads == 1: ++ return query_states, key_states, value_states, attention_mask ++ ++ new_kv_shape = (batch, num_key_value_heads, 1, kv_len, head_dim) ++ key_states = key_states.reshape(new_kv_shape) ++ value_states = value_states.reshape(new_kv_shape) ++ ++ batch, _, q_len, head_dim = query_states.shape ++ new_q_shape = (batch, num_key_value_heads, n_rep, q_len, head_dim) ++ query_states = query_states.reshape(new_q_shape) ++ ++ if attention_mask is not None: ++ # Add groups dim and set to 1 ++ attention_mask = attention_mask.unsqueeze(1) ++ ++ return query_states, key_states, value_states, attention_mask ++ ++ ++def gaudi_eager_attention_forward( ++ module: torch.nn.Module, ++ query: torch.Tensor, ++ key: torch.Tensor, ++ value: torch.Tensor, ++ attention_mask: Optional[torch.Tensor], ++ scaling: float, ++ dropout: float = 0.0, ++ **kwargs, ++): ++ bsz, q_len = kwargs["input_shape"] ++ query_states, key_states, value_states, attention_mask = gaudi_phi_repeat_kv( ++ query, key, value, attention_mask, module.num_key_value_groups ++ ) ++ ++ attn_weights = module.matmul_qk(query_states, key_states.transpose(2, 3)) * scaling ++ if attention_mask is not None: ++ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] ++ attn_weights = attn_weights + causal_mask ++ ++ attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) ++ attn_weights = torch.nn.functional.dropout(attn_weights, p=dropout, training=module.training) ++ attn_output = module.matmul_av(attn_weights, value_states) ++ attn_output = attn_output.reshape(bsz, -1, q_len, module.head_dim) ++ ++ return attn_output, attn_weights ++ ++ ++class GaudiPhi3Attention(Phi3Attention): ++ def __init__(self, config: Phi3Config, layer_idx: Optional[int] = None): ++ super().__init__(config, layer_idx) ++ self.matmul_qk = Matmul() ++ self.matmul_av = Matmul() ++ self.k_cache = KVCache() ++ self.v_cache = KVCache() ++ self.inp_seq_len = -1 ++ self.rotary_emb = GaudiRotaryEmbedding(config=self.config) ++ self.num_key_value_heads = config.num_key_value_heads ++ ++ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): ++ cache_shape = (batch_size, self.num_key_value_heads, max_seq_len, self.head_dim) ++ device = self.k_proj.weight.device ++ dtype = self.config.torch_dtype ++ self.k_cache.allocate(inp_seq_len, dtype, device, cache_shape) ++ self.v_cache.allocate(inp_seq_len, dtype, device, cache_shape) ++ ++ def forward( ++ self, ++ hidden_states: torch.Tensor, ++ attention_mask: Optional[torch.Tensor], ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_value: Optional[Cache] = None, ++ output_attentions: bool = False, ++ use_cache: bool = False, ++ cache_position: Optional[torch.LongTensor] = None, ++ token_idx: Optional[torch.Tensor] = None, ++ reuse_cache: Optional[bool] = False, ++ cache_idx: Optional[int] = None, ++ **kwargs, ++ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: ++ """ ++ Copied from PhiAttention.forward: /~https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py ++ The only differences are: ++ - add new args token_idx ++ - optimize KV cache ++ - add new args reuse_cache ++ - add new args cache_idx ++ """ ++ bsz, q_len, _ = hidden_states.size() ++ input_shape = [bsz, q_len] ++ qkv = self.qkv_proj(hidden_states) ++ query_pos = self.num_heads * self.head_dim ++ query_states = qkv[..., :query_pos] ++ key_states = qkv[..., query_pos : query_pos + self.num_key_value_heads * self.head_dim] ++ value_states = qkv[..., query_pos + self.num_key_value_heads * self.head_dim :] ++ ++ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) ++ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ++ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) ++ ++ if self.qk_layernorm: ++ query_states = self.q_layernorm(query_states) ++ key_states = self.k_layernorm(key_states) ++ ++ kv_seq_len = key_states.shape[-2] ++ if past_key_value is not None: ++ if self.layer_idx is None: ++ raise ValueError( ++ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " ++ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " ++ "with a layer index." ++ ) ++ kv_shape = ( ++ (past_key_value[0][-2] if reuse_cache else past_key_value[0].shape[-2]) ++ if isinstance(past_key_value, tuple) ++ else past_key_value.get_usable_length(kv_seq_len, self.layer_idx) ++ ) ++ if token_idx is not None: ++ kv_seq_len = kv_shape ++ else: ++ kv_seq_len += kv_shape ++ ++ cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) ++ ++ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) ++ ++ ++ if use_cache: ++ # reuse k, v, self_attention ++ if reuse_cache: ++ key_states = self.k_cache(key_states, 2, token_idx) ++ value_states = self.v_cache(value_states, 2, token_idx) ++ past_key_value = (self.k_cache.get_shape(), self.v_cache.get_shape()) ++ else: ++ if past_key_value is None: ++ past_key = torch.zeros(key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device) ++ past_value = torch.zeros( ++ key_states.shape, dtype=self.k_proj.weight.dtype, device=key_states.device ++ ) ++ past_key_value = (past_key, past_value) ++ key_states = self.k_cache.update(past_key_value[0], key_states, 2, token_idx, self.inp_seq_len) ++ value_states = self.v_cache.update(past_key_value[1], value_states, 2, token_idx, self.inp_seq_len) ++ if token_idx is None: ++ past_key_value = (key_states, value_states) ++ ++ if cache_idx is not None and q_len == 1: ++ key_states = key_states[:, :, :cache_idx, :] ++ value_states = value_states[:, :, :cache_idx, :] ++ if attention_mask is not None: ++ attention_mask = attention_mask[:, :, :, :cache_idx] ++ kv_seq_len = key_states.shape[-2] ++ else: ++ past_key_value = None ++ ++ attn_output, attn_weights = gaudi_eager_attention_forward( ++ self, ++ query_states, ++ key_states, ++ value_states, ++ attention_mask, ++ dropout=0.0 if not self.training else self.attention_dropout, ++ scaling=self.scaling, ++ input_shape=input_shape, ++ ) ++ ++ attn_output = attn_output.transpose(1, 2).contiguous() ++ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) ++ ++ attn_output = self.o_proj(attn_output) ++ ++ if not output_attentions: ++ attn_weights = None ++ ++ return attn_output, attn_weights, past_key_value ++ ++ ++class GaudiPhi3DecoderLayer(Phi3DecoderLayer): ++ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): ++ self.self_attn.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) ++ ++ def forward( ++ self, ++ hidden_states: torch.Tensor, ++ attention_mask: Optional[torch.Tensor] = None, ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_value: Optional[Tuple[torch.Tensor]] = None, ++ output_attentions: Optional[bool] = False, ++ use_cache: Optional[bool] = False, ++ cache_position: Optional[torch.LongTensor] = None,token_idx: Optional[torch.Tensor] = None, ++ reuse_cache: Optional[bool] = False, ++ cache_idx: Optional[int] = None, ++ **kwargs, ++ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: ++ """ ++ Copied from PhiDecoderLayer.forward: /~https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py ++ The only differences are: ++ - add new args token_idx ++ - add new args reuse_cache ++ - add new args cache_idx ++ """ ++ ++ residual = hidden_states ++ ++ hidden_states = self.input_layernorm(hidden_states) ++ ++ # Self Attention ++ attn_outputs, self_attn_weights, present_key_value = self.self_attn( ++ hidden_states=hidden_states, ++ attention_mask=attention_mask, ++ position_ids=position_ids, ++ past_key_value=past_key_value, ++ output_attentions=output_attentions, ++ use_cache=use_cache, ++ cache_position=cache_position, ++ ) ++ # FIXME: Fix this ++ # token_idx=token_idx, ++ # reuse_cache=reuse_cache, ++ # cache_idx=cache_idx, ++ hidden_states = residual + self.resid_attn_dropout(attn_outputs) ++ ++ residual = hidden_states ++ hidden_states = self.post_attention_layernorm(hidden_states) ++ hidden_states = self.mlp(hidden_states) ++ hidden_states = residual + self.resid_mlp_dropout(hidden_states) ++ ++ outputs = (hidden_states,) ++ ++ if output_attentions: ++ outputs += (self_attn_weights,) ++ ++ if use_cache: ++ outputs += (present_key_value,) ++ ++ return outputs ++ ++ ++ ++class GaudiPhi3Model(Phi3Model): ++ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): ++ for layer in self.layers: ++ layer.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) ++ ++ def forward( ++ self, ++ input_ids: torch.LongTensor = None, ++ attention_mask: Optional[torch.Tensor] = None, ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_values: Optional[List[torch.FloatTensor]] = None, ++ inputs_embeds: Optional[torch.FloatTensor] = None, ++ use_cache: Optional[bool] = None, ++ output_attentions: Optional[bool] = None, ++ output_hidden_states: Optional[bool] = None, ++ return_dict: Optional[bool] = None, ++ cache_position: Optional[torch.LongTensor] = None, ++ token_idx: Optional[torch.Tensor] = None, ++ reuse_cache: Optional[bool] = False, ++ cache_idx: Optional[int] = None, ++ ) -> Union[Tuple, BaseModelOutputWithPast]: ++ """ ++ Copied from PhiModel.forward: /~https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py ++ The only differences are: ++ - add new args token_idx ++ - add new args reuse_cache ++ - add new args cache_idx ++ """ ++ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions ++ output_hidden_states = ( ++ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ++ ) ++ use_cache = use_cache if use_cache is not None else self.config.use_cache ++ ++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict ++ ++ # retrieve input_ids and inputs_embeds ++ if input_ids is not None and inputs_embeds is not None: ++ raise ValueError("You must specify exactly one of input_ids or inputs_embeds") ++ elif input_ids is not None: ++ batch_size, seq_length = input_ids.shape[:2] ++ elif inputs_embeds is not None: ++ batch_size, seq_length = inputs_embeds.shape[:2] ++ else: ++ raise ValueError("You have to specify either input_ids or inputs_embeds") ++ ++ if self.gradient_checkpointing and self.training and use_cache: ++ logger.warning_once( ++ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ++ ) ++ use_cache = False ++ ++ use_legacy_cache = True ++ use_new_cache = False ++ past_seen_tokens = 0 ++ if past_key_values is not None and use_cache: ++ if reuse_cache: ++ past_seen_tokens = past_key_values[0][0][2] ++ else: ++ if use_new_cache: ++ use_legacy_cache = not isinstance(past_key_values, Cache) ++ if use_legacy_cache: ++ past_key_values = DynamicCache.from_legacy_cache(past_key_values) ++ past_seen_tokens = past_key_values.get_seq_length() ++ else: ++ # TODO: Need to fix token_idx ++ if past_key_values[0] is not None: ++ past_seen_tokens = past_key_values[0][0].shape[2] ++ ++ if inputs_embeds is None: ++ inputs_embeds = self.embed_tokens(input_ids) ++ ++ if cache_position is None: ++ past_seen_tokens = 0 ++ if past_key_values is not None: ++ if isinstance(past_key_values, Cache): ++ past_seen_tokens = past_key_values.get_seq_length() ++ else: ++ past_seen_tokens = past_key_values[0][0].shape[2] ++ cache_position = torch.arange( ++ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ++ ) ++ if position_ids is None: ++ position_ids = cache_position.unsqueeze(0) ++ ++ # 4d mask is passed through the layers ++ attention_mask = _gaudi_prepare_4d_causal_attention_mask( ++ attention_mask, (batch_size, seq_length), inputs_embeds, past_seen_tokens ++ ) ++ ++ hidden_states = inputs_embeds ++ ++ # decoder layers ++ all_hidden_states = () if output_hidden_states else None ++ all_self_attns = () if output_attentions else None ++ next_decoder_cache = () if not use_new_cache else None ++ ++ for layer_idx, decoder_layer in enumerate(self.layers): ++ if output_hidden_states: ++ all_hidden_states += (hidden_states,) ++ ++ if self.gradient_checkpointing and self.training: ++ layer_outputs = self._gradient_checkpointing_func( ++ decoder_layer.__call__, ++ hidden_states, ++ attention_mask, ++ position_ids, ++ None if past_key_values is None else past_key_values[layer_idx], ++ output_attentions, ++ use_cache, ++ cache_position, ++ None, ++ ) ++ else: ++ ++ layer_outputs = decoder_layer( ++ hidden_states, ++ attention_mask=attention_mask, ++ position_ids=position_ids, ++ past_key_value=None if past_key_values is None else past_key_values[layer_idx], ++ output_attentions=output_attentions, ++ use_cache=use_cache, ++ cache_position=cache_position, ++ ) ++ # FIXME: Fix this ++ # token_idx=token_idx, ++ # reuse_cache=reuse_cache, ++ # cache_idx=cache_idx, ++ ++ hidden_states = layer_outputs[0] ++ ++ if use_cache: ++ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) ++ ++ if output_attentions: ++ all_self_attns += (layer_outputs[1],) ++ ++ hidden_states = self.norm(hidden_states) ++ ++ # add hidden states from the last decoder layer ++ if output_hidden_states: ++ all_hidden_states += (hidden_states,) ++ ++ next_cache = None ++ if use_cache: ++ next_cache = ( ++ next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ++ ) ++ ++ if not return_dict: ++ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) ++ return BaseModelOutputWithPast( ++ last_hidden_state=hidden_states, ++ past_key_values=next_cache, ++ hidden_states=all_hidden_states, ++ attentions=all_self_attns, ++ ) ++ ++ ++logger = logging.get_logger(__name__) ++ ++class GaudiPhi3ForCausalLM(Phi3ForCausalLM): ++ def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len): ++ self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len) ++ ++ def forward( ++ self, ++ input_ids: torch.LongTensor = None, ++ attention_mask: Optional[torch.Tensor] = None, ++ position_ids: Optional[torch.LongTensor] = None, ++ past_key_values: Optional[List[torch.FloatTensor]] = None, ++ inputs_embeds: Optional[torch.FloatTensor] = None, ++ labels: Optional[torch.LongTensor] = None, ++ use_cache: Optional[bool] = None, ++ output_attentions: Optional[bool] = None, ++ output_hidden_states: Optional[bool] = None, ++ return_dict: Optional[bool] = None, ++ cache_position: Optional[torch.LongTensor] = None, ++ num_logits_to_keep: int = 0, ++ token_idx: Optional[torch.Tensor] = None, ++ reuse_cache: Optional[bool] = False, ++ trim_logits: Optional[bool] = False, ++ cache_idx: Optional[int] = None, ++ **kwargs, ++ ) -> Union[Tuple, CausalLMOutputWithPast]: ++ """ ++ Inherits from PhiForCausalLM: /~https://github.com/huggingface/transformers/blob/v4.37.1/src/transformers/models/phi/modeling_phi.py ++ The only differences are: ++ - add new args token_idx ++ - add new args reuse_cache ++ - add new args cache_idx ++ """ ++ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions ++ output_hidden_states = ( ++ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ++ ) ++ return_dict = return_dict if return_dict is not None else self.config.use_return_dict ++ ++ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) ++ outputs = self.model( ++ input_ids=input_ids, ++ attention_mask=attention_mask, ++ position_ids=position_ids, ++ past_key_values=past_key_values, ++ inputs_embeds=inputs_embeds, ++ use_cache=use_cache, ++ output_attentions=output_attentions, ++ output_hidden_states=output_hidden_states, ++ return_dict=return_dict, ++ cache_position=cache_position, ++ token_idx=token_idx, ++ reuse_cache=reuse_cache, ++ cache_idx=cache_idx, ++ ) ++ ++ hidden_states = outputs[0] ++ _, seq_len, _ = hidden_states.shape ++ if seq_len > 1 and trim_logits and not self.training: ++ if token_idx is not None: ++ hidden_states = hidden_states.index_select(1, token_idx - 1) ++ else: ++ hidden_states = hidden_states[:, -1, :] ++ num_logits_to_keep = 0 ++ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss ++ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) ++ ++ loss = None ++ if labels is not None: ++ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) ++ ++ if not return_dict: ++ output = (logits,) + outputs[1:] ++ return (loss,) + output if loss is not None else output ++ ++ return CausalLMOutputWithPast( ++ loss=loss, ++ logits=logits, ++ past_key_values=outputs.past_key_values, ++ hidden_states=outputs.hidden_states, ++ attentions=outputs.attentions, ++ ) ++ ++ def prepare_inputs_for_generation( ++ self, ++ input_ids, ++ past_key_values=None, ++ attention_mask=None, ++ inputs_embeds=None, ++ cache_position=None, ++ position_ids=None, ++ use_cache=True, ++ num_logits_to_keep=0, ++ token_idx=None, ++ **kwargs, ++ ): ++ """ ++ The only differences are: ++ - add new args token_idx ++ - add token_idx into model_inputs ++ - from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx ++ - from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx ++ """ ++ reuse_cache = kwargs.get("reuse_cache") ++ # Omit tokens covered by past_key_values ++ if past_key_values is not None: ++ if token_idx is not None: ++ idx = token_idx + kwargs.get("inputs_embeds_offset", 0) - 1 ++ input_ids = torch.index_select(input_ids, 1, idx) ++ else: ++ if inputs_embeds is not None: # Exception 1 ++ input_ids = input_ids[:, -cache_position.shape[0] :] ++ elif ( ++ input_ids.shape[1] != cache_position.shape[0] ++ ): # Default case (the "else", a no op, is Exception 2) ++ input_ids = input_ids[:, cache_position] ++ elif reuse_cache and token_idx is not None: ++ # With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass ++ input_ids = input_ids[:, :token_idx] ++ attention_mask = attention_mask[:, :token_idx] ++ ++ if attention_mask is not None and position_ids is None: ++ # create position_ids on the fly for batch generation ++ position_ids = attention_mask.long().cumsum(-1) - 1 ++ position_ids.masked_fill_(attention_mask == 0, 1) ++ if past_key_values: ++ if token_idx is not None: ++ position_ids = torch.index_select(position_ids, 1, token_idx - 1) ++ else: ++ position_ids = position_ids[:, -input_ids.shape[1] :] ++ # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture. ++ position_ids = position_ids.clone(memory_format=torch.contiguous_format) ++ ++ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step ++ if inputs_embeds is not None and past_key_values is None: ++ model_inputs = {"inputs_embeds": inputs_embeds} ++ else: ++ model_inputs = { ++ "input_ids": input_ids.clone(memory_format=torch.contiguous_format) ++ } # `contiguous()` needed for compilation use cases ++ ++ if num_logits_to_keep is not None: ++ model_inputs["num_logits_to_keep"] = num_logits_to_keep ++ ++ model_inputs.update( ++ { ++ "position_ids": position_ids, ++ "cache_position": cache_position, ++ "past_key_values": past_key_values, ++ "use_cache": use_cache, ++ "attention_mask": attention_mask, ++ "token_idx": token_idx, ++ "reuse_cache": kwargs.get("reuse_cache"), ++ "trim_logits": kwargs.get("trim_logits"), ++ "cache_idx": kwargs.get("cache_idx"), ++ } ++ ) ++ ++ return model_inputs +-- +2.34.1 + diff --git a/tests/llms/test_llms_text-generation_native_enhance_multimodal_on_intel_hpu.sh b/tests/llms/test_llms_text-generation_native_enhance_multimodal_on_intel_hpu.sh new file mode 100644 index 000000000..bff73beef --- /dev/null +++ b/tests/llms/test_llms_text-generation_native_enhance_multimodal_on_intel_hpu.sh @@ -0,0 +1,103 @@ +#!/bin/bash +# Copyright (C) 2024 Intel Corporation +# SPDX-License-Identifier: Apache-2.0 + +set -x + +IMAGE_REPO=${IMAGE_REPO:-"opea"} +export REGISTRY=${IMAGE_REPO} +export TAG="comps" +echo "REGISTRY=IMAGE_REPO=${IMAGE_REPO}" +echo "TAG=${TAG}" + +WORKPATH=$(dirname "$PWD") +host_ip=$(hostname -I | awk '{print $1}') +LOG_PATH="$WORKPATH/tests" +service_name="textgen-native-gaudi-enhance-multimodal" + +function build_docker_images() { + cd $WORKPATH + docker build --no-cache -t ${REGISTRY:-opea}/llm-textgen-gaudi-enhance:${TAG:-latest} --build-arg https_proxy=$https_proxy --build-arg http_proxy=$http_proxy -f comps/llms/src/text-generation/Dockerfile.intel_hpu_enhance . + if [ $? -ne 0 ]; then + echo "opea/llm-textgen-gaudi-enhance built fail" + exit 1 + else + echo "opea/llm-textgen-gaudi-enhance built successful" + fi +} + +function start_service() { + export TEXTGEN_PORT=10512 #10500-10599 + export host_ip=${host_ip} + export LLM_MODEL_ID="/data/phi4/phi-4-multimodel" + export LOGFLAG=True + export DATA_PATH="/data" + + cd $WORKPATH/comps/llms/deployment/docker_compose + docker compose -f compose_text-generation.yaml up ${service_name} -d > ${LOG_PATH}/start_services_with_compose.log + + sleep 2m +} + +function validate_services() { + local URL="$1" + local EXPECTED_RESULT="$2" + local SERVICE_NAME="$3" + local DOCKER_NAME="$4" + local INPUT_DATA="$5" + + local HTTP_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X POST -d "$INPUT_DATA" -H 'Content-Type: application/json' "$URL") + + echo "===========================================" + + if [ "$HTTP_STATUS" -eq 200 ]; then + echo "[ $SERVICE_NAME ] HTTP status is 200. Checking content..." + + local CONTENT=$(curl -s -X POST -d "$INPUT_DATA" -H 'Content-Type: application/json' "$URL" | tee ${LOG_PATH}/${SERVICE_NAME}.log) + + if echo "$CONTENT" | grep -q "$EXPECTED_RESULT"; then + echo "[ $SERVICE_NAME ] Content is as expected." + else + echo "[ $SERVICE_NAME ] Content does not match the expected result: $CONTENT" + docker logs ${DOCKER_NAME} >> ${LOG_PATH}/${SERVICE_NAME}.log + exit 1 + fi + else + echo "[ $SERVICE_NAME ] HTTP status is not 200. Received status was $HTTP_STATUS" + docker logs ${DOCKER_NAME} >> ${LOG_PATH}/${SERVICE_NAME}.log + exit 1 + fi + sleep 1s +} + +function validate_microservices() { + URL="http://${host_ip}:${TEXTGEN_PORT}/v1/chat/completions" + + # textgen + echo "Validate textgen with string messages input..." + validate_services \ + "$URL" \ + "text" \ + "textgen-native-gaudi" \ + "textgen-native-gaudi" \ + '{"model": "Intel/neural-chat-7b-v3-3", "messages": "What is Deep Learning?", "max_tokens":17, "stream":false}' +} + +function stop_docker() { + cd $WORKPATH/comps/llms/deployment/docker_compose + docker compose -f compose_text-generation.yaml down ${service_name} --remove-orphans +} + +function main() { + + stop_docker + build_docker_images + start_service + validate_microservice + stop_docker + + echo y | docker system prune + +} + +main