Skip to content

graphcore-research/tessellate-ipu

Repository files navigation

logo

TessellateIPU Library

Run on Gradient tests notebook-tests license GitHub Repo stars

Features | Installation guide | Quickstart | Documentation | Projects

🔴 ⚠️ Non-official Graphcore Product ⚠️ 🔴

TessellateIPU is a library bringing low-level Poplar IPU programming to Python ML frameworks (JAX at the moment, and PyTorch in the near future).

The package is maintained by the Graphcore Research team. Expect bugs and sharp edges! Please let us know what you think!

Features

TessellateIPU brings low-level Poplar IPU programming to Python, while being fully compatible with ML framework standard APIs. The main features are:

  • Control tile mapping of arrays using tile_put_replicated or tile_put_sharded
  • Support of standard JAX LAX operations at tile level using tile_map (see operations supported)
  • Easy integration of custom IPU C++ vertex (see vertex example)
  • Access to low-level IPU hardware functionalities such as cycle count and random seed set/get
  • Full compatibility with other backends

The TessellateIPU API allows easy and efficient implementation of algorithms on IPUs, while keeping compatibility with other backends (CPU, GPU, TPU). For more details on the API, please refer to the TessellateIPU documentation, or try it on IPU Paperspace Gradient .

Installation guide

This package requires JAX IPU experimental (available for Python 3.8 and Poplar SDK versions 3.1 or 3.2). For Poplar SDK 3.2:

pip install jax==0.3.16+ipu jaxlib==0.3.15+ipu.sdk320 -f https://graphcore-research.github.io/jax-experimental/wheels.html

Please change sdk320 into sdk310 if using Poplar SDK 3.1.

As a pure Python repo, TessellateIPU can then be directly installed from GitHub using pip:

pip install git+/~https://github.com/graphcore-research/tessellate-ipu.git@main

Note: main can be replaced with any tag (v0.1, ...) or commit hash in order to install a specific version.

Local pip install is also supported after cloning the Github repository:

git clone git@github.com:graphcore-research/tessellate-ipu.git
pip install ./tessellate_ipu

Minimal example

The following is a simple example showing how to set the tile mapping of JAX arrays, and run a JAX LAX operation on these tiles.

import numpy as np
import jax
from tessellate_ipu import tile_put_sharded, tile_map

# Which IPU tiles do we want to use?
tiles = (0, 1, 3)

@jax.jit
def compute_fn(data0, data1):
    # Tile sharding arrays along the first axis.
    input0 = tile_put_sharded(data0, tiles)
    input1 = tile_put_sharded(data1, tiles)
    # Map a JAX LAX primitive on tiles.
    output = tile_map(jax.lax.add_p, input0, input1)
    return output

data = np.random.rand(len(tiles), 2, 3).astype(np.float32)
output = compute_fn(data, 3 * data)

print("Output:", output)

Useful environment variables and flags

JAX IPU experimental flags, using from jax.config import config:

Flag Description
config.FLAGS.jax_platform_name ='ipu'/'cpu' Configure default JAX backend. Useful for CPU initialization.
config.FLAGS.jax_ipu_use_model = True Use IPU model emulator.
config.FLAGS.jax_ipu_model_num_tiles = 8 Set the number of tiles in the IPU model.
config.FLAGS.jax_ipu_device_count = 2 Set the number of IPUs visible in JAX. Can be any local IPU available.
config.FLAGS.jax_ipu_visible_devices = '0,1' Set the specific collection of local IPUs to be visible in JAX.

Alternatively, like other JAX flags, these can be set using environment variables (for example JAX_IPU_USE_MODEL and JAX_IPU_MODEL_NUM_TILES).

PopVision environment variables:

  • Generate a PopVision Graph analyser profile: PVTI_OPTIONS='{"enable":"true", "directory":"./reports"}'
  • Generate a PopVision system analyser profile: POPLAR_ENGINE_OPTIONS='{"autoReport.all":"true", "debug.allowOutOfMemory":"true"}'

Documentation

Projects using TessellateIPU

  • PySCF IPU: Molecular quantum chemistry simulation on Graphcore IPUs;

License

Copyright (c) 2023 Graphcore Ltd. The project is licensed under the Apache License 2.0.

TessellateIPU is implemented using C++ custom operations. These have the following C++ libraries as dependencies, statically compiled into a shared library:

Component Description License
fastbase64 Base64 fast decoder library Simplified BSD (FreeBSD) License
fmt A modern C++ formatting library MIT license
half IEEE-754 conformant half-precision library MIT license
json JSON for modern C++ MIT license
nanobind Tiny C++/Python bindings BSD 3-Clause License