Skip to content

Repository for "Particle Trajectory Representation Learning with Masked Point Modeling"

License

Notifications You must be signed in to change notification settings

DeepLearnPhysics/PoLAr-MAE

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Particle Trajectory Representation Learning with Masked Point Modeling
(PoLAr-MAE)


arch

Installation

This codebase relies on a number of dependencies, some of which are difficult to get running. If you're using conda on Linux, use the following to create an environment and install the dependencies:

conda env create -f environment.yml
conda activate polarmae

# Install pytorch3d
cd extensions
git clone https://github.com/facebookresearch/pytorch3d.git
cd pytorch3d
MAX_JOBS=N pip install -e .

# Install C-NMS
cd ../cnms
MAX_JOBS=N pip install -e .

# Install polarmae
cd ../.. # should be in the root directory of the repository now
pip install -e .

Note

environment.yml is a full environment specification, which includes an install of cuda 12.4, pytorch 2.1.5, and python 3.9.

pytorch3d and cnms are compiled from source, and will only be compiled for the CUDA device architecture of the visible GPU(s) available on the system.

Change N in MAX_JOBS=N to the number of cores you want to use for installing pytorch3d and cnms. At least 4 cores is recommended to compile pytorch3d in a reasonable amount of time.

If you'd like to do the installation on your own, you will need the following dependencies:

pytorch3d is usually the most difficult dependency to install. See the pytorch3d INSTALL.md for more details.

There are a couple of extra dependencies that are optional, but recommended:

conda install wandb jupyter matplotlib

Tutorial Notebooks

Tutorial notebooks for understanding the dataset, model architecture, pretraining, and finetuning are available in the tutorial directory.

PILArNet-M Dataset

We use and provide the 156 GB PILArNet-M dataset of >1M LArTPC events. See DATASET.md for more details, but the dataset is available at this link, or can be downloaded with the following command:

gdown --folder 1nec9WYPRqMn-_3m6TdM12TmpoInHDosb -O /path/to/save/dataset

Note

gdown must be installed via e.g. pip install gdown or conda install gdown.

Models

Pretraining

Model Num. Events Config SVM $F_1$ Download
Point-MAE 1M pointmae.yml 0.719 here
PoLAr-MAE 1M polarmae.yml 0.732 here

Our evaluation consists of training an ensemble of linear SVMs to classify individual tokens (i.e., groups) as containing one or more classes. This is done via a One vs Rest strategy, where each SVM is trained to classify a single class against all others. $F_1$ is the mean $F_1$ score over all semantic categories in the validation set of the PILArNet-M dataset.

After installing the dependencies, you can run the following commands in an interactive Python session to load the pretrained model(s):

>>> from polarmae.models.ssl import PointMAE, PoLArMAE
>>> !wget https://github.com/DeepLearnPhysics/PoLAr-MAE/releases/download/weights/mae_pretrain.ckpt
>>> !wget https://github.com/DeepLearnPhysics/PoLAr-MAE/releases/download/weights/polarmae_pretrain.ckpt
>>> model = PointMAE.load_from_checkpoint("mae_pretrain.ckpt") # or
>>> model = PoLArMAE.load_from_checkpoint("polarmae_pretrain.ckpt")

Semantic Segmentation

Model Training Method Num. Events Config $F_1$ Download
Point-MAE Linear probing 10k part_segmentation_mae_peft.yml 0.772 here
PoLAr-MAE Linear probing 10k part_segmentation_polarmae_peft.yml 0.798 here
Point-MAE FFT 10k part_segmentation_mae_fft.yml 0.831 here
PoLAr-MAE FFT 10k part_segmentation_polarmae_fft.yml 0.837 here

Our evaluation for semantic segmentation consists of 1:1 comparisons between the predicted and ground truth segmentations. $F_1$ is the mean $F_1$ score over all semantic categories in the validation set of the PILArNet-M dataset.

After installing the dependencies, you can run the following commands in an interactive Python session to load the pretrained model(s):

>>> from polarmae.models.finetune import SemanticSegmentation
>>> from polarmae.utils.checkpoint import load_finetune_checkpoint
>>> !wget https://github.com/DeepLearnPhysics/PoLAr-MAE/releases/download/weights/{mae,polarmae}_{fft,peft}_segsem.ckpt
>>> model = load_finetune_checkpoint(SemanticSegmentation, 
                                    "{mae,polarmae}_{fft,peft}_segsem.ckpt",
                                    data_path="/path/to/pilarnet-m/dataset",
                                    pretrained_ckpt_path="{mae,polarmae}_pretrain.ckpt")

Here, the brackets {} denote the model and the training method -- choose one from {mae,polarmae} and {fft,peft}. Note that you must use the load_finetune_checkpoint function to load the model, as it has to do some extra setup not required for the pretraining phase. Knowing the data_path is necessary as the number of segmentation classes is determined by the dataset.

Training

Important: Learning rate instructions

The following commands use the configurations we used for our experiments. Particularly, our learning rates are set assuming a batch size of 128 (i.e., 32 across 4 GPUs). If you want to train on a single GPU with the same batch size in the configuration file, you will need to scale the learning rate accordingly. We recommend scaling the learning rate by the square root of the ratio of the batch sizes. I.e., if your batch size is $b$, you should set the learning rate $l \rightarrow l \times \sqrt{b/128}$.

Pretraining

To pretrain Point-MAE, modify the config file to include the path to the PILArNet-M dataset, and run the following command:

python -m polarmae.tasks.pointmae fit --config configs/pointmae.yml

To pretrain PoLAr-MAE, modify the config file to include the path to the PILArNet-M dataset, and run the following command:

python -m polarmae.tasks.polarmae fit --config configs/polarmae.yml
Example training plots training plots svm plots

Semantic Segmentation

To train a semantic segmentation model, modify the config file to include the path to the PILArNet-M dataset, and run the following command:

python -m polarmae.tasks.part_segmentation fit --config configs/part_segmentation_{mae,polarmae}_{peft,fft}.yml \
                        --model.pretrained_ckpt_path path/to/pretrained/checkpoint.ckpt

where {mae,polarmae} is either mae or polarmae, and {peft,fft} is either peft or fft. You can either specify the pretrained checkpoint path in the config, or pass it as an argument to the command like above.

Example training plots training plots svm plots svm plots

Acknowledgements

This repository is built upon the lovely Point-MAE and point2vec repositories.

Citing PoLAr-MAE

If you find this work useful, please consider citing the following paper:

@misc{young2025particletrajectoryrepresentationlearning,
      title={Particle Trajectory Representation Learning with Masked Point Modeling}, 
      author={Sam Young and Yeon-jae Jwa and Kazuhiro Terao},
      year={2025},
      eprint={2502.02558},
      archivePrefix={arXiv},
      primaryClass={hep-ex},
      url={https://arxiv.org/abs/2502.02558}, 
}

About

Repository for "Particle Trajectory Representation Learning with Masked Point Modeling"

Resources

License

Stars

Watchers

Forks

Packages

No packages published