Skip to content

Commit

Permalink
adapt llava pipeline to latest Transformers
Browse files Browse the repository at this point in the history
- change TF pipeline name from image-to-text to image-text-to-text huggingface/transformers#34769
- simplify the preprocessing logics, keep the handling when image is empty string, and remove the multi-image inference in one run
- fix few README errors
  • Loading branch information
Spycsh committed Feb 28, 2025
1 parent fb703f0 commit 876998c
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ COPY comps /home/user/comps
# Install requirements and optimum habana
RUN pip install --no-cache-dir --upgrade pip && \
pip install --no-cache-dir -r /home/user/comps/lvms/src/integrations/dependency/llava/requirements.txt && \
pip install --no-cache-dir optimum[habana]
pip install git+/~https://github.com/huggingface/optimum-habana.git@transformers_future && \
pip install --no-cache-dir --upgrade Jinja2

ENV PYTHONPATH=$PYTHONPATH:/home/user
USER user
Expand Down
7 changes: 2 additions & 5 deletions comps/lvms/src/integrations/dependency/llava/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,10 @@ docker run -p 8399:8399 --runtime=habana -e HABANA_VISIBLE_DEVICES=all -e OMPI_M
# Use curl/python

# curl with an image and a prompt
http_proxy="" curl http://localhost:9399/v1/lvm -XPOST -d '{"image": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC", "prompt":"What is this?"}' -H 'Content-Type: application/json'

# curl with multiple images and a prompt (Note that depending on your MAX_IMAGES value, both images may not be sent to the LLaVA model)
http_proxy="" curl http://localhost:9399/v1/lvm -XPOST -d '{"image": ["iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNkYPhfz0AEYBxVSF+FAP5FDvcfRYWgAAAAAElFTkSuQmCC", "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mNk+M9Qz0AEYBxVSF+FAAhKDveksOjmAAAAAElFTkSuQmCC"], "prompt":"What is in these images?"}' -H 'Content-Type: application/json'
http_proxy="" curl http://localhost:8399/generate -XPOST -d '{"img_b64_str": "iVBORw0KGgoAAAANSUhEUgAAAAoAAAAKCAYAAACNMs+9AAAAFUlEQVR42mP8/5+hnoEIwDiqkL4KAcT9GO0U4BxoAAAAAElFTkSuQmCC", "prompt":"What is this?"}' -H 'Content-Type: application/json'

# curl with a prompt only (no image)
http_proxy="" curl http://localhost:9399/v1/lvm -XPOST -d '{"image": "", "prompt":"What is deep learning?"}' -H 'Content-Type: application/json'
http_proxy="" curl http://localhost:8399/generate -XPOST -d '{"img_b64_str": "", "prompt":"What is deep learning?"}' -H 'Content-Type: application/json'

# Test
python check_llava_server.py
Expand Down
92 changes: 7 additions & 85 deletions comps/lvms/src/integrations/dependency/llava/llava_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse, Response
from transformers import AutoProcessor, pipeline
from transformers.image_utils import load_image

model_name_or_path = None
model_dtype = None
Expand All @@ -27,79 +26,6 @@
app = FastAPI()


def pipeline_preprocess(self, image, prompt=None, timeout=None):
"""
This replaces the preprocess function used by the image-to-text pipeline
(/~https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/image_to_text.py).
The original transformers image-to-text pipeline preprocess function requires that an image is passed in, and will
fail if the image parameter is null/empty. In order to support multimodal use cases with the same pipeline, this
preprocess function handles the case where there is no image with the prompt.
Also, the image-to-text pipeline typically treats multiple images passed in as a list as a batch (where it iterates
over the image inputs for generation). For that reason, the original pipeline_preprocess code would only get a
single image at a time. To support multiple images, the pipeline call is updated to send a list of lists for the
images (so that when iterated, we still get multiple images) and this pipeline_preprocess function has been updated
to handle a list of images in addition to single images.
"""

if isinstance(image, list):
image = [load_image(i, timeout=timeout) for i in image]
elif image:
image = load_image(image, timeout=timeout)

