基于paddle框架的Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering实现
本项目基于paddle复现Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering中所提出的基于bottom-up
和top-down
注意力机制的Image Captioning
模型。论文作者提出了著名的bottom-up
注意力机制,与以往的grid-level
的注意力不同,作者提出了object-level
的注意力机制。作者将该注意力机制应用到image captioning
和visual question answering (vqa)
任务中,均取得了显著的效果。
论文:
- [1] P. Anderson, X. He, C. Buehler, D. Teney, M. Johnson, S. Gould, L. Zhang, "Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering", CVPR, 2018.
参考项目:
该指标为模型在COCO2014的测试集评估而得
指标 | 原论文 | 复现精度 |
---|---|---|
BlEU-1 | 0.798 | 0.791 |
本项目所使用的数据集为COCO2014。该数据集共包含123287张图像,每张图像对应5个标题。训练集、验证集和测试集分别为113287、5000、5000张图像及其对应的标题。本项目使用作者提供的预提取的bottom-up
特征,可以从这里下载得到(我们提供了脚本下载该数据集的标题以及图像特征,见download_dataset.sh)。
-
硬件:CPU、GPU
-
软件:
- Python 3.8
- Java 1.8.0
- PaddlePaddle == 2.1.0
# clone this repo
git clone /~https://github.com/fuqianya/bottom-up-attention-paddle.git
cd bottom-up-attention-paddle
pip install -r requirements.txt
# 下载数据集及特征
bash ./download_dataset.sh
python prepro.py
训练过程过程分为两步(详情见论文3.3节):
-
Training with Cross Entropy (XE) Loss
python train.py --train_mode xe --learning_rate 4e-4
-
CIDEr-D Score Optimization
python train.py --train_mode rl --learning_rate 4e-5 --resume ./checkpoint/xe/epoch_25.pth
python eval.py --train_mode rl --eval_model ./checkpoint/rl/epoch_25.pth --result_file epoch25_results.json
模型下载: 谷歌云盘
将下载的模型权重放到checkpoints
目录下, 运行step6
的指令进行测试。
├── checkpoint # 存储训练的模型
├── config
│ └── config.py # 模型的参数设置
├── data # 预处理的数据
├── model
│ └── captioner.py # 定义模型结构
│ └── dataloader.py # 加载训练数据
│ └── loss.py # 定义损失函数
├── pyutils
│ └── cap_eval # 计算评价指标工具
│ └── self_critical # rl阶段计算reward工具
├── result # 存放生成的标题
├── utils
│ └── utils.py # 工具类
├── download_dataset.sh # 数据集下载脚本
├── prepro.py # 数据预处理
├── train.py # 训练主函数
├── eval.py # 测试主函数
└── requirement.txt # 依赖包
模型、训练的所有参数信息都在config.py
中进行了详细注释,详情见config/config.py
。
关于模型的其他信息,可以参考下表:
信息 | 说明 |
---|---|
发布者 | fuqianya |
时间 | 2021.08 |
框架版本 | Paddle 2.1.0 |
应用场景 | 多模态 |
支持硬件 | GPU、CPU |
下载链接 | 预训练模型 | 训练日志 |