Skip to content

fuqianya/bottom-up-attention-paddle

Repository files navigation

bottom-up-attention-paddle

基于paddle框架的Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering实现

一、简介

本项目基于paddle复现Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering中所提出的基于bottom-uptop-down注意力机制的Image Captioning模型。论文作者提出了著名的bottom-up注意力机制,与以往的grid-level的注意力不同,作者提出了object-level的注意力机制。作者将该注意力机制应用到image captioningvisual question answering (vqa)任务中,均取得了显著的效果。

论文:

  • [1] P. Anderson, X. He, C. Buehler, D. Teney, M. Johnson, S. Gould, L. Zhang, "Bottom-Up and Top-Down Attention for Image Captioning and Visual Question Answering", CVPR, 2018.

参考项目:

二、复现精度

该指标为模型在COCO2014的测试集评估而得

指标 原论文 复现精度
BlEU-1 0.798 0.791

三、数据集

本项目所使用的数据集为COCO2014。该数据集共包含123287张图像,每张图像对应5个标题。训练集、验证集和测试集分别为113287、5000、5000张图像及其对应的标题。本项目使用作者提供的预提取的bottom-up特征,可以从这里下载得到(我们提供了脚本下载该数据集的标题以及图像特征,见download_dataset.sh)。

四、环境依赖

  • 硬件:CPU、GPU

  • 软件:

    • Python 3.8
    • Java 1.8.0
    • PaddlePaddle == 2.1.0

五、快速开始

step1: clone

# clone this repo
git clone /~https://github.com/fuqianya/bottom-up-attention-paddle.git
cd bottom-up-attention-paddle

step2: 安装环境及依赖

pip install -r requirements.txt

step3: 下载数据

# 下载数据集及特征
bash ./download_dataset.sh

step4: 数据集预处理

python prepro.py

step5: 训练

训练过程过程分为两步(详情见论文3.3节):

  • Training with Cross Entropy (XE) Loss

    python train.py --train_mode xe --learning_rate 4e-4
  • CIDEr-D Score Optimization

    python train.py --train_mode rl --learning_rate 4e-5 --resume ./checkpoint/xe/epoch_25.pth

step6: 测试

python eval.py --train_mode rl --eval_model ./checkpoint/rl/epoch_25.pth --result_file epoch25_results.json

使用预训练模型进行预测

模型下载: 谷歌云盘

将下载的模型权重放到checkpoints目录下, 运行step6的指令进行测试。

六、代码结构与详细说明

├── checkpoint          # 存储训练的模型
├── config
│  └── config.py        # 模型的参数设置
├── data                # 预处理的数据
├── model
│   └── captioner.py    # 定义模型结构
│   └── dataloader.py   # 加载训练数据
│   └── loss.py         # 定义损失函数
├── pyutils 
│   └── cap_eval        # 计算评价指标工具
│   └── self_critical   # rl阶段计算reward工具
├── result              # 存放生成的标题
├── utils 
│   └── utils.py        # 工具类
├── download_dataset.sh # 数据集下载脚本
├── prepro.py           # 数据预处理
├── train.py            # 训练主函数
├── eval.py             # 测试主函数
└── requirement.txt     # 依赖包

模型、训练的所有参数信息都在config.py中进行了详细注释,详情见config/config.py

七、模型信息

关于模型的其他信息,可以参考下表:

信息 说明
发布者 fuqianya
时间 2021.08
框架版本 Paddle 2.1.0
应用场景 多模态
支持硬件 GPU、CPU
下载链接 预训练模型 | 训练日志

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published