Skip to content

A flexible and efficient codebase for training visually-conditioned language models (VLMs)

License

Notifications You must be signed in to change notification settings

TRI-ML/prismatic-vlms

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

15 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Prismatic VLMs

arXiv PyTorch Python License

Installation | Usage | Pretrained Models | Training VLMs

A flexible and efficient codebase for training visually-conditioned language-models (VLMs):

  • Different Visual Representations. We natively support backbones such as CLIP, SigLIP, DINOv2 – and even fusions of different backbones. Adding new backbones is easy via TIMM.
  • Base and Instruct-Tuned Language Models. We support arbitrary instances of AutoModelForCausalLM including both base and instruct-tuned models (with built-in prompt handling) via Transformers. If your favorite LM isn't already supported, feel free to submit a PR!
  • Easy Scaling. Powered by PyTorch FSDP and Flash-Attention, we can quickly and efficiently train models from 1B - 34B parameters, on different, easily configurable dataset mixtures.

If you're interested in rigorously evaluating existing VLMs, check our evaluation codebase that bundles together 12 different battle-tested vision-and-language benchmarks through a clean, automated test harness.


Installation

This repository was built using Python 3.10, but should be backwards compatible with any Python >= 3.8. We require PyTorch 2.1 or greater -- installation instructions can be found here. This repository was developed and has been thoroughly tested with:

  • [2/16/24] PyTorch 2.1.0, Torchvision 0.16.0, Transformers 4.34.1, and Flash-Attention 2.3.3.
  • [3/24/24] PyTorch 2.2.1, Torchvision 0.17.0, Transformers 4.38.1, and Flash-Attention 2.5.5.

