Pytorch implementation of the paper "ALICE: Adapt your Learnable Image Compression modEl for variable bitrates", published at VCIP 2024. This repository is based on CompressAI and STF.
When training a Learned Image Compression model, the loss function is minimized such that the encoder and the decoder attain a target Rate-Distorsion trade-off. Therefore, a distinct model shall be trained and stored at the transmitter and receiver for each target rate, fostering the quest for efficient vari- able bitrate compression schemes. This paper proposes plugging Low-Rank Adapters into a transformer-based pre-trained LIC model and training them to meet different target rates. With our method, encoding an image at a variable rate is as simple as training the corresponding adapters and plugging them into the frozen pre-trained model. Our experiments show performance comparable with state-of-the-art fixed-rate LIC models at a fraction of the training and deployment cost.
- First download openimages dataset using
src/downloader_openimages.py
. - For evaluation and validation during training also download KODAK dataset.
The script for downloading OpenImages is provided in downloader_openimages.py
. Please install fiftyone first.
- conda env create -f environment.yml
- conda activate alice
- Download our pretrained model in the ALICE directory from here.
- Extract pretrained.zip
ALICE
│ README.md
│
└───pretrained
│ └───adapt_0483_seed_42_conf_lora_8_8_opt_adam_sched_cosine_lr_0_0001
│ │ inference.json
| | inference_merge.json
│ │ 013_checkpoint_best.pth.tar
| | 025_checkpoint_best.pth.tar
│ │ ...
|
│ adapt_0483_seed_42_conf_vanilla_adapt_opt_adam_sched_cosine_lr_0_0001
│ │ inference.json
│ │ 013_checkpoint_best.pth.tar
| | 025_checkpoint_best.pth.tar
│ │ ...
|
│ stf
│ │ inference.json
│ │ stf_013_best.pth.tar
| | stf_025_best.pth.tar
│ │ ...
| |
└───src
│ train.py
│ ...
- Finally, run the following command:
cd src
python -m evaluate.eval --test-dir /test/to/kodak/ --file-name results_kodak --save-path /path/to/save/results
To evaluate your own model follow the same structure proposed above:
- First create a folder named
pretrained/your_model
. - In this folder collect all of the trained models making sure that each one ends with best.pth.tar, for example:
pretrained/your_model/your_model_013_best.pth.tar
pretrained/your_model/your_model_0018_best.pth.tar
pretrained/your_model/your_model_025_best.pth.tar
- ...
- Create a file named
pretrained/your_model/inference.json
in the same folder in this way:
{
"model":"your_model",
"checkpoints_path":"pretrained/your_model"
}
- You can compare multiple models at the same time creating a folder with an inference.json file and saving the chekpoints.
- List all of the model that you want to compare in the variable
configs
in the main function ofsrc/evaluate/eval.py
You can evaluate MALICE including in the configs
variable defined in the main function of src/evaluate/eval.py
the following configuration:
../pretrained/adapt_0483_seed_42_conf_lora_8_8_opt_adam_sched_cosine_lr_0_0001/inference_merge.json
The training script is provided in
src/train.py
.
cd src
python train.py --batch-size=16 --checkpoint=../pretrained/stf/stf_0483_best.pth.tar --cuda=1 --dataset=../../../data/openimages/ --epochs=15 --lambda=0.013 --learning-rate=0.0001 --lora=1 --lora-config=../configs/lora_8_8.yml --lora-opt=adam --lora-sched=cosine --model=stf --save=1 --save-dir=../results/adapt_models_lora/adapt_0483 --test-dir=../../../data/kodak/
cd src
python train.py --batch-size=16 --checkpoint=../pretrained/stf/stf_0483_best.pth.tar --cuda=1 --dataset=../../../data/openimages/ --epochs=15 --lambda=0.013 --learning-rate=0.0001 --lora=1 --lora-opt=adam --lora-sched=cosine --model=stf --save=1 --save-dir=../results/adapt_models_vanilla/adapt_0483 --test-dir=../../../data/kodak/ --vanilla-adapt=1
Sweep files to adapt the model for all lambda values are saved in:
- Rate-Distortion
- Complexity Considerations
Pretrained models (optimized for MSE) trained from scratch using randomly chose 300k images from the OpenImages dataset.
Method | Lambda | Link |
---|---|---|
STF | 0.0018 | stf_0018 |
STF | 0.0035 | stf_0035 |
STF | 0.0067 | stf_0067 |
STF | 0.013 | stf_013 |
STF | 0.025 | stf_025 |
STF | 0.0483 | stf_0483 |
- STF: https://github.com/Googolxx/STF
- CompressAI: https://github.com/InterDigitalInc/CompressAI
- Swin-Transformer: https://github.com/microsoft/Swin-Transformer
- Tensorflow compression library by Ballé et al.: https://github.com/tensorflow/compression
- Range Asymmetric Numeral System code from Fabian 'ryg' Giesen: https://github.com/rygorous/ryg_rans
- Kodak Images Dataset: http://r0k.us/graphics/kodak/
- Open Images Dataset: https://github.com/openimages
- fiftyone: https://github.com/voxel51/fiftyone
- CLIC: https://www.compression.cc/