Skip to content

Commit

Permalink
Parse result for llava_hf 1.6
Browse files Browse the repository at this point in the history
  • Loading branch information
kcz358 committed May 7, 2024
1 parent 3e56b4f commit 7847dc4
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion lmms_eval/models/llava_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
eval_logger.info("Not sure whether you use 1.5 or 1.6. Use 1.5 by default. This might cause bugs if you are actually using 1.6")
self._model = LlavaForConditionalGeneration.from_pretrained(pretrained, revision=revision, torch_dtype=dtype, device_map=self.device_map, trust_remote_code=trust_remote_code, attn_implementation=attn_implementation)

self.pretrained = pretrained
self._image_processor = AutoProcessor.from_pretrained(pretrained, revision=revision, trust_remote_code=trust_remote_code)
# Pad from left for batched generation: https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips
self._image_processor.tokenizer.padding_side = "left"
Expand Down Expand Up @@ -317,7 +318,12 @@ def _collate(x):
eval_logger.error(f"Error {e} in generating")
cont = ""
text_outputs = self.tokenizer.batch_decode(cont, skip_special_tokens=True)[0]
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()
if "1.5" in self.pretrained:
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()
elif "1.6" in self.pretrained:
text_outputs = text_outputs.split("[/INST]")[-1].strip()
else:
text_outputs = text_outputs.split("ASSISTANT:")[-1].strip()

if self.accelerator.is_main_process and doc_id[0] % 100 == 0:
eval_logger.info(f"Generated text for doc ID {doc_id[0]}:\n\n{text_outputs}\n")
Expand Down

0 comments on commit 7847dc4

Please sign in to comment.