Skip to content
/ BAD Public

The official Pytorch implementation of “BAD: Bidirectional Auto-regressive Diffusion for Text-to-Motion Generation”

Notifications You must be signed in to change notification settings

RohollahHS/BAD

Repository files navigation

Sample Image If you find our code or paper helpful, please consider starring our repository and citing us.

@article{hosseyni2024bad,
  title={BAD: Bidirectional Auto-regressive Diffusion for Text-to-Motion Generation},
  author={Hosseyni, S Rohollah and Rahmani, Ali Ahmad and Seyedmohammadi, S Jamal and Seyedin, Sanaz and Mohammadi, Arash},
  journal={arXiv preprint arXiv:2409.10847},
  year={2024}
}

News

📢 2024-09-24 --- Initialized the webpage and git project.

Get You Ready

1. Conda Environment

For training and evaluation, we used the following conda environment, which is based on the MMM environment:

conda env create -f environment.yml
conda activate bad
pip install git+https://github.com/openai/CLIP.git

We encountered issues when using the above environment for generation and visualization. As a result, we had to use a new environment. You may try changing the version of some packages from the previous environment, particularly numpy, and it might work. The new environment is based on the Momask environment, with additional packages like smplx from the MDM environment.

conda env create -f environment2.yml
conda activate bad2
pip install git+https://github.com/openai/CLIP.git

2. Models and Dependencies

Download Pre-trained Models

bash dataset/prepare/download_models.sh

Download SMPL Files

For rendering.

bash dataset/prepare/download_smpl_files.sh

Download Evaluation Models and Gloves

For evaluation only.

bash dataset/prepare/download_extractor.sh
bash dataset/prepare/download_glove.sh

Troubleshooting

To address the download error related to gdown: "Cannot retrieve the public link of the file. You may need to change the permission to 'Anyone with the link', or have had many accesses". A potential solution is to run pip install --upgrade --no-cache-dir gdown, as suggested on wkentaro/gdown#43. This should help resolve the issue.

(Optional) Download Manually

Visit [Google Drive] to download the models and evaluators mannually.

3. Get Data

HumanML3D - We are using two 3D human motion-language dataset: HumanML3D and KIT-ML. For both datasets, you could find the details as well as download link here.

./dataset/HumanML3D/
├── new_joint_vecs/
├── texts/
├── Mean.npy # same as in [HumanML3D](https://github.com/EricGuo5513/HumanML3D) 
├── Std.npy # same as in [HumanML3D](https://github.com/EricGuo5513/HumanML3D) 
├── train.txt
├── val.txt
├── test.txt
├── train_val.txt
└── all.txt

KIT-ML - For KIT-ML dataset, you can download and extract it using the following files:

bash dataset/prepare/download_kit.sh
bash dataset/prepare/extract_kit.sh

If you face any issues, you can refer to this link.

Training

Stage 1: VQ-VAE

python train_vq.py --exp_name 'trian_vq' \
                   --dataname t2m \
                   --total_batch_size 256
  • --exp_name: The name of your experiment.
  • --dataname: Dataset name; use t2m for HumanML3D and kit for KIT-ML dataset.

Stage 2: Transformer

python train_t2m_trans.py --exp_name 'train_tr' \
                          --dataname t2m \
                          --time_cond \
                          --z_0_attend_to_all \
                          --unmasked_tokens_not_attend_to_mask_tokens \
                          --total_batch_size 256 \
                          --vq_pretrained_path ./output/vq/vq_last.pth
  • --z_0_attend_to_all: Specifies the causality condition for mask tokens, where each mask token attends to the last T-p+1 mask tokens. If z_0_attend_to_all is not activated, each mask token attends to the first p mask tokens.
  • --time_cond: Uses time as one of the conditions for training the transformer.
  • --unmasked_tokens_not_attend_to_mask_tokens: Prohibits mask tokens from attending to other mask tokens.
  • --vq_pretrained_path: The path to your pretrained VQ-VAE.

Evaluation

For sampling using Order-Agnostic Autoregressive Sampling (OAAS), rand_pos should be set to False. rand_pos=False means that the token with the highest probability is always sampled, and no top_p, top_k, or temperature is applied. If rand_pos=True, the metrics significantly worsen, whereas in Confidence-Based Sampling (CBS), the metrics significantly improve. We do not know why OAAS performance worsens with random sampling during generation. Maybe this is a bug; we are not sure! We would be extremely grateful if anyone could help fix this issue.

python GPT_eval_multi.py --exp_name "eval" \
                         --sampling_type OAAS \
                         --z_0_attend_to_all \
                         --time_cond  \
                         --unmasked_tokens_not_attend_to_mask_tokens \
                         --num_repeat_inner 1 \
                         --resume_pth ./output/vq/vq_last.pth \
                         --resume_trans ./output/t2m/trans_best_fid.pth
  • --sampling_type: Type of sampling.
  • --num_repeat_inner: If you want to calculate MModality, it should be above 10, like 20. For other metrics, 1 is enough.
  • --resume_pth: The path to your pretrained VQ-VAE.
  • --resume_trans: The path to your pretrained transformer.

