π Halton Scheduler for Masked Generative Image Transformer π
Official PyTorch implementation of the paper:
Halton Scheduler for Masked Generative Image Transformer
Victor Besnier, Mickael Chen, David Hurych, Eduardo Valle, Matthieu Cord
Accepted at ICLR 2025.
TL;DR: We introduce a new sampling strategy using the Halton Scheduler, which spreads tokens uniformly across the image. This approach reduces sampling errors, and improves image quality.
π Overview
Welcome to the official implementation of our ICLR 2025 paper! π
This repository introduces Halton Scheduler for Masked Generative Image Transformer (MaskGIT) and includes:
- Class-to-Image Model: Generates high-quality 384x384 images from ImageNet class labels.
- Text-to-Image Model: Generates realistic images from textual descriptions (coming soon)
Explore, train, and extend our easy to use generative models! π
The v1.0 version, previously known as "MaskGIT-pytorch" is available here!
π Repository Structure
β Halton-MaskGIT/
| βββ Congig/ <- Base config file for the demo
| | βββ base_cls2img.yaml
| | βββ base_txt2img.yaml
| βββ Dataset/ <- Data loading utilities
| | βββ dataset.py <- PyTorch dataset class
| | βββ dataloader.py <- PyTorch dataloader
| βββ launch/
| | βββ run_cls_to_img.sh <- Training script for class-to-image
| | βββ run_txt_to_img.sh <- Training script for text-to-image (coming soon)
| βββ Metrics/
| | βββ extract_train_fid.py <- Precompute FID stats for ImageNet
| | βββ inception_metrics.py <- Inception score and FID evaluation
| | βββ sample_and_eval.py <- Sampling and evaluation
| βββ Network/
| | βββ ema.py <- EMA model
| | βββ transformer.py <- Transformer for class-to-image
| | βββ txt_transformer.py <- Transformer for text-to-image (coming soon)
| | βββ va_model.py <- VQGAN architecture
| βββ Sampler/
| | βββ confidence_sampler.py <- Confidence scheduler
| | βββ halton_sampler.py <- Halton scheduler
| βββ Trainer/ <- Training classes
| | βββ abstract_trainer.py <- Abstract trainer
| | βββ cls_trainer.py <- Class-to-image trainer
| | βββ txt_trainer.py <- Text-to-image trainer (coming soon)
| βββ statics/ <- Sample images and assets
| βββ saved_networks/ <- placeholder for the downloaded models
| βββ colab_demo.ipynb <- Inference demo
| βββ app.py <- Gradio example
| βββ LICENSE.txt <- MIT license
| βββ env.yaml <- Environment setup file
| βββ README.md <- This file! π
| βββ main.py <- Main script
π οΈ Usage
Get started with just a few steps:
1οΈβ£ Clone the repository
git clone https://github.com/valeoai/Halton-MaskGIT.git
cd Halton-MaskGIT
2οΈβ£ Install dependencies
conda env create -f env.yaml
conda activate maskgit
3οΈβ£ Download pretrained models
from huggingface_hub import hf_hub_download
# The VQ-GAN
hf_hub_download(repo_id="FoundationVision/LlamaGen",
filename="vq_ds16_c2i.pt",
local_dir="./saved_networks/")
# (Optional) The MaskGIT
hf_hub_download(repo_id="llvictorll/Halton-Maskgit",
filename="ImageNet_384_large.pth",
local_dir="./saved_networks/")
4οΈβ£ Extract the code from the VQGAN
python extract_vq_features.py --data_folder="/path/to/ImageNet/" --dest_folder="/your/path/" --bsize=256 --compile
5οΈβ£ Train the model
To train the class-to-image model:
bash launch/run_cls_to_img.sh
π Quick Start for sampling
To quickly verify the functionality of our model, you can try this Python code:
import torch
from Utils.utils import load_args_from_file
from Utils.viz import show_images_grid
from huggingface_hub import hf_hub_download
from Trainer.cls_trainer import MaskGIT
from Sampler.halton_sampler import HaltonSampler
config_path = "Config/base_cls2img.yaml" # Path to your config file
args = load_args_from_file(config_path)
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download the VQGAN from LlamaGen
hf_hub_download(repo_id="FoundationVision/LlamaGen",
filename="vq_ds16_c2i.pt",
local_dir="./saved_networks/")
# Download the MaskGIT
hf_hub_download(repo_id="llvictorll/Halton-Maskgit",
filename="ImageNet_384_large.pth",
local_dir="./saved_networks/")
# Initialisation of the model
model = MaskGIT(args)
# select your scheduler
sampler = HaltonSampler(sm_temp_min=1, sm_temp_max=1.2, temp_pow=1, temp_warmup=0, w=2,
sched_pow=2, step=32, randomize=True, top_k=-1)
# [goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner]
labels = [1, 7, 282, 604, 724, 179, 751, 404]
gen_images = sampler(trainer=model, nb_sample=8, labels=labels, verbose=True)[0]
show_images_grid(gen_images)
or run the gradio πΌοΈ app.py --> python app.py
and connect to http://127.0.0.1:6006 on your navigator
π¨ Want to try the model, but you don't have a gpu? Check out the Colab Notebook for an easy-to-run demo!
π§ Pretrained Models
The pretrained MaskGIT models are available on Hugging Face. Use them to jump straight into inference or fine-tuning.
Model | # Params | # Input | # GFLOP | VQGAN | MaskGIT |
---|---|---|---|---|---|
Halton-MaskGIT-Large | 480M | 24x24 | 83.00 | π Download | π Download |
β€οΈ Contribute
We welcome contributions and feedback! π οΈ If you encounter any issues, have suggestions, or want to collaborate, feel free to:
- Create an issue
- Fork the repository and submit a pull request
Your input is highly valued. Letβs make this project even better together! π
π License
This project is licensed under the MIT License. See the LICENSE file for details.
π Acknowledgments
We are grateful for the support of the IT4I Karolina Cluster in the Czech Republic for powering our experiments.
The pretrained VQGAN ImageNet (f=16/8, 16384 codebook) is from the LlamaGen official repository
π Citation
If you find our work useful, please cite us and add a star β to the repository :)
@inproceedings{besnier2025iclr,
title={Halton Scheduler for Masked Generative Image Transformer},
author={Victor Besnier, Mickael Chen, David Hurych, Eduardo Valle, Matthieu Cord},
booktitle={International Conference on Learning Representations (ICLR)},
year={2025}
}