Skip to content

Commit

Permalink
Add to ONNX reproduction logs (#2565)
Browse files Browse the repository at this point in the history
  • Loading branch information
valamuri2020 authored Aug 17, 2024
1 parent 8dab4d6 commit 46b6834
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
8 changes: 7 additions & 1 deletion docs/onnx-conversion.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# End to End ONNX Conversion for SPLADE++ Ensemble Distil
This MD file will describe steps to convert particular PyTorch models (i.e., [SPLADE++](https://doi.org/10.1145/3477495.3531857)) to ONNX models and options to further optimize compute graph for Transformer-based models. For more details on how does ONNX Conversion work and how to optimize the compute graph, please refer to [ONNX Tutorials](/~https://github.com/onnx/tutorials#services).

The SPLADE model takes a text input and generates sparse token-level representations as output, where each token is assigned a weight, enabling efficient information retrieval. A more in depth explantation can be found [here](https://www.pinecone.io/learn/splade/).

All scripts are available for reference under in the following directory:
```
src/main/python/onnx
Expand Down Expand Up @@ -331,4 +333,8 @@ cd src/main/python/onnx/models
cp splade-cocondenser-ensembledistil-optimized.onnx splade-cocondenser-ensembledistil-vocab.txt ~/.cache/anserini/encoders/
```

Second, now run the end to end regression as seen in the previously mentioned documentation with the generated ONNX model.
Second, now run the end to end regression as seen in the previously mentioned documentation with the generated ONNX model.


### Reproduction Log
+ Results reproduced by [@valamuri2020](/~https://github.com/valamuri2020) on 2024-08-06 (commit [`6178b40`](/~https://github.com/castorini/anserini/commit/6178b407fc791d62f81e751313771165c6e2c743))
19 changes: 5 additions & 14 deletions src/main/python/onnx/convert_hf_model_to_onnx.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import torch
from transformers import AutoTokenizer, AutoModel
import onnx
import onnxruntime
import argparse
import os

import onnx
import onnxruntime
import torch
from transformers import AutoModel, AutoTokenizer

# device
device = "cuda" if torch.cuda.is_available() else "cpu" # make sure torch is compiled with cuda if you have a cuda device

Expand All @@ -21,14 +22,6 @@ def get_model_output_names(model, test_input):
else:
return [f'output_{i}' for i in range(len(outputs))]

def get_dynamic_axes(input_names, output_names):
dynamic_axes = {}
for name in input_names:
dynamic_axes[name] = {0: 'batch_size', 1: 'sequence'}
for name in output_names:
dynamic_axes[name] = {0: 'batch_size', 1: 'sequence'}
return dynamic_axes

def convert_model_to_onnx(text, model, tokenizer, onnx_path, vocab_path, device):
print(model) # this prints the model structure for better understanding (optional)
model.eval()
Expand All @@ -38,7 +31,6 @@ def convert_model_to_onnx(text, model, tokenizer, onnx_path, vocab_path, device)
test_input = {k: v.to(device) for k, v in test_input.items()}

output_names = get_model_output_names(model, test_input)
dynamic_axes = get_dynamic_axes(input_names, output_names)

model_type = model.config.model_type
num_heads = model.config.num_attention_heads
Expand All @@ -50,7 +42,6 @@ def convert_model_to_onnx(text, model, tokenizer, onnx_path, vocab_path, device)
onnx_path,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=14
)
Expand Down

0 comments on commit 46b6834

Please sign in to comment.