For sampling using Confidence-Based Sampling (CBS), rand_pos=True significantly improves FID compared to CBS with rand_pos=False.

python GPT_eval_multi.py --exp_name "eval" \
                         --z_0_attend_to_all \
                         --time_cond  \
                         --sampling_type CBS \
                         --rand_pos \
                         --unmasked_tokens_not_attend_to_mask_tokens \
                         --num_repeat_inner 1 \
                         --resume_pth ./output/vq/vq_last.pth \
                         --resume_trans ./output/t2m/trans_best_fid.pth

For evaluation of four temporal editing tasks (inpainting, outpainting, prefix prediction, suffix prediction), you should use eval_edit.py. We used OAAS to report our results on temporal editing tasks in Table 3 of the paper.

python eval_edit.py --exp_name "eval" \
                    --edit_task inbetween \
                    --z_0_attend_to_all \
                    --time_cond  \
                    --sampling_type OAAS \
                    --unmasked_tokens_not_attend_to_mask_tokens \
                    --num_repeat_inner 1 \
                    --resume_pth ./output/vq/vq_last.pth \
                    --resume_trans ./output/t2m/trans_best_fid.pth
  • --edit_task: Four edit tasks are available: inbetween, outpainting, prefix, and suffix.

Generation

For generating a motion sequence run the following

python generate.py --caption 'a person jauntily skips forward.' \
                   --length 196 \
                   --z_0_attend_to_all \
                   --time_cond  \
                   --sampling_type OAAS \
                   --unmasked_tokens_not_attend_to_mask_tokens \
                   --resume_pth ./output/vq/vq_last.pth   \
                   --resume_trans ./output/t2m/trans_best_fid.pth
  • --length: The length of the motion sequence. If not provided, a length estimator will be used to predict the length of the motion sequence based on the caption.
  • --caption: Text prompt used for generating the motion sequence.

For temporal editing, run the following.

python generate.py --temporal_editing \
                   --caption 'a person jauntily skips forward.' \
                   --caption_inbetween 'a man walks in a clockwise circle an then sits.' \
                   --length 196 \
                   --edit_task inbetween \
                   --z_0_attend_to_all \
                   --time_cond \
                   --sampling_type OAAS \
                   --unmasked_tokens_not_attend_to_mask_tokens \
                   --resume_pth ./output/vq/vq_last.pth   \
                   --resume_trans ./output/t2m/trans_best_fid.pth
  • --caption_inbetween: Text prompt used for generating the inbetween/outpainting/prefix/suffix motion sequence.
  • --edit_task: Four edit tasks are available: inbetween, outpainting, prefix, and suffix.

For long sequence generation, run the following.

python generate.py --long_seq_generation \
                   --long_seq_captions 'a person runs forward and jumps.' 'a person crawls.' 'a person does a cart wheel.' 'a person walks forward up stairs and then climbs down.' 'a person sits on the chair and then steps up.' \
                   --long_seq_lengths 128 196 128 128 128 \
                   --z_0_attend_to_all \
                   --time_cond \
                   --sampling_type OAAS \
                   --unmasked_tokens_not_attend_to_mask_tokens \
                   --resume_pth ./output/vq/vq_last.pth   \
                   --resume_trans ./output/t2m/trans_best_fid.pth
  • --long_seq_generation: Activating long sequence generation.
  • --long_seq_captions: Specifies multiple captions.
  • --long_seq_lengths: Specifies multiple lengths (between 40 and 196) corresponding to each caption.

Visualization

The above commands will save .bvh and .mp4 files in ./output/visualization/ directory. The .bvh file can be rendered in Blender. Please refer to this link for more information.

To render the motion sequence in SMPL, you need to pass the .mp4 and .npy file generated by generate.py to visualization/render_mesh.py. The following command will create .obj files that can be easily imported into Blender. This script is running SMPLify and needs GPU as well.

python visualization/render_mesh.py \
  --input_path output/visualization/animation/a_person_jauntily_skips_forwar_196/sample103_repeat0_len196.mp4 \
  --npy_path output/visualization/joints/a_person_jauntily_skips_forwar_196/sample103_repeat0_len196.npy 
  • --input_path: Path to the .mp4 file, created by generate.py.
  • --npy_path: Path to the .npy file, created by generate.py

For rendering .obj files using Blender, you can use the scripts in the visualization/blender_scripts directory. First, open Blender, then go to File -> Import -> Wavefront (.obj), navigate to the directory containing the .obj files, and press A to select and import all of them. Next, copy and paste the script from visualization/blender_scripts/framing_coloring.py into the Scripting tab in Blender, and run the script. Finally, you can render the animation in the Render tab.

Acknowledgement

We would like to express our sincere gratitude to MMM, Momask, MDM, and T2M-GPT for their outstanding open-source contributions.

About

The official Pytorch implementation of “BAD: Bidirectional Auto-regressive Diffusion for Text-to-Motion Generation”

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published