Skip to content

Commit

Permalink
Initial bits of documentation on writing a kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Feb 24, 2025
1 parent d0651a8 commit d3647dd
Showing 1 changed file with 173 additions and 0 deletions.
173 changes: 173 additions & 0 deletions docs/writing-kernels.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# Writing Hub kernels with kernel-builder

## Introduction

Hub kernels differ from traditional Python kernel packages in that they
can be loaded easily outside `PYTHONPATH` and that multiple versions of the
same kernel can be loaded in the same Python process. For more information,
about these requirements, see the [hf-kernels](/~https://github.com/huggingface/hf-kernels)
documentation.

`kernel-builder` provides a build system that produces kernels that are
compatible with the Kernel Hub. It takes care of:

- Building kernels for all supported PyTorch configurations (CXX98/11 and
different CUDA versions).
- Compatibility with old glibc and libstdc++ versions, so that kernels also
work on older Linux distributions.
- Registering Torch ops, such that multiple versions the same kernel can be
loaded without namespace conflicts.

`kernel-builder` builds are configured through a `build.toml` file.
`build.toml` is a simple format that does not require intricate knowledge
of CMake or setuptools.

This page describes the directory layout of a kernel-builder project, the
format of the `build.toml` file, and some additional Python glue that
`kernel-builder` provides.

## Kernel project layout

Kernel projects generally follow this layout:

```text
example
├── build.toml
├── kernel_a
├── kernel_b
└── torch-ext
└── torch_bindings.cpp
└── torch_bindings.h
└── example
└── __init__.py
```

- The `build.toml` file is the build configuration.
- One or more top-level directories containing kernels (`kernel_a` and `kernel_b` here).
- The `torch-ext` directory contains:
- `torch_bindings.h`: contains declarations for kernel entry points
(from `kernel_a` and `kernel_b`).
- `torch_bindings.cpp`: registers the entry points as Torch ops.
- `torch_ext/example`: contains any Python wrapping the kernel needs. At the
bare minimum, it should contain an `__init__.py` file.

## `build.toml`

`build.toml` tells `kernel-builder` what to build and how. Here is an example
for an extension named `example` containing two kernels:

```toml
[general]
name = `example`

[torch]
src = [
"torch-ext/torch_bindings.cpp",
"torch-ext/torch_bindings.h"
]

[kernel.activation]
cuda-capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
src = [
"kernel_a/kernel.cu",
"kernel_a/kernel.h,
"kernel_b/kernel.cu",
"kernel_b/kernel.h,
]
depends = [ "torch" ]
```

### `general`

`name` defines the name of the kernel. The Python code for a Torch extension
must be stored in `torch-ext/<name>`.

### `torch`

This section describes the Torch extension. In the future, there may be
similar sections for other frameworks. This section can contain the
following options:

- `src` (required): a list of source files and headers.
- `pyext` (optional): the list of extensions for Python files. Default:
`["py", "pyi"]`.
- `include` (optional): include directories relative to the project root.
Default: `[]`.

### `kernel.<name>`

Specification of a kernel with the name `<name>`. This section can contain
the following options:

- `cuda-capabilities` (required): a list of CUDA capabilities that the
kernel should be compiled for.
- `depends` (required): a list of dependencies. The supported dependencies
are listed in [`deps.nix`](../lib/deps.nix].
- `src` (required): a list of source files and headers.
- `include` (optional): include directories relative to the project root.
Default: `[]`.

## Torch bindings

### Defining bindings

Torch bindings are defined in C++, kernels commonly use two files:

- `torch_bindings.h` containing function declarations.
- `torch_bindings.cpp` registering the functions as Torch ops.

For instance, a simple kernel could have the following declaration in
`torch_bindings.h`:

```cpp
#pragma once

#include <torch/torch.h>

void silu_and_mul(torch::Tensor &out, torch::Tensor const &input);
```
This function can then be registered as a Torch op in `torch_bindings.cpp`:
```cpp
#include <torch/library.h>
#include "registration.h"
#include "torch_bindings.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
```

This snippet uses macros from `registration.h` to register the function.
`registration.h` is generated by `kernel-builder` itself. A function
is registered through the `def`/`ops` methods. `ops` specifies the
function signature following the [function schema](/~https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/README.md#func).
`impl` associates the function name with the C/C++ function and
the applicable device.

## Using the functions from Python

The bindings are typically wrapped in Python code in `torch_ext/<name>`.
The native code is exposed under the `torch.ops` namespace. However,
we add some unique material to the name of the extension to ensure that
different versions of the same extension can be loaded at the same time.
As a result, the extension is registered as
`torch.ops.<name>_<unique_material>`.

To deal with this uniqueness, `kernel_builder` generates a Python module
named `_ops` that contains an alias for the name. This can be used to
refer to the correct `torch.ops` module. For example:

```python
import torch
from ._ops import ops

def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
ops.silu_and_mul(out, x)
return out
```

0 comments on commit d3647dd

Please sign in to comment.