From f9cddd2cf5730bb330dc417ba461a684ea678444 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Fri, 22 Mar 2024 14:44:34 -0700 Subject: [PATCH] Remove early stopping from LLaMA end-to-end benchmarking (#20033) ### Description This PR removes early stopping from the end-to-end LLaMA-2 benchmark script. ### Motivation and Context This allows models to always generate the requested number of new tokens. --- .../python/tools/transformers/models/llama/benchmark_e2e.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py index 4d0d2e68e8983..47b7f35cbdd7c 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_e2e.py @@ -400,11 +400,7 @@ def main(): sampling_times.append(sampling_end_time - sampling_start_time) all_token_ids = torch.cat([all_token_ids, tokens_to_add], dim=-1) - - # Return early if all batch entries have reached EOS token id current_length += 1 - if torch.all(has_eos) or current_length > max_length: - break # Update inputs for next inference run inputs["input_ids"] = tokens_to_add