Skip to content

Commit

Permalink
fixed cryptic contiguity problems
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 16, 2025
1 parent 821c997 commit cc2e124
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 81 deletions.
6 changes: 2 additions & 4 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,6 @@ def forward(
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
labels: Optional[torch.LongTensor] = None,
use_cache_branch: bool = None,
**kwargs,
) -> CausalLMOutputWithPast:
Expand All @@ -233,18 +232,17 @@ def forward(
)

# Create position_ids on the fly for batch generation
if attention_mask is not None and position_ids is None and "position_ids" in self.input_names:
if "position_ids" in self.input_names and position_ids is None and attention_mask is not None:
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

model_inputs = {
"input_ids": input_ids.contiguous() if use_torch else input_ids,
"input_ids": input_ids,
"position_ids": position_ids,
"attention_mask": attention_mask,
"use_cache_branch": use_cache_branch,
"labels": labels,
}

if past_key_values is not None:
Expand Down
20 changes: 10 additions & 10 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,21 +814,21 @@ def _prepare_io_binding(

input_name_to_shape = {}
for input_name in self.input_names.keys():
tensor = model_inputs[input_name].contiguous()
input_name_to_shape[input_name] = tensor.shape
model_inputs[input_name] = model_inputs[input_name].contiguous()
input_name_to_shape[input_name] = model_inputs[input_name].shape

data_ptr = tensor.data_ptr()
if "past" in input_name and data_ptr == 0:
# During first generation, sequence_length can be 0 when use_cache=True, which results in data_ptr to also be 0.
# To keep compatibility with IO binding, we pass the data pointer of input_ids instead. This will have no impact because past_key_values will not be used during the first generation.
data_ptr = next(iter(model_inputs.values())).data_ptr()
expected_dtype = TypeHelper.ort_type_to_torch_type(self.input_dtypes[input_name])
if model_inputs[input_name].dtype != expected_dtype:
model_inputs[input_name] = model_inputs[input_name].to(expected_dtype)

data_ptr = model_inputs[input_name].data_ptr()

io_binding.bind_input(
input_name,
tensor.device.type,
self.device.type,
IOBindingHelper.get_device_index(self.device),
name_to_np_type[input_name],
tuple(tensor.shape),
TypeHelper.ort_type_to_numpy_type(self.input_dtypes[input_name]),
model_inputs[input_name].shape,
data_ptr,
)
dimensions = {}
Expand Down
16 changes: 4 additions & 12 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from tempfile import TemporaryDirectory
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
Expand Down Expand Up @@ -450,7 +449,7 @@ def forward(
"attention_mask": attention_mask,
}

if self.parent_model.device.type == "cuda" and self.parent_model.use_io_binding:
if self.parent_model.use_io_binding:
io_binding, output_shapes, output_buffers = self._prepare_io_binding(self.session, model_inputs)

io_binding.synchronize_inputs()
Expand Down Expand Up @@ -1513,26 +1512,18 @@ def forward(
labels: Optional[torch.LongTensor] = None,
**kwargs,
) -> Seq2SeqLMOutput:
# Encode if needed : first prediction pass
# Encode if needed (training, first prediction pass)
if encoder_outputs is None:
encoder_outputs = self.encoder(
flattened_patches=flattened_patches,
attention_mask=attention_mask,
)

# TODO: for some reason the attention_mask for pix2struct is a float in transformers and not an int64. This messes up with the exporter
# hardcodes int64 input dtype for the attention mask. This workaround is quite ugly, it should be fixed rather in the ONNX exporter.
if isinstance(attention_mask, torch.Tensor):
attention_mask = attention_mask.to(torch.int64)
else:
attention_mask = attention_mask.astype(np.int64)

model = (
self.decoder
if past_key_values is None or not self.use_cache or self.use_merged
if self.use_merged or not self.use_cache or past_key_values is None
else self.decoder_with_past
)

decoder_outputs = model(
input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
Expand All @@ -1546,6 +1537,7 @@ def forward(
loss=decoder_outputs.get("loss", None),
logits=decoder_outputs.logits,
past_key_values=decoder_outputs.past_key_values,
encoder_last_hidden_state=encoder_outputs.last_hidden_state,
)

def prepare_inputs_for_generation(
Expand Down
84 changes: 29 additions & 55 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5489,9 +5489,6 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")

if use_cache is False:
self.skipTest("skip")

model_args = {
"test_name": test_name,
"model_arch": model_arch,
Expand All @@ -5507,56 +5504,37 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
if use_merged is False:
model_path = Path(self.onnx_model_dirs[test_name], ONNX_DECODER_NAME)
self.assertFalse(has_onnx_input(model_path, "use_cache_branch"))
self.assertEqual(onnx_model.use_merged, False)
self.assertFalse(onnx_model.use_merged)
else:
model_path = Path(self.onnx_model_dirs[test_name], ONNX_DECODER_MERGED_NAME)
self.assertTrue(has_onnx_input(model_path, "use_cache_branch"))
self.assertEqual(onnx_model.use_merged, True)
self.assertTrue(onnx_model.use_merged)

self.assertIsInstance(onnx_model.decoder, ORTDecoderForSeq2Seq)
if onnx_model.use_cache is True and onnx_model.use_merged is False:
if use_cache is True and use_merged is False:
self.assertIsInstance(onnx_model.decoder_with_past, ORTDecoderForSeq2Seq)
if onnx_model.use_cache is True and onnx_model.use_merged is True:
if use_cache is True and use_merged is True:
self.assertTrue(onnx_model.decoder_with_past is None)

self.assertIsInstance(onnx_model.config, PretrainedConfig)

set_seed(SEED)
transformers_model = Pix2StructForConditionalGeneration.from_pretrained(model_id)

preprocessor = get_preprocessor(model_id)
questions = [
"Who am I?",
"What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud and this is long long very long and super long my dear",
]

transformers_model = Pix2StructForConditionalGeneration.from_pretrained(model_id)
preprocessor = get_preprocessor(model_id)

inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=questions, padding=True, return_tensors="pt")
del inputs["decoder_attention_mask"]
del inputs["decoder_input_ids"]

decoder_start_token_id = transformers_model.config.decoder_start_token_id
decoder_inputs = {
"decoder_input_ids": torch.ones((2, 1), dtype=torch.long) * decoder_start_token_id,
"decoder_attention_mask": torch.ones((2, 1), dtype=torch.int64),
}

with torch.no_grad():
transformers_outputs = transformers_model(**inputs, **decoder_inputs)
transformers_outputs = transformers_model(**inputs)

for input_type in ["pt", "np"]:
inputs = preprocessor(
images=[self.IMAGE, self.IMAGE], text=questions, padding=True, return_tensors=input_type
)
del inputs["decoder_attention_mask"]
del inputs["decoder_input_ids"]

if input_type == "np":
decoder_inputs = {
"decoder_input_ids": np.ones((2, 1), dtype=np.int64) * decoder_start_token_id,
"decoder_attention_mask": np.ones((2, 1), dtype=np.int64),
}

onnx_outputs = onnx_model(**inputs, **decoder_inputs)
onnx_outputs = onnx_model(**inputs)

self.assertTrue("logits" in onnx_outputs)
self.assertIsInstance(onnx_outputs.logits, self.TENSOR_ALIAS_TO_TYPE[input_type])
Expand All @@ -5568,42 +5546,39 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@pytest.mark.cuda_ep_test # mark as GPU test as well to run the without/with cache timing test on the slow tests
def test_compare_with_and_without_past_key_values(self, model_arch: str):
if model_arch == "m2m_100":
return # TODO: this test is failing for m2m_100

model_args = {"test_name": model_arch + "_False", "model_arch": model_arch, "use_cache": False}
self._setup(model_args)
model_args = {"test_name": model_arch + "_True", "model_arch": model_arch, "use_cache": True}
self._setup(model_args)

model_with_pkv = ORTModelForPix2Struct.from_pretrained(
self.onnx_model_dirs[model_arch + "_True"], use_cache=True
)
model_without_pkv = ORTModelForPix2Struct.from_pretrained(
self.onnx_model_dirs[model_arch + "_False"], use_cache=False
)

model_id = MODEL_NAMES[model_arch]
preprocessor = get_preprocessor(model_id)

question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
inputs = preprocessor(images=self.IMAGE, text=question, return_tensors="pt")
del inputs["decoder_attention_mask"]
del inputs["decoder_input_ids"]

model_with_pkv = ORTModelForPix2Struct.from_pretrained(
self.onnx_model_dirs[model_arch + "_True"], use_cache=True
)

outputs_model_with_pkv = model_with_pkv.generate(
**inputs, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1
)

model_without_pkv = ORTModelForPix2Struct.from_pretrained(
self.onnx_model_dirs[model_arch + "_False"], use_cache=False
)
outputs_model_without_pkv = model_without_pkv.generate(
**inputs, min_new_tokens=self.GENERATION_LENGTH, max_new_tokens=self.GENERATION_LENGTH, num_beams=1
)

torch.testing.assert_close(outputs_model_with_pkv, outputs_model_without_pkv, rtol=self.RTOL, atol=self.ATOL)
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH + 1)
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH + 1)

torch.testing.assert_close(outputs_model_with_pkv, outputs_model_without_pkv, rtol=self.RTOL, atol=self.ATOL)

@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, model_arch: str, use_cache: bool):
model_args = {
Expand All @@ -5626,8 +5601,6 @@ def test_compare_merged_and_not_merged_models_outputs(self, test_name: str, mode

question = "What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash cloud"
inputs = preprocessor(images=self.IMAGE, text=question, return_tensors="pt")
del inputs["decoder_attention_mask"]
del inputs["decoder_input_ids"]

model_not_merged_dir = self.onnx_model_dirs[test_name + "_False"]
model_merged_dir = self.onnx_model_dirs[test_name + "_True"]
Expand Down Expand Up @@ -5676,24 +5649,29 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
self.assertTrue(io_model.use_io_binding)

preprocessor = get_preprocessor(model_id)
decoder_start_token_id = onnx_model.config.decoder_start_token_id
question = ["What does the label 15 represent? (1) lava (2) core (3) tunnel (4) ash", "Who are you?"]
inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt").to(
"cuda"
)
inputs["decoder_input_ids"] = torch.full((2, 1), decoder_start_token_id, dtype=torch.int64).to("cuda")
inputs["decoder_attention_mask"] = torch.ones((2, 1), dtype=torch.int64).to("cuda")

onnx_outputs = onnx_model(**inputs)
io_outputs = io_model(**inputs)

self.assertTrue("logits" in io_outputs)
self.assertTrue("encoder_last_hidden_state" in io_outputs)

self.assertIsInstance(io_outputs.logits, torch.Tensor)
self.assertIsInstance(io_outputs.encoder_last_hidden_state, torch.Tensor)

torch.testing.assert_close(
onnx_outputs.logits, io_outputs.logits, atol=self.ATOL, rtol=self.RTOL, equal_nan=True
onnx_outputs.encoder_last_hidden_state,
io_outputs.encoder_last_hidden_state,
atol=self.ATOL,
rtol=self.RTOL,
)

torch.testing.assert_close(onnx_outputs.logits, io_outputs.logits, atol=self.ATOL, rtol=self.RTOL)

gc.collect()

@parameterized.expand(
Expand All @@ -5711,9 +5689,6 @@ def test_compare_to_io_binding(self, test_name: str, model_arch: str, use_cache:
def test_compare_generation_to_io_binding(
self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool, num_beams: int
):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")

model_args = {
"test_name": test_name,
"model_arch": model_arch,
Expand All @@ -5738,13 +5713,12 @@ def test_compare_generation_to_io_binding(
inputs = preprocessor(images=[self.IMAGE, self.IMAGE], text=question, padding=True, return_tensors="pt").to(
"cuda"
)
del inputs["decoder_attention_mask"]
del inputs["decoder_input_ids"]

onnx_outputs = onnx_model.generate(**inputs, num_beams=num_beams)
io_outputs = io_model.generate(**inputs, num_beams=num_beams)

# compare tensor outputs
print("diff", onnx_outputs - io_outputs)
torch.testing.assert_close(onnx_outputs, io_outputs, atol=self.ATOL, rtol=self.RTOL)

gc.collect()
Expand Down

0 comments on commit cc2e124

Please sign in to comment.