pip install mriaug
to use a 3D image library that is ~50x faster and simpler than torchio
by
- only using PyTorch β full GPU(+autograd) support π₯
- being tiny: ~200 lines of code β no room for bugs π
while offering ~20 different augmentations (incl. MRI-specific operations) π©»
πΆ Normal users should use mriaug
via niftiai
, a deep learning framework for 3D images, since it
- provides
aug_transforms3d
: A convenient function that compiles allmriaug
mentations! - simplifies all the code needed for data loading, training, visualization...check it out here!
π΄ Experienced users can build their own framework upon mriaug
(use niftiai/augment.py
as a cheat sheet)
Let's create a 3D image tensor (with additional batch and channel dimension) and apply flip3d
import torch
from mriaug import flip3d
shape = (1, 1, 4, 4, 4)
x = torch.linspace(0, 1, 4**3).view(*shape)
x_flipped = flip3d(x)
print(x[..., 0, 0]) # tensor([[[0.0000, 0.2540, 0.5079, 0.7619]]])
print(x_flipped[..., 0, 0]) # tensor([[[0.7619, 0.5079, 0.2540, 0.0000]]])
Explore the gallery to understand the usage and effect of all ~20 augmentations!
The popular libraries torchio
and MONAI
(utilizes torchio
) often use ITK
(CPU only) like this
PyTorch tensor β NumPy array β NiBabel image β ITK operation (C/C++) β NumPy array β PyTorch tensor
to augment a PyTorch tensor π¬ That's complicated and does not use the (for neural net training needed) GPU π
Instead, mriaug
directly uses PyTorch (CPU & GPU support) resulting in
- ~50x fewer lines of code:
torchio
: ~10,000 LOC,mriaug
: ~200 LOC π€ - ~50x speedup on GPU π₯ based on the table below (run
speed.py
to reproduce) π¨
Click here, to see runtimes on a 256Β³ image in seconds (on AMD 5950X CPU and NVIDIA RTX 3090 GPU)
Transformation | torchio |
mriaug on CPU |
mriaug on GPU |
Speedup vs. torchio |
---|---|---|---|---|
Flip | 0.014 | 0.012 | 0.002 | 7.5x |
Affine | 0.297 | 0.608 | 0.011 | 27.9x |
Warp | 0.951 | 0.850 | 0.009 | 103.3x |
Bias Field | 3.258 | 0.081 | 0.002 | 1813.0x |
Noise | 0.117 | 0.105 | 0.001 | 230.4x |
Downsample | 0.282 | 0.013 | 0.000 | 592.3x |
Ghosting | 0.241 | 0.170 | 0.003 | 78.3x |
Spike | 0.265 | 0.172 | 0.003 | 88.8x |
Motion | 0.696 | 0.540 | 0.009 | 78.6x |
Let's load an example 3D image x
, show it with niftiview
(used to create all images below)
define some arguments
size = (160, 196, 160)
zoom = torch.tensor([[-.2, 0, 0]])
rotate = torch.tensor([[0, .1, 0]])
translate = torch.tensor([[0, 0, .2]])
shear = torch.tensor([[0, .05, 0]])
and run all augmentations (see runall.py
):