Skip to content

Commit

Permalink
feat: intial L1 compute
Browse files Browse the repository at this point in the history
  • Loading branch information
b8zhong committed Jan 18, 2025
1 parent 369dc99 commit 27fcb65
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 12 deletions.
28 changes: 22 additions & 6 deletions docs/onnx-conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ onnxruntime 1.20.1
## Converting from PyTorch models to ONNX model
The following sections will describe how to convert SPLADE++ model to ONNX model. The steps are as follows:

### Run the End to End PyTorch to ONNX Conversion
### Run the End to End PyTorch to ONNX Conversion with Validation
Loading and running is done easily with argparse in the following script:

```
Expand All @@ -39,15 +39,31 @@ All that needs to be provided is the model_name as seen on huggingface. In our e
naver/splade-cocondenser-ensembledistil
```

To run the script and produce the onnx model, run the following sequence of commands:
To run the script and produce the onnx model with validation, run the following sequence of commands:
```bash
# Begin by going to the appropriate directory
cd src/main/python/onnx
# Now run the script
python convert_hf_model_to_onnx.py --model_name naver/splade-cocondenser-ensembledistil
# Now run the script with validation
python convert_hf_model_to_onnx.py --model_name naver/splade-cocondenser-ensembledistil --text "what is AI?"
```

So what actually happens under the hood? The following sections will discuss the key parts of the above script:
The script will now:
1. Convert the PyTorch model to ONNX format
2. Run inference on both models with the test input ("what is AI?" by default)
3. Compute and report the L1 norm difference between PyTorch and ONNX outputs
4. Validate that the difference is below an acceptable threshold (1e-4)

Example output:
```
Some weights of BertModel were not initialized from the model checkpoint at naver/splade-cocondenser-ensembledistil and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
ONNX model checked successfully
L1 difference between PyTorch and ONNX outputs: 3.5234e-07
ONNX conversion validated successfully!
```

> Note: For SPLADE models, the validation applies ReLU activation to both PyTorch and ONNX outputs before computing the L1 difference, since SPLADE uses ReLU activation in its architecture. This ensures accurate validation of the conversion process.
If the L1 difference exceeds the threshold, a warning will be displayed indicating potential conversion issues.

### Getting Output Specificaton from the Model

Expand Down
56 changes: 50 additions & 6 deletions src/main/python/onnx/convert_hf_model_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import onnxruntime
import torch
from transformers import AutoModel, AutoTokenizer
import numpy as np
import logging

# device
device = "cuda" if torch.cuda.is_available() else "cpu" # make sure torch is compiled with cuda if you have a cuda device
Expand All @@ -23,7 +25,7 @@ def get_model_output_names(model, test_input):
return [f'output_{i}' for i in range(len(outputs))]

def convert_model_to_onnx(text, model, tokenizer, onnx_path, vocab_path, device):
print(model) # this prints the model structure for better understanding (optional)
logging.info(model) # this prints the model structure for better understanding (optional)
model.eval()

test_input = tokenizer(text, return_tensors="pt")
Expand Down Expand Up @@ -55,26 +57,68 @@ def convert_model_to_onnx(text, model, tokenizer, onnx_path, vocab_path, device)
meta.key, meta.value = 'hidden_size', str(hidden_size)

onnx.save(onnx_model, onnx_path) # including the metadata
print(f"Model converted and saved to {onnx_path}")
logging.info(f"Model converted and saved to {onnx_path}")

onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
print("ONNX model checked successfully")
logging.info("ONNX model checked successfully")

vocab = tokenizer.get_vocab()
with open(vocab_path, 'w', encoding='utf-8') as f:
for token, index in sorted(vocab.items(), key=lambda x: x[1]):
f.write(f"{token}\n")
print(f"Vocabulary saved to {vocab_path}")
logging.info(f"Vocabulary saved to {vocab_path}")

# small inference session for testing
ort_session = onnxruntime.InferenceSession(onnx_path)
ort_inputs = {k: v.cpu().numpy() for k, v in test_input.items()}
ort_outputs = ort_session.run(None, ort_inputs)
print("ONNX model output shape:", ort_outputs[0].shape)
print("ONNX model test run successful")
logging.info(f"ONNX model output shape: {ort_outputs[0].shape}")
logging.info("ONNX model test run successful")

l1_diff = validate_onnx_conversion(model, onnx_path, test_input, tokenizer)
validation_threshold = 1e-4
if l1_diff > validation_threshold:
logging.warning(f"Warning: L1 difference ({l1_diff}) exceeds threshold ({validation_threshold})")
logging.warning("ONNX conversion may not be accurate!")
logging.warning("This could be due to missing ReLU activation in the comparison.")
else:
logging.info("ONNX conversion validated successfully!")

def validate_onnx_conversion(pytorch_model, onnx_model_path, test_input, tokenizer):
"""Validates ONNX model outputs against PyTorch model outputs."""

# Get PyTorch output with exact same processing
with torch.no_grad():
pytorch_outputs = pytorch_model(**test_input)

ort_session = onnxruntime.InferenceSession(onnx_model_path)
onnx_inputs = {name: test_input[name].numpy() for name in test_input if name in [i.name for i in ort_session.get_inputs()]}
onnx_outputs = ort_session.run(None, onnx_inputs)

# Convert output to numpy for comparison
if isinstance(pytorch_outputs, torch.Tensor):
pytorch_outputs = pytorch_outputs.numpy()
else:
if hasattr(pytorch_outputs, 'last_hidden_state'):
pytorch_outputs = pytorch_outputs.last_hidden_state.numpy()
elif hasattr(pytorch_outputs, 'logits'):
pytorch_outputs = pytorch_outputs.logits.numpy()

# Apply ReLU to both outputs before comparison (for SPLADE models)
pytorch_outputs = np.maximum(pytorch_outputs, 0)
onnx_outputs = np.maximum(onnx_outputs[0], 0)

l1_diff = np.mean(np.abs(pytorch_outputs - onnx_outputs))
logging.info(f"L1 difference between PyTorch and ONNX outputs: {l1_diff}")
return l1_diff

if __name__ == "__main__":
logging.basicConfig(
level=logging.INFO,
format='%(message)s'
)

parser = argparse.ArgumentParser(description="Convert Hugging Face model to ONNX")
parser.add_argument("--model_name", type=str, help="Name or path of the Hugging Face model")
parser.add_argument("--text", type=str, default="what is AI?", help="Test input text for the model")
Expand Down

0 comments on commit 27fcb65

Please sign in to comment.