Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ReLU example and use it in the docs #69

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/build_kernel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ jobs:
- name: Copy cutlass GEMM kernel
run: cp -rL examples/cutlass-gemm/result cutlass-gemm-kernel

- name: Build relu kernel
run: ( cd examples/relu && nix build .\#redistributable.torch25-cxx98-cu121-x86_64-linux )
- name: Copy relu kernel
run: cp -rL examples/relu/result relu-kernel

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build Docker image
Expand Down
90 changes: 66 additions & 24 deletions docs/writing-kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,57 +30,55 @@ of CMake or setuptools.

This page describes the directory layout of a kernel-builder project, the
format of the `build.toml` file, and some additional Python glue that
`kernel-builder` provides.
`kernel-builder` provides. We will use a [simple ReLU kernel](../examples/relu)
as the running example.

## Kernel project layout

Kernel projects follow this general directory layout:

```text
example
relu
├── build.toml
├── kernel_a
── kernel_b
├── relu_kernel
│ └── relu.cu
└── torch-ext
└── torch_bindings.cpp
└── torch_bindings.h
└── example
└── relu
└── __init__.py
```

In this example we can find:

- The build configuration in `build.toml`.
- One or more top-level directories containing kernels (`kernel_a` and `kernel_b`).
- One or more top-level directories containing kernels (`relu_kernel`).
- The `torch-ext` directory, which contains:
- `torch_bindings.h`: contains declarations for kernel entry points
(from `kernel_a` and `kernel_b`).
- `torch_bindings.cpp`: registers the entry points as Torch ops.
- `torch_ext/example`: contains any Python wrapping the kernel needs. At the
- `torch_ext/relu`: contains any Python wrapping the kernel needs. At the
bare minimum, it should contain an `__init__.py` file.

## `build.toml`

`build.toml` tells `kernel-builder` what to build and how. Here is an example
for an extension named `example` containing two kernels:
`build.toml` tells `kernel-builder` what to build and how. It looks as
follows for the `relu` kernel:

```toml
[general]
name = "example"
name = "relu"

[torch]
src = [
"torch-ext/torch_bindings.cpp",
"torch-ext/torch_bindings.h"
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h"
]

[kernel.activation]
cuda-capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
src = [
"kernel_a/kernel.cu",
"kernel_a/kernel.h",
"kernel_b/kernel.cu",
"kernel_b/kernel.h",
"relu_kernel/relu.cu",
]
depends = [ "torch" ]
```
Expand Down Expand Up @@ -115,6 +113,10 @@ the following options:
- `include` (optional): include directories relative to the project root.
Default: `[]`.

Multiple `kernel.<name>` sections can be defined in the same `build.toml`.
See for example [`kernels-community/quantization`](https://huggingface.co/kernels-community/quantization/)
for an example with multiple kernel sections.

## Torch bindings

### Defining bindings
Expand All @@ -124,28 +126,65 @@ Torch bindings are defined in C++, kernels commonly use two files:
- `torch_bindings.h` containing function declarations.
- `torch_bindings.cpp` registering the functions as Torch ops.

For instance, a simple kernel could have the following declaration in
For instance, the `relu` kernel has the following declaration in
`torch_bindings.h`:

```cpp
#pragma once

#include <torch/torch.h>

void silu_and_mul(torch::Tensor &out, torch::Tensor const &input);
void relu(torch::Tensor &out, torch::Tensor const &input);
```

This is a declaration for the actual kernel, which is in `relu_kernel/relu.cu`:

```cpp
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

#include <cmath>

__global__ void relu_kernel(float *__restrict__ out,
float const *__restrict__ input,
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
auto x = input[token_idx * d + idx];
out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
}
}

void relu(torch::Tensor &out,
torch::Tensor const &input)
{
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
input.scalar_type() == at::ScalarType::Float,
"relu_kernel only supports float32");

int d = input.size(-1);
int64_t num_tokens = input.numel() / d;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
relu_kernel<<<grid, block, 0, stream>>>(out.data_ptr<float>(),
input.data_ptr<float>(), d);
}
```

This function can then be registered as a Torch op in `torch_bindings.cpp`:
This function is then registered as a Torch op in `torch_bindings.cpp`:

```cpp
#include <torch/library.h>

#include "registration.h"
#include "torch_bindings.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("silu_and_mul(Tensor! out, Tensor input) -> ()");
ops.impl("silu_and_mul", torch::kCUDA, &silu_and_mul);
ops.def("relu(Tensor! out, Tensor input) -> ()");
ops.impl("relu", torch::kCUDA, &relu);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
Expand All @@ -172,10 +211,13 @@ named `_ops` that contains an alias for the name. This can be used to
refer to the correct `torch.ops` module. For example:

```python
from typing import Optional
import torch
from ._ops import ops

def silu_and_mul(out: torch.Tensor, x: torch.Tensor) -> None:
ops.silu_and_mul(out, x)
def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
if out is None:
out = torch.empty_like(x)
ops.relu(out, x)
return out
```
15 changes: 15 additions & 0 deletions examples/relu/build.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[general]
name = "relu"

[torch]
src = [
"torch-ext/torch_binding.cpp",
"torch-ext/torch_binding.h"
]

[kernel.activation]
cuda-capabilities = [ "7.0", "7.2", "7.5", "8.0", "8.6", "8.7", "8.9", "9.0" ]
src = [
"relu_kernel/relu.cu",
]
depends = [ "torch" ]
14 changes: 14 additions & 0 deletions examples/relu/flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
description = "Flake for ReLU kernel";

inputs = {
kernel-builder.url = "path:../..";
};

outputs =
{
self,
kernel-builder,
}:
kernel-builder.lib.genFlakeOutputs ./.;
}
32 changes: 32 additions & 0 deletions examples/relu/relu_kernel/relu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>

#include <cmath>

__global__ void relu_kernel(float *__restrict__ out,
float const *__restrict__ input,
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
auto x = input[token_idx * d + idx];
out[token_idx * d + idx] = x > 0.0f ? x : 0.0f;
}
}

void relu(torch::Tensor &out,
torch::Tensor const &input)
{
TORCH_CHECK(input.scalar_type() == at::ScalarType::Float &&
input.scalar_type() == at::ScalarType::Float,
"relu_kernel only supports float32");

int d = input.size(-1);
int64_t num_tokens = input.numel() / d;
dim3 grid(num_tokens);
dim3 block(std::min(d, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
relu_kernel<<<grid, block, 0, stream>>>(out.data_ptr<float>(),
input.data_ptr<float>(), d);
}
Empty file added examples/relu/tests/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions examples/relu/tests/test_relu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch
import torch.nn.functional as F

import relu


def test_relu():
x = torch.randn(1024, 1024, dtype=torch.float32, device="cuda")
torch.testing.assert_allclose(F.relu(x), relu.relu(x))
12 changes: 12 additions & 0 deletions examples/relu/torch-ext/relu/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from typing import Optional

import torch

from ._ops import ops


def relu(x: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
if out is None:
out = torch.empty_like(x)
ops.relu(out, x)
return out
11 changes: 11 additions & 0 deletions examples/relu/torch-ext/torch_binding.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("relu(Tensor! out, Tensor input) -> ()");
ops.impl("relu", torch::kCUDA, &relu);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
5 changes: 5 additions & 0 deletions examples/relu/torch-ext/torch_binding.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#pragma once

#include <torch/torch.h>

void relu(torch::Tensor &out, torch::Tensor const &input);