Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nn.Sequential fix #13

Merged
merged 2 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
![GitHub repo size](https://img.shields.io/github/repo-size/QVPR/VPRTempo.svg?style=flat-square)
[![PyPI downloads](https://img.shields.io/pypi/dw/VPRTempo.svg)](https://pypistats.org/packages/vprtempo)

This repository contains code for VPRTempo, a spiking neural network that uses temporally encoding to perform visual place recognition tasks. The network is based off of [BLiTNet](https://arxiv.org/pdf/2208.01204.pdf) and adapted to the [VPRSNN](https://github.com/QVPR/VPRSNN) framework.
This repository contains code for [VPRTempo](vprtempo.github.io), a spiking neural network that uses temporally encoding to perform visual place recognition tasks. The network is based off of [BLiTNet](https://arxiv.org/pdf/2208.01204.pdf) and adapted to the [VPRSNN](https://github.com/QVPR/VPRSNN) framework.

<p style="width: 50%; display: block; margin-left: auto; margin-right: auto">
<img src="./assets/github_image.png" alt="VPRTempo method diagram"/>
Expand All @@ -31,15 +31,14 @@ To use VPRTempo, please follow the instructions below for installation and usage
## License & Citation
This repository is licensed under the [MIT License](./LICENSE)

If you use our code, please cite the following [paper](https://arxiv.org/abs/2309.10225):
If you use our code, please cite our IEEE ICRA [paper](https://arxiv.org/abs/2309.10225):
```
@misc{hines2023vprtempo,
@inproceedings{hines2024vprtempo,
title={VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition},
author={Adam D. Hines and Peter G. Stratton and Michael Milford and Tobias Fischer},
year={2023},
eprint={2309.10225},
archivePrefix={arXiv},
primaryClass={cs.RO}
year={2024},
booktitle={2024 IEEE International Conference on Robotics and Automation (ICRA)}

}
```
## Installation and setup
Expand Down Expand Up @@ -106,9 +105,7 @@ For convenience, all data should be organised in the `./dataset` folder in the f
|--winter
```
### Custom Datasets
To define your own custom dataset to use with VPRTempo, you will need to follow the conventions for [PyTorch Datasets & Dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html). We provide a simple script `./dataset/custom_dataset.py` which will rename images in user defined directories and generate the necessary `.csv` file to load into VPRTempo.

To learn how to use custom datasets, please see the [CustomDatasets.ipynb](https://github.com/AdamDHines/VPRTempo-quant/tree/main/tutorials) tutorial.
To define your own custom dataset to use with VPRTempo, you will need to follow the conventions for [PyTorch Datasets & Dataloaders](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html).

## Usage
Running VPRTempo and VPRTempoQuant is handlded by `main.py`, which can be operated either through the command terminal or directly running the script. See below for more details.
Expand Down
5 changes: 0 additions & 5 deletions docs/.gitignore

This file was deleted.

Empty file removed docs/.gitkeep
Empty file.
1 change: 0 additions & 1 deletion docs/README.md

This file was deleted.

2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def parse_network(use_quantize=False, train_new_model=False):
help="Ground truth tolerance for matching")

# Define training parameters
parser.add_argument('--filter', type=int, default=1,
parser.add_argument('--filter', type=int, default=8,
help="Images to skip for training and/or inferencing")
parser.add_argument('--epoch', type=int, default=4,
help="Number of epochs to train the model")
Expand Down
6 changes: 2 additions & 4 deletions vprtempo/VPRTempo.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,9 @@ def evaluate(self, models, test_loader):
for model in models:
self.inferences.append(nn.Sequential(
model.feature_layer.w,
nn.Hardtanh(0, 0.9),
nn.ReLU(),
nn.Hardtanh(0,1.0),
model.output_layer.w,
nn.Hardtanh(0, 0.9),
nn.ReLU()
nn.Hardtanh(0,1.0)
))
self.inferences[-1].to(torch.device(self.device))
# Initiliaze the output spikes variable
Expand Down
4 changes: 1 addition & 3 deletions vprtempo/VPRTempoQuant.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,8 @@ def evaluate(self, models, test_loader, layers=None):
self.inferences.append(nn.Sequential(
model.feature_layer.w,
nn.Hardtanh(0, maxSpike),
nn.ReLU(),
model.output_layer.w,
nn.Hardtanh(0, maxSpike),
nn.ReLU()
nn.Hardtanh(0, maxSpike)
))
# Initialize the tqdm progress bar
pbar = tqdm(total=self.query_places,
Expand Down
Git LFS file not shown