This codebase relies on a number of dependencies, some of which are difficult to get running. If you're using conda on Linux, use the following to create an environment and install the dependencies:
conda env create -f environment.yml
conda activate polarmae
# Install pytorch3d
cd extensions
git clone https://github.com/facebookresearch/pytorch3d.git
cd pytorch3d
MAX_JOBS=N pip install -e .
# Install C-NMS
cd ../cnms
MAX_JOBS=N pip install -e .
# Install polarmae
cd ../.. # should be in the root directory of the repository now
pip install -e .
Note
environment.yml is a full environment specification, which includes an install of cuda 12.4, pytorch 2.1.5, and python 3.9.
pytorch3d
and cnms
are compiled from source, and will only be compiled for the CUDA device architecture of the visible GPU(s) available on the system.
Change N
in MAX_JOBS=N
to the number of cores you want to use for installing pytorch3d
and cnms
. At least 4 cores is recommended to compile pytorch3d
in a reasonable amount of time.
If you'd like to do the installation on your own, you will need the following dependencies:
- CUDA
- PyTorch 2.1.0, 2.1.1, 2.1.2, 2.2.0, 2.2.1, 2.2.2, 2.3.0, 2.3.1, 2.4.0 or 2.4.1.
- gcc & g++ >= 4.9 and < 13
- pytorch3d
- PyTorch Lightning
- NumPy
- Lightning-Utilities
- Scikit-Learn
- Omegaconf
- h5py
pytorch3d
is usually the most difficult dependency to install. See the pytorch3d
INSTALL.md for more details.
There are a couple of extra dependencies that are optional, but recommended:
conda install wandb jupyter matplotlib
Tutorial notebooks for understanding the dataset, model architecture, pretraining, and finetuning are available in the tutorial
directory.
We use and provide the 156 GB PILArNet-M dataset of >1M LArTPC events. See DATASET.md for more details, but the dataset is available at this link, or can be downloaded with the following command:
gdown --folder 1nec9WYPRqMn-_3m6TdM12TmpoInHDosb -O /path/to/save/dataset
Note
gdown
must be installed via e.g. pip install gdown
or conda install gdown
.
Model | Num. Events | Config | SVM |
Download |
---|---|---|---|---|
Point-MAE | 1M | pointmae.yml | 0.719 | here |
PoLAr-MAE | 1M | polarmae.yml | 0.732 | here |
Our evaluation consists of training an ensemble of linear SVMs to classify individual tokens (i.e., groups) as containing one or more classes. This is done via a One vs Rest strategy, where each SVM is trained to classify a single class against all others.
After installing the dependencies, you can run the following commands in an interactive Python session to load the pretrained model(s):
>>> from polarmae.models.ssl import PointMAE, PoLArMAE
>>> !wget https://github.com/DeepLearnPhysics/PoLAr-MAE/releases/download/weights/mae_pretrain.ckpt
>>> !wget https://github.com/DeepLearnPhysics/PoLAr-MAE/releases/download/weights/polarmae_pretrain.ckpt
>>> model = PointMAE.load_from_checkpoint("mae_pretrain.ckpt") # or
>>> model = PoLArMAE.load_from_checkpoint("polarmae_pretrain.ckpt")
Model | Training Method | Num. Events | Config | Download | |
---|---|---|---|---|---|
Point-MAE | Linear probing | 10k | part_segmentation_mae_peft.yml | 0.772 | here |
PoLAr-MAE | Linear probing | 10k | part_segmentation_polarmae_peft.yml | 0.798 | here |
Point-MAE | FFT | 10k | part_segmentation_mae_fft.yml | 0.831 | here |
PoLAr-MAE | FFT | 10k | part_segmentation_polarmae_fft.yml | 0.837 | here |
Our evaluation for semantic segmentation consists of 1:1 comparisons between the predicted and ground truth segmentations.
After installing the dependencies, you can run the following commands in an interactive Python session to load the pretrained model(s):
>>> from polarmae.models.finetune import SemanticSegmentation
>>> from polarmae.utils.checkpoint import load_finetune_checkpoint
>>> !wget https://github.com/DeepLearnPhysics/PoLAr-MAE/releases/download/weights/{mae,polarmae}_{fft,peft}_segsem.ckpt
>>> model = load_finetune_checkpoint(SemanticSegmentation,
"{mae,polarmae}_{fft,peft}_segsem.ckpt",
data_path="/path/to/pilarnet-m/dataset",
pretrained_ckpt_path="{mae,polarmae}_pretrain.ckpt")
Here, the brackets {} denote the model and the training method -- choose one from {mae,polarmae}
and {fft,peft}
. Note that you must use the load_finetune_checkpoint
function to load the model, as it has to do some extra setup not required for the pretraining phase. Knowing the data_path
is necessary as the number of segmentation classes is determined by the dataset.
Important: Learning rate instructions
The following commands use the configurations we used for our experiments. Particularly, our learning rates are set assuming a batch size of 128 (i.e., 32 across 4 GPUs). If you want to train on a single GPU with the same batch size in the configuration file, you will need to scale the learning rate accordingly. We recommend scaling the learning rate by the square root of the ratio of the batch sizes. I.e., if your batch size is
To pretrain Point-MAE, modify the config file to include the path to the PILArNet-M dataset, and run the following command:
python -m polarmae.tasks.pointmae fit --config configs/pointmae.yml
To pretrain PoLAr-MAE, modify the config file to include the path to the PILArNet-M dataset, and run the following command:
python -m polarmae.tasks.polarmae fit --config configs/polarmae.yml
To train a semantic segmentation model, modify the config file to include the path to the PILArNet-M dataset, and run the following command:
python -m polarmae.tasks.part_segmentation fit --config configs/part_segmentation_{mae,polarmae}_{peft,fft}.yml \
--model.pretrained_ckpt_path path/to/pretrained/checkpoint.ckpt
where {mae,polarmae}
is either mae
or polarmae
, and {peft,fft}
is either peft
or fft
. You can either specify the pretrained checkpoint path in the config, or pass it as an argument to the command like above.
This repository is built upon the lovely Point-MAE and point2vec repositories.
If you find this work useful, please consider citing the following paper:
@misc{young2025particletrajectoryrepresentationlearning,
title={Particle Trajectory Representation Learning with Masked Point Modeling},
author={Sam Young and Yeon-jae Jwa and Kazuhiro Terao},
year={2025},
eprint={2502.02558},
archivePrefix={arXiv},
primaryClass={hep-ex},
url={https://arxiv.org/abs/2502.02558},
}