Once PyTorch has been properly installed, you can install this package locally via an editable installation (or via pip install git+/~https://github.com/TRI-ML/prismatic-vlms):

git clone /~https://github.com/TRI-ML/prismatic-vlms
cd prismatic-vlms
pip install -e .

# Training additionally requires Flash-Attention 2 (/~https://github.com/Dao-AILab/flash-attention)
pip install packaging ninja

# Verify Ninja --> should return exit code "0"
ninja --version; echo $?

# Install Flash Attention 2 
#   =>> If you run into difficulty, try `pip cache remove flash_attn` first
pip install flash-attn --no-build-isolation

If you run into any problems during the installation process, please file a GitHub Issue.

Usage

Once installed, loading and running inference with pretrained prismatic models is easy:

import requests
import torch

from PIL import Image
from pathlib import Path

from prismatic import load

# For gated LMs like Llama-2, make sure to request official access, and generate an access token
hf_token = Path(".hf_token").read_text().strip()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load a pretrained VLM (either local path, or ID to auto-download from the HF Hub) 
model_id = "prism-dinosiglip+7b"
vlm = load(model_id, hf_token=hf_token)
vlm.to(device, dtype=torch.bfloat16)

# Download an image and specify a prompt
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/beignets-task-guide.png"
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
user_prompt = "What is going on in this image?"

# Build prompt
prompt_builder = vlm.get_prompt_builder()
prompt_builder.add_turn(role="human", message=user_prompt)
prompt_text = prompt_builder.get_prompt()

# Generate!
generated_text = vlm.generate(
    image,
    prompt_text,
    do_sample=True,
    temperature=0.4,
    max_new_tokens=512,
    min_length=1,
)

For a complete terminal-based CLI for interacting with our VLMs, check out scripts/generate.py.

Pretrained Models

We release all 49 VLMs trained as part of our work, with a range of different visual representations, language models, data, and scale. The exhaustive set of models (with structured descriptions) can be found in `prismatic/models/registry.py - we will continue to update this registry as we train additional models.

We also provide a top-level API for instantiating models from the names mentioned in the various Figures of our paper, as well as for generally browsing our pretrained models by description:

from prismatic import available_model_ids_and_names, available_model_ids, get_model_description
from pprint import pprint

# List all Pretrained VLMs (by HF Hub IDs)
pprint(available_model_ids())

# List all Pretrained VLMs with both HF Hub IDs AND semantically meaningful names from paper
pprint(available_model_ids_and_names())

# Print and return a targeted description of a model (by name or ID) 
#   =>> See `prismatic/models/registry.py` for explicit schema
description = get_model_description("Prism-DINOSigLIP 13B (Controlled)")

Currently, our best performing models are the Prism-DINOSigLIP series, with especially strong performance on spatial understanding and localization tasks.


Explicit Notes on Model Licensing & Commercial Use: While all code in this repository is released under an MIT License, our pretrained models may inherit restrictions from the datasets and underlying LMs we use for training.

[02/09/24] Our current VLMs are all derived from Llama-2, and as such are subject to the Llama Community License, which does permit commercial use. We additionally train on the LLaVa Instruct Tuning data.

[05/05/24] Our new VLMs derived from Mistral and Phi-2 are subject to the original Apache and MIT Licenses attached to each model.

As we train new models, we will update this section of the README (and the LICENSE files associated with each model) appropriately. If there are any questions, please file an Issue!

Training VLMs

In addition to providing all pretrained VLMs trained in this work, we also provide full instructions and configurations for reproducing all results (down to controlling for the batch order of examples seen during training).

Pretraining Datasets

For the LLaVa v1.5 Instruct Dataset we use for all of our models, we provide an automated download script in scripts/preprocess.py:

# Download the `llava-v1.5-instruct` (Instruct Tuning) Image and Language Data (includes extra post-processing)
python scripts/preprocess.py --dataset_id "llava-v1.5-instruct" --root_dir <PATH-TO-DATA-ROOT>

# (In case you also wish to download the explicit vision-language alignment data)
python scripts/preprocess.py --dataset_id "llava-laion-cc-sbu-558k" --root_dir <PATH-TO-DATA-ROOT>

As part of our work, we also train on mixtures of datasets including LVIS-Instruct-4V and LRV-Instruct. We provide instructions and scripts for downloading these datasets in scripts/additional-datasets.

We welcome any and all contributions and pull requests to add new datasets!

Model Configuration & Training Script

The entry point for training models is scripts/pretrain.py. We employ draccus to provide a modular, dataclass-based interface for specifying model configurations; all 42 VLM configurations are in prismatic/conf/models.py.

We use PyTorch Fully Sharded Data Parallel (FSDP) to distribute training across GPUs, though we also provide a simpler Distributed Data Parallel training implementation (for smaller LM backbones, debugging). You can run a pretraining job via torchrun.

As a compact example, here's how you would train a VLM derived from Vicuña-v1.5 7B, using fused DINOv2 + SigLIP representations, processing non-square images with a "letterbox padding" transform across 8 GPUs on a single-node:

# Run from the root of the repository
torchrun --standalone --nnodes 1 --nproc-per-node 8 scripts/pretrain.py \
  --model.type "one-stage+7b" \
  --model.model_id "<NAME OF NEW MODEL>" \
  --model.vision_backbone_id "dinosiglip-vit-so-384px" \
  --model.image_resize_strategy "letterbox" \
  --model.llm_backbone_id "vicuna-v15-7b" 

Note that specifying model.type is important for identifying the base configuration that you want to build on top of; the full list of model types are available in our config file, under the model_id key for each dataclass.


Repository Structure

High-level overview of repository/project file-tree:

  • prismatic - Package source; provides core utilities for model loading, training, data preprocessing, etc.
  • scripts/ - Standalone scripts for preprocessing, training VLMs, and generating from pretrained models.
  • LICENSE - All code is made available under the MIT License; happy hacking!
  • Makefile - Top-level Makefile (by default, supports linting - checking & auto-fix); extend as needed.
  • pyproject.toml - Full project configuration details (including dependencies), as well as tool configurations.
  • README.md - You are here!

Citation

If you find our code or models useful in your work, please cite our paper:

@inproceedings{karamcheti2024prismatic,
  title = {Prismatic VLMs: Investigating the Design Space of Visually-Conditioned Language Models},
  author = {Siddharth Karamcheti and Suraj Nair and Ashwin Balakrishna and Percy Liang and Thomas Kollar and Dorsa Sadigh},
  booktitle = {International Conference on Machine Learning (ICML)},
  year = {2024},
}

About

A flexible and efficient codebase for training visually-conditioned language models (VLMs)

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published