if prompt is not None:
if not isinstance(prompt, str):
raise ValueError(
f"Received an invalid text input, got - {type(prompt)} - but expected a single string. "
"Note also that one single text can be provided for conditional image to text generation."
)

model_type = self.model.config.model_type

if model_type == "git":
if image:
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
else:
model_inputs = {}
input_ids = self.tokenizer(text=prompt, add_special_tokens=False).input_ids
input_ids = [self.tokenizer.cls_token_id] + input_ids
input_ids = torch.tensor(input_ids).unsqueeze(0)
model_inputs.update({"input_ids": input_ids})
elif model_type == "pix2struct":
model_inputs = self.image_processor(images=image, header_text=prompt, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)

elif model_type != "vision-encoder-decoder":
if image:
# vision-encoder-decoder does not support conditional generation
model_inputs = self.image_processor(images=image, return_tensors=self.framework)

if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
else:
model_inputs = {}

text_inputs = self.tokenizer(prompt, return_tensors=self.framework)
model_inputs.update(text_inputs)

else:
raise ValueError(f"Model type {model_type} does not support conditional text generation")

elif image:
model_inputs = self.image_processor(images=image, return_tensors=self.framework)
if self.framework == "pt":
model_inputs = model_inputs.to(self.torch_dtype)
else:
raise ValueError("Both image and prompt cannot be empty.")

if self.model.config.model_type == "git" and prompt is None:
model_inputs["input_ids"] = None

return model_inputs


def process_image(image, max_len=1344, min_len=672):
if max(image.size) > max_len:
max_hw, min_hw = max(image.size), min(image.size)
Expand All @@ -122,11 +48,11 @@ async def health() -> Response:


@app.post("/generate")
async def generate(request: Request) -> Response: # FIXME batch_size=1 for now
async def generate(request: Request) -> Response:
print("LLaVA generation begin.")
request_dict = await request.json()
prompt = request_dict.pop("prompt")
img_b64_str = request_dict.pop("img_b64_str") # String or list of strings
img_b64_str = request_dict.pop("img_b64_str") # Only accept string
max_new_tokens = request_dict.pop("max_new_tokens", 100)

# Determine the format of the role labels based on the model name
Expand Down Expand Up @@ -183,12 +109,9 @@ async def generate(request: Request) -> Response: # FIXME batch_size=1 for now

start = time.time()

# Override the pipeline preprocessing
generator.preprocess = pipeline_preprocess.__get__(generator, type(generator))

result = generator([images], prompt=prompt, batch_size=1, generate_kwargs=generate_kwargs)
result = generator(images, text=prompt, batch_size=1, generate_kwargs=generate_kwargs)
end = time.time()
result = result[0][0]["generated_text"].split(output_assistant_label.strip())[-1].strip()
result = result[0]["generated_text"].split(output_assistant_label.strip())[-1].strip()
print(f"LLaVA result = {result}, time = {(end-start) * 1000 }ms")
if images:
for i in images:
Expand Down Expand Up @@ -223,7 +146,7 @@ async def generate(request: Request) -> Response: # FIXME batch_size=1 for now
model_name_or_path = args.model_name_or_path

generator = pipeline(
"image-to-text",
"image-text-to-text",
model=args.model_name_or_path,
torch_dtype=model_dtype,
device=args.device,
Expand Down Expand Up @@ -266,11 +189,10 @@ async def generate(request: Request) -> Response: # FIXME batch_size=1 for now
},
]
text_prompt = processor.apply_chat_template(conversation)

for i in range(args.warmup):
generator(
res = generator(
images,
prompt=text_prompt,
text=text_prompt,
batch_size=1,
generate_kwargs=generate_kwargs,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ langchain-core
opentelemetry-api
opentelemetry-exporter-otlp
opentelemetry-sdk
optimum[habana]
transformers
prometheus-fastapi-instrumentator
pydantic==2.7.2
pydub
Expand Down

0 comments on commit 876998c

Please sign in to comment.