Official PyTorch implementation of the paper:
Halton Scheduler for Masked Generative Image Transformer
Victor Besnier, Mickael Chen, David Hurych, Eduardo Valle, Matthieu Cord
Accepted at ICLR 2025.
TL;DR: We introduce a new sampling strategy using the Halton Scheduler, which spreads tokens uniformly across the image. This approach reduces sampling errors, and improves image quality.
Welcome to the official implementation of our ICLR 2025 paper! π
This repository introduces Halton Scheduler for Masked Generative Image Transformer (MaskGIT) and includes:
- Class-to-Image Model: Generates high-quality 384x384 images from ImageNet class labels.
- Text-to-Image Model: Generates realistic images from textual descriptions (coming soon)
Explore, train, and extend our easy to use generative models! π
The v1.0 version, previously known as "MaskGIT-pytorch" is available here!
β Halton-MaskGIT/
| βββ Congig/ <- Base config file for the demo
| | βββ base_cls2img.yaml
| | βββ base_txt2img.yaml
| βββ Dataset/ <- Data loading utilities
| | βββ dataset.py <- PyTorch dataset class
| | βββ dataloader.py <- PyTorch dataloader
| βββ launch/
| | βββ run_cls_to_img.sh <- Training script for class-to-image
| | βββ run_txt_to_img.sh <- Training script for text-to-image (coming soon)
| βββ Metrics/
| | βββ extract_train_fid.py <- Precompute FID stats for ImageNet
| | βββ inception_metrics.py <- Inception score and FID evaluation
| | βββ sample_and_eval.py <- Sampling and evaluation
| βββ Network/
| | βββ ema.py <- EMA model
| | βββ transformer.py <- Transformer for class-to-image
| | βββ txt_transformer.py <- Transformer for text-to-image (coming soon)
| | βββ va_model.py <- VQGAN architecture
| βββ Sampler/
| | βββ confidence_sampler.py <- Confidence scheduler
| | βββ halton_sampler.py <- Halton scheduler
| βββ Trainer/ <- Training classes
| | βββ abstract_trainer.py <- Abstract trainer
| | βββ cls_trainer.py <- Class-to-image trainer
| | βββ txt_trainer.py <- Text-to-image trainer (coming soon)
| βββ statics/ <- Sample images and assets
| βββ saved_networks/ <- placeholder for the downloaded models
| βββ colab_demo.ipynb <- Inference demo
| βββ app.py <- Gradio example
| βββ LICENSE.txt <- MIT license
| βββ env.yaml <- Environment setup file
| βββ README.md <- This file! π
| βββ main.py <- Main script
Get started with just a few steps:
git clone https://github.com/valeoai/Halton-MaskGIT.git
cd Halton-MaskGIT
conda env create -f env.yaml
conda activate maskgit
from huggingface_hub import hf_hub_download
# The VQ-GAN
hf_hub_download(repo_id="FoundationVision/LlamaGen",
filename="vq_ds16_c2i.pt",
local_dir="./saved_networks/")
# (Optional) The MaskGIT
hf_hub_download(repo_id="llvictorll/Halton-Maskgit",
filename="ImageNet_384_large.pth",
local_dir="./saved_networks/")
python extract_vq_features.py --data_folder="/path/to/ImageNet/" --dest_folder="/your/path/" --bsize=256 --compile
To train the class-to-image model:
bash launch/run_cls_to_img.sh
To quickly verify the functionality of our model, you can try this Python code:
import torch
from Utils.utils import load_args_from_file
from Utils.viz import show_images_grid
from huggingface_hub import hf_hub_download
from Trainer.cls_trainer import MaskGIT
from Sampler.halton_sampler import HaltonSampler
config_path = "Config/base_cls2img.yaml" # Path to your config file
args = load_args_from_file(config_path)
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Download the VQGAN from LlamaGen
hf_hub_download(repo_id="FoundationVision/LlamaGen",
filename="vq_ds16_c2i.pt",
local_dir="./saved_networks/")
# Download the MaskGIT
hf_hub_download(repo_id="llvictorll/Halton-Maskgit",
filename="ImageNet_384_large.pth",
local_dir="./saved_networks/")
# Initialisation of the model
model = MaskGIT(args)
# select your scheduler
sampler = HaltonSampler(sm_temp_min=1, sm_temp_max=1.2, temp_pow=1, temp_warmup=0, w=2,
sched_pow=2, step=32, randomize=True, top_k=-1)
# [goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner]
labels = [1, 7, 282, 604, 724, 179, 751, 404]
gen_images = sampler(trainer=model, nb_sample=8, labels=labels, verbose=True)[0]
show_images_grid(gen_images)
or run the gradio πΌοΈ app.py --> python app.py
and connect to http://127.0.0.1:6006 on your navigator
π¨ Want to try the model, but you don't have a gpu? Check out the Colab Notebook for an easy-to-run demo!
The pretrained MaskGIT models are available on Hugging Face. Use them to jump straight into inference or fine-tuning.
Model | # Params | # Input | # GFLOP | VQGAN | MaskGIT |
---|---|---|---|---|---|
Halton-MaskGIT-Large | 480M | 24x24 | 83.00 | π Download | π Download |
We welcome contributions and feedback! π οΈ If you encounter any issues, have suggestions, or want to collaborate, feel free to:
- Create an issue
- Fork the repository and submit a pull request
Your input is highly valued. Letβs make this project even better together! π
This project is licensed under the MIT License. See the LICENSE file for details.
We are grateful for the support of the IT4I Karolina Cluster in the Czech Republic for powering our experiments.
The pretrained VQGAN ImageNet (f=16/8, 16384 codebook) is from the LlamaGen official repository
If you find our work useful, please cite us and add a star β to the repository :)
@inproceedings{besnier2025iclr,
title={Halton Scheduler for Masked Generative Image Transformer},
author={Victor Besnier, Mickael Chen, David Hurych, Eduardo Valle, Matthieu Cord},
booktitle={International Conference on Learning Representations (ICLR)},
year={2025}
}