Official PyTorch implementation of "MobileMamba: Lightweight Multi-Receptive Visual Mamba Network".
Haoyang He1*, Jiangning Zhang2*, Yuxuan Cai3, Hongxu Chen1 Xiaobin Hu2,
Zhenye Gan2, Yabiao Wang2, Chengjie Wang2, Yunsheng Wu2, Lei Xie1†
1College of Control Science and Engineering, Zhejiang University, 2Youtu Lab, Tencent, 3Huazhong University of Science and Technology
Abstract: Previous research on lightweight models has primarily focused on CNNs and Transformer-based designs. CNNs, with their local receptive fields, struggle to capture long-range dependencies, while Transformers, despite their global modeling capabilities, are limited by quadratic computational complexity in high-resolution scenarios. Recently, state-space models have gained popularity in the visual domain due to their linear computational complexity. Despite their low FLOPs, current lightweight Mamba-based models exhibit suboptimal throughput. In this work, we propose the MobileMamba framework, which balances efficiency and performance. We design a three-stage network to enhance inference speed significantly. At a fine-grained level, we introduce the Multi-Receptive Field Feature Interaction MRFFI module, comprising the Long-Range Wavelet Transform-Enhanced Mamba WTE-Mamba, Efficient Multi-Kernel Depthwise Convolution MK-DeConv, and Eliminate Redundant Identity components. This module integrates multi-receptive field information and enhances high-frequency detail extraction. Additionally, we employ training and testing strategies to further improve performance and efficiency. MobileMamba achieves up to 83.6% on Top-1, surpassing existing state-of-the-art methods which is maximum x21 faster than LocalVim on GPU. Extensive experiments on high-resolution downstream tasks demonstrate that MobileMamba surpasses current efficient models, achieving an optimal balance between speed and accuracy.
Top: Visualization of the Effective Receptive Fields (ERF) for different architectures. Bottom: Performance vs. FLOPs with recent CNN/Transformer/Mamba-based methods.
Accuracy vs. Speed with Mamba-based methods.
Image Classification for ImageNet-1K:
Model | FLOPs | #Params | Resolution | Top-1 | Cfg | Log | Model |
---|---|---|---|---|---|---|---|
MobileMamba-T2 | 255M | 8.8M | 192 x 192 | 71.5 | cfg | log | model |
MobileMamba-T2† | 255M | 8.8M | 192 x 192 | 76.9 | cfg | log | model |
MobileMamba-T4 | 413M | 14.2M | 192 x 192 | 76.1 | cfg | log | model |
MobileMamba-T4† | 413M | 14.2M | 192 x 192 | 78.9 | cfg | log | model |
MobileMamba-S6 | 652M | 15.0M | 224 x 224 | 78.0 | cfg | log | model |
MobileMamba-S6† | 652M | 15.0M | 224 x 224 | 80.7 | cfg | log | model |
MobileMamba-B1 | 1080M | 17.1M | 256 x 256 | 79.9 | cfg | log | model |
MobileMamba-B1† | 1080M | 17.1M | 256 x 256 | 82.2 | cfg | log | model |
MobileMamba-B2 | 2427M | 17.1M | 384 x 384 | 81.6 | cfg | log | model |
MobileMamba-B2† | 2427M | 17.1M | 384 x 384 | 83.3 | cfg | log | model |
MobileMamba-B4 | 4313M | 17.1M | 512 x 512 | 82.5 | cfg | log | model |
MobileMamba-B4† | 4313M | 17.1M | 512 x 512 | 83.6 | cfg | log | model |
Backbone | APb | APb50 | APb75 | APbS | APbM | APbL | APm | APm50 | APm75 | APmS | APmM | APmL | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
MobileMamba-B1 | 40.6 | 61.8 | 43.8 | 22.4 | 43.5 | 55.9 | 37.4 | 58.9 | 39.9 | 17.1 | 39.9 | 56.4 | 38.0M | 178G | cfg | log | model |
Backbone | AP | AP50 | AP75 | APS | APM | APL | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|---|---|---|
MobileMamba-B1 | 39.6 | 59.8 | 42.4 | 21.5 | 43.4 | 53.9 | 27.1M | 151G | cfg | log | model |
Backbone | AP | AP50 | AP75 | APS | APM | APL | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|---|---|---|
MobileMamba-B1 | 24.0 | 39.5 | 24.0 | 3.1 | 23.4 | 46.9 | 18.0M | 1.7G | cfg | log | model |
MobileMamba-B1-r512 | 29.5 | 47.7 | 30.4 | 8.9 | 35.0 | 47.0 | 18.0M | 4.4G | cfg | log | model |
Semantic Segmentation Based on Semantic FPN for ADE20k:
Backbone | aAcc | mIoU | mAcc | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|
MobileMamba-B4 | 79.9 | 42.5 | 53.7 | 19.8M | 5.6G | cfg | log | model |
Backbone | aAcc | mIoU | mAcc | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|
MobileMamba-B4 | 76.3 | 36.6 | 47.1 | 23.4M | 4.7G | cfg | log | model |
Backbone | aAcc | mIoU | mAcc | #Params | FLOPs | Cfg | Log | Model |
---|---|---|---|---|---|---|---|---|
MobileMamba-B4 | 76.2 | 36.9 | 47.9 | 20.5M | 4.5G | cfg | log | model |
The model weights and log files for all classification and downstream tasks are available for download via weights.
pip3 install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
pip3 install timm==0.9.16 tensorboardX einops torchprofile fvcore==0.1.5.post20221221 triton==2.1.0
cd model/lib_mamba/kernels/selective_scan && pip install . && cd ../../../..
git clone /~https://github.com/NVIDIA/apex && cd apex && pip3 install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./ (optional)
Download and extract ImageNet-1K dataset in the following directory structure:
├── imagenet
├── train
├── n01440764
├── n01440764_10026.JPEG
├── ...
├── ...
├── train.txt (optional)
├── val
├── n01440764
├── ILSVRC2012_val_00000293.JPEG
├── ...
├── ...
└── val.txt (optional)
Test with 8 GPUs in one node:
MobileMamba-T2
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t2 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_T2/mobilemamba_t2.pth
This should give Top-1: 73.638 (Top-5: 91.422)
MobileMamba-T2†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t2s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_T2s/mobilemamba_t2s.pth
This should give Top-1: 76.934 (Top-5: 93.100)
MobileMamba-T4
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t4 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_T4/mobilemamba_t4.pth
This should give Top-1: 76.086 (Top-5: 92.772)
MobileMamba-T4†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t4s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_T4s/mobilemamba_t4s.pth
This should give Top-1: 78.914 (Top-5: 94.160)
MobileMamba-S6
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_s6 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_S6/mobilemamba_s6.pth
This should give Top-1: 78.002 (Top-5: 93.992)
MobileMamba-S6†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_s6s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_S6s/mobilemamba_s6s.pth
This should give Top-1: 80.742 (Top-5: 95.182)
MobileMamba-B1
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b1 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B1/mobilemamba_b1.pth
This should give Top-1: 79.948 (Top-5: 94.924)
MobileMamba-B1†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b1s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B1s/mobilemamba_b1s.pth
This should give Top-1: 82.234 (Top-5: 95.872)
MobileMamba-B2
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b2 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B2/mobilemamba_b2.pth
This should give Top-1: 81.624 (Top-5: 95.890)
MobileMamba-B2†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b2s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B2s/mobilemamba_b2s.pth
This should give Top-1: 83.260 (Top-5: 96.438)
MobileMamba-B4
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b4 -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B4/mobilemamba_b4.pth
This should give Top-1: 82.496 (Top-5: 96.252)
MobileMamba-B4†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b4s -m test model.model_kwargs.checkpoint_path=weights/MobileMamba_B4s/mobilemamba_b4s.pth
This should give Top-1: 83.644 (Top-5: 96.606)
Train with 8 GPUs in one node:
MobileMamba-T2
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t2 -m train
MobileMamba-T2†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t2s -m train
MobileMamba-T4
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t4 -m train
MobileMamba-T4†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_t4s -m train
MobileMamba-S6
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_s6 -m train
MobileMamba-S6†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_s6s -m train
MobileMamba-B1
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b1 -m train
MobileMamba-B1†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b1s -m train
MobileMamba-B2
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b2 -m train
MobileMamba-B2†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b2s -m train
MobileMamba-B4
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b4 -m train
MobileMamba-B4†
python3 -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --use_env run.py -c configs/mobilemamba/mobilemamba_b4s -m train
pip3 install terminaltables pycocotools prettytable xtcocotools
pip3 install mmpretrain==1.2.0 mmdet==3.3.0 mmsegmentation==1.2.2
pip3 install mmcv==2.1.0 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.1/index.html
cd det/backbones/lib_mamba/kernels/selective_scan && pip install . && cd ../../../..
Download and extract COCO2017 and ADE20k dataset in the following directory structure:
downstream
├── det
├──── data
│ ├──── coco
│ │ ├──── annotations
│ │ ├──── train2017
│ │ ├──── val2017
│ │ ├──── test2017
├── seg
├──── data
│ ├──── ade
│ │ ├──── ADEChallengeData2016
│ │ ├──────── annotations
│ │ ├──────── images
Mask-RCNN
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/mask_rcnn/mask-rcnn_mobilemamba_b1_fpn_1x_coco.py 4
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh configs/mask_rcnn/mask-rcnn_mobilemamba_b1_fpn_1x_coco.py ../../weights/downstream/det/maskrcnn.pth 4
RetinaNet
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/retinanet/retinanet_mobilemamba_b1_fpn_1x_coco.py 4
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh configs/retinanet/retinanet_mobilemamba_b1_fpn_1x_coco.py ../../weights/downstream/det/retinanet.pth 4
SSDLite
./tools/dist_train.sh configs/ssd/ssdlite_mobilemamba_b1_8gpu_2lr_coco.py 8
./tools/dist_test.sh configs/ssd/ssdlite_mobilemamba_b1_8gpu_2lr_coco.py ../../weights/downstream/det/ssdlite.pth 8
./tools/dist_train.sh configs/ssd/ssdlite_mobilemamba_b1_8gpu_2lr_512_coco.py 8
./tools/dist_test.sh configs/ssd/ssdlite_mobilemamba_b1_8gpu_2lr_512_coco.py ../../weights/downstream/det/ssdlite_512.pth 8
DeepLabV3
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/deeplabv3/deeplabv3_mobilemamba_b4-80k_ade20k-512x512.py 4
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh configs/deeplabv3/deeplabv3_mobilemamba_b4-80k_ade20k-512x512.py ../../weights/downstream/seg/deeplabv3.pth 4
Semantic FPN
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/sem_fpn/fpn_mobilemamba_b4-160k_ade20k-512x512.py 4
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh configs/sem_fpn/fpn_mobilemamba_b4-160k_ade20k-512x512.py ../../weights/downstream/seg/fpn.pth 4
PSPNet
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_train.sh configs/pspnet/pspnet_mobilemamba_b4-80k_ade20k-512x512.py 4
CUDA_VISIBLE_DEVICES=0,1,2,3 ./tools/dist_test.sh configs/pspnet/pspnet_mobilemamba_b4-80k_ade20k-512x512.py ../../weights/downstream/seg/pspnet.pth 4
If our work is helpful for your research, please consider citing:
@article{mobilemamba,
title={MobileMamba: Lightweight Multi-Receptive Visual Mamba Network},
author={Haoyang He and Jiangning Zhang and Yuxuan Cai and Hongxu Chen and Xiaobin Hu and Zhenye Gan and Yabiao Wang and Chengjie Wang and Yunsheng Wu and Lei Xie},
journal={arXiv preprint arXiv:2411.15941},
year={2024}
}
We thank but not limited to following repositories for providing assistance for our research: