This repository provides the code and model checkpoints of the research paper: Scalable Pre-training of Large Autoregressive Image Models
OTHER License
Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar, Joshua M Susskind, and Armand Joulin
To appear at ICML 2024
This software project accompanies the research paper, Scalable Pre-training of Large Autoregressive Image Models.
We introduce AIM a collection of vision models pre-trained with an autoregressive generative objective. We show that autoregressive pre-training of image features exhibits similar scaling properties to their textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:
Please install PyTorch using the official installation instructions. Afterward, install the package as:
pip install git+https://[email protected]/apple/ml-aim.git
We also offer MLX backend support for research and experimentation on Apple silicon. To enable MLX support, simply run:
pip install mlx
Below we provide an example of usage in PyTorch:
from PIL import Image
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="torch")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
logits, features = model(inp)
from PIL import Image
import mlx.core as mx
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="mlx")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
inp = mx.array(inp.numpy())
logits, features = model(inp)
from PIL import Image
import jax.numpy as jnp
from aim.utils import load_pretrained
from aim.torch.data import val_transforms
img = Image.open(...)
model, params = load_pretrained("aim-600M-2B-imgs", backend="jax")
transform = val_transforms()
inp = transform(img).unsqueeze(0)
inp = jnp.array(inp)
(logits, features), _ = model.apply(params, inp, mutable=['batch_stats'])
The pre-trained models can be accessed via PyTorch Hub as:
import torch
aim_600m = torch.hub.load("apple/ml-aim", "aim_600M")
aim_1b = torch.hub.load("apple/ml-aim", "aim_1B")
aim_3b = torch.hub.load("apple/ml-aim", "aim_3B")
aim_7b = torch.hub.load("apple/ml-aim", "aim_7B")
or via HuggingFace Hub as:
from aim.torch.models import AIMForImageClassification
aim_600m = AIMForImageClassification.from_pretrained("apple/aim-600M")
aim_1b = AIMForImageClassification.from_pretrained("apple/aim-1B")
aim_3b = AIMForImageClassification.from_pretrained("apple/aim-3B")
aim_7b = AIMForImageClassification.from_pretrained("apple/aim-7B")
The following table contains pre-trained backbones used in our paper.
The table below contains the classification results on ImageNet-1k validation set.
The commands below reproduce the attention probe results on ImageNet-1k validation set. We run the evaluation using 1 node with 8 GPUs:
torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \
--model=aim-7B \
--batch-size=64 \
--data-path=/path/to/imagenet \
--probe-layers=best \
--backbone-ckpt-path=/path/to/backbone_ckpt.pth \
--head-ckpt-path=/path/to/head_ckpt.pth
By default, we probe features from the intermediate 6 layers that provide the best performance. To change this, simply pass --probe-layers=last
.
If you find our work useful, please consider citing us as:
@article{el2024scalable,
title={Scalable Pre-training of Large Autoregressive Image Models},
author={El-Nouby, Alaaeldin and Klein, Michal and Zhai, Shuangfei and Bautista, Miguel Angel and Toshev, Alexander and Shankar, Vaishaal and Susskind, Joshua M and Joulin, Armand},
journal={International Conference on Machine Learning},
year={2024}
}