Skip to content

Halton Scheduler for Masked Generative Image Transformer

License

Notifications You must be signed in to change notification settings

valeoai/Halton-MaskGIT

Repository files navigation

🌟 Halton Scheduler for Masked Generative Image Transformer 🌟

GitHub stars Hugging Face Model Open In Colab License Paper drawing

Official PyTorch implementation of the paper:
Halton Scheduler for Masked Generative Image Transformer
Victor Besnier, Mickael Chen, David Hurych, Eduardo Valle, Matthieu Cord
Accepted at ICLR 2025.

TL;DR: We introduce a new sampling strategy using the Halton Scheduler, which spreads tokens uniformly across the image. This approach reduces sampling errors, and improves image quality.


πŸš€ Overview

Welcome to the official implementation of our ICLR 2025 paper! πŸŽ‰

This repository introduces Halton Scheduler for Masked Generative Image Transformer (MaskGIT) and includes:

  1. Class-to-Image Model: Generates high-quality 384x384 images from ImageNet class labels.

Cls2Img

  1. Text-to-Image Model: Generates realistic images from textual descriptions (coming soon)

Txt2Img

Explore, train, and extend our easy to use generative models! πŸš€

The v1.0 version, previously known as "MaskGIT-pytorch" is available here!


πŸ“ Repository Structure

β”œ Halton-MaskGIT/
|    β”œβ”€β”€ Congig/                                <- Base config file for the demo
|    |      β”œβ”€β”€ base_cls2img.yaml                                
|    |      └── base_txt2img.yaml               
|    β”œβ”€β”€ Dataset/                               <- Data loading utilities
|    |      β”œβ”€β”€ dataset.py                      <- PyTorch dataset class                   
|    |      └── dataloader.py                   <- PyTorch dataloader
|    β”œβ”€β”€ launch/                             
|    |      β”œβ”€β”€ run_cls_to_img.sh               <- Training script for class-to-image   
|    |      └── run_txt_to_img.sh               <- Training script for text-to-image (coming soon) 
|    β”œβ”€β”€ Metrics/                             
|    |      β”œβ”€β”€ extract_train_fid.py            <- Precompute FID stats for ImageNet    
|    |      β”œβ”€β”€ inception_metrics.py            <- Inception score and FID evaluation
|    |      └── sample_and_eval.py              <- Sampling and evaluation
|    β”œβ”€β”€ Network/                             
|    |      β”œβ”€β”€ ema.py                          <- EMA model 
|    |      β”œβ”€β”€ transformer.py                  <- Transformer for class-to-image   
|    |      β”œβ”€β”€ txt_transformer.py              <- Transformer for text-to-image (coming soon)
|    |      └── va_model.py                     <- VQGAN architecture  
|    β”œβ”€β”€ Sampler/                             
|    |      β”œβ”€β”€ confidence_sampler.py           <- Confidence scheduler   
|    |      └── halton_sampler.py               <- Halton scheduler  
|    β”œβ”€β”€ Trainer/                               <- Training classes
|    |      β”œβ”€β”€ abstract_trainer.py             <- Abstract trainer     
|    |      β”œβ”€β”€ cls_trainer.py                  <- Class-to-image trainer     
|    |      └── txt_trainer.py                  <- Text-to-image trainer (coming soon)
|    β”œβ”€β”€ statics/                               <- Sample images and assets
|    β”œβ”€β”€ saved_networks/                        <- placeholder for the downloaded models
|    β”œβ”€β”€ colab_demo.ipynb                       <- Inference demo 
|    β”œβ”€β”€ app.py                                 <- Gradio example
|    β”œβ”€β”€ LICENSE.txt                            <- MIT license
|    β”œβ”€β”€ env.yaml                               <- Environment setup file
|    β”œβ”€β”€ README.md                              <- This file! πŸ“–
|    └── main.py                                <- Main script

πŸ› οΈ Usage

Get started with just a few steps:

1️⃣ Clone the repository

git clone https://github.com/valeoai/Halton-MaskGIT.git
cd Halton-MaskGIT

2️⃣ Install dependencies

conda env create -f env.yaml
conda activate maskgit

3️⃣ Download pretrained models

from huggingface_hub import hf_hub_download
# The VQ-GAN
hf_hub_download(repo_id="FoundationVision/LlamaGen", 
                filename="vq_ds16_c2i.pt", 
                local_dir="./saved_networks/")

# (Optional) The MaskGIT
hf_hub_download(repo_id="llvictorll/Halton-Maskgit", 
                filename="ImageNet_384_large.pth", 
                local_dir="./saved_networks/")

4️⃣ Extract the code from the VQGAN

python extract_vq_features.py --data_folder="/path/to/ImageNet/" --dest_folder="/your/path/" --bsize=256 --compile

5️⃣ Train the model

To train the class-to-image model:

bash launch/run_cls_to_img.sh

πŸ“Ÿ Quick Start for sampling

To quickly verify the functionality of our model, you can try this Python code:

import torch
from Utils.utils import load_args_from_file
from Utils.viz import show_images_grid
from huggingface_hub import hf_hub_download

from Trainer.cls_trainer import MaskGIT
from Sampler.halton_sampler import HaltonSampler

config_path = "Config/base_cls2img.yaml"        # Path to your config file
args = load_args_from_file(config_path)
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Download the VQGAN from LlamaGen 
hf_hub_download(repo_id="FoundationVision/LlamaGen", 
                filename="vq_ds16_c2i.pt", 
                local_dir="./saved_networks/")

# Download the MaskGIT
hf_hub_download(repo_id="llvictorll/Halton-Maskgit", 
                filename="ImageNet_384_large.pth", 
                local_dir="./saved_networks/")

# Initialisation of the model
model = MaskGIT(args)

# select your scheduler
sampler = HaltonSampler(sm_temp_min=1, sm_temp_max=1.2, temp_pow=1, temp_warmup=0, w=2,
                        sched_pow=2, step=32, randomize=True, top_k=-1)

# [goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner]
labels = [1, 7, 282, 604, 724, 179, 751, 404] 

gen_images = sampler(trainer=model, nb_sample=8, labels=labels, verbose=True)[0]
show_images_grid(gen_images)

or run the gradio πŸ–ΌοΈ app.py --> python app.py and connect to http://127.0.0.1:6006 on your navigator

🎨 Want to try the model, but you don't have a gpu? Check out the Colab Notebook for an easy-to-run demo! Open In Colab

🧠 Pretrained Models

The pretrained MaskGIT models are available on Hugging Face. Use them to jump straight into inference or fine-tuning.

Model # Params # Input # GFLOP VQGAN MaskGIT
Halton-MaskGIT-Large 480M 24x24 83.00 πŸ”— Download πŸ”— Download

❀️ Contribute

We welcome contributions and feedback! πŸ› οΈ If you encounter any issues, have suggestions, or want to collaborate, feel free to:

  • Create an issue
  • Fork the repository and submit a pull request

Your input is highly valued. Let’s make this project even better together! πŸ™Œ

πŸ“œ License

This project is licensed under the MIT License. See the LICENSE file for details.

πŸ™ Acknowledgments

We are grateful for the support of the IT4I Karolina Cluster in the Czech Republic for powering our experiments.

The pretrained VQGAN ImageNet (f=16/8, 16384 codebook) is from the LlamaGen official repository

πŸ“– Citation

If you find our work useful, please cite us and add a star ⭐ to the repository :)

@inproceedings{besnier2025iclr,
  title={Halton Scheduler for Masked Generative Image Transformer},
  author={Victor Besnier, Mickael Chen, David Hurych, Eduardo Valle, Matthieu Cord},
  booktitle={International Conference on Learning Representations (ICLR)},
  year={2025}
}

⭐ Stars History

Star History Chart