Intel® Extension for OpenXLA* includes PJRT plugin implementation, which seamlessly runs JAX models on Intel GPU. The PJRT API simplified the integration, which allowed the Intel GPU plugin to be developed separately and quickly integrated into JAX. Refer to OpenXLA PJRT Plugin RFC for more details.
Please check README#requirements for the requirements of hardware and software.
The following table tracks intel-extension-for-openxla versions and compatible versions of jax
and jaxlib
. The compatibility between jax
and jaxlib
is maintained through JAX. This version restriction will be relaxed over time as the plugin API matures.
intel-extension-for-openxla | jaxlib | jax |
---|---|---|
0.5.0 | 0.4.30 | >= 0.4.30, <= 0.4.31 |
0.4.0 | 0.4.26 | >= 0.4.26, <= 0.4.27 |
0.3.0 | 0.4.24 | >= 0.4.24, <= 0.4.27 |
0.2.1 | 0.4.20 | >= 0.4.20, <= 0.4.26 |
0.2.0 | 0.4.20 | >= 0.4.20, <= 0.4.26 |
0.1.0 | 0.4.13 | >= 0.4.13, <= 0.4.14 |
conda is recommanded as the virtual running environment.
conda create -n jax-ioex python=3.10
conda activate jax-ioex
pip install -U pip
pip install intel-extension-for-openxla
# Install jax, jaxlib and flax dependency
pip install -r https://raw.githubusercontent.com/intel/intel-extension-for-openxla/main/test/requirements.txt
Please refer to requirements.txt for the version dependency of jax
, jaxlib
and flax
.
python -c "import jax; print(jax.devices())"
Reference result:
[xpu(id=0), xpu(id=1)]
curl -L -O "/~https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash Miniforge3-$(uname)-$(uname -m).sh
Please follow Install to prepare the basic environment first.
pip install transformers==4.47 diffusers==0.31.0 datasets==4.9.7 msgpack==1.1.0
Source OneAPI env
source /opt/intel/oneapi/compiler/2025.0/env/vars.sh
source /opt/intel/oneapi/mkl/2025.0/env/vars.sh
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/intel/oneapi/umf/latest/lib
NOTE: The path of OneAPI env script is based on the OneAPI installed path.
Go to example/stable_diffusion for detail about this demo.
Command | Model | Output Image Resolution |
---|---|---|
python jax_stable.py |
CompVis/stable-diffusion-v1-4 | 512x512 |
python jax_stable.py -m stabilityai/stable-diffusion-2 |
stabilityai/stable-diffusion-2 | 768x768 |
python jax_stable.py -m stabilityai/stable-diffusion-2-1 |
stabilityai/stable-diffusion-2-1 | 768x768 |
Average Latency per image is: x.xxx s
Average Throughput per second is: x.xxx steps
To submit questions, feature requests, and bug reports about the intel-extension-for-openxla plugin, visit the GitHub intel-extension-for-openxla issues page. You can also view GitHub JAX Issues with the label "Intel GPU plugin".
-
If there is an error 'No visible XPU devices', print
jax.local_devices()
to check which device is running. Setexport OCL_ICD_ENABLE_TRACE=1
to check if there are driver error messages. The following code opens more debug log for JAX app.import logging logging.basicConfig(level = logging.DEBUG)
-
If there is an error 'version GLIBCXX_3.4.30' not found, upgrade libstdc++ to the latest, for example for conda
conda install libstdcxx-ng==12.2.0 -c conda-forge