This project is a port of Blealtan's efficient-kan to JAX.
We have ported the mnist.py
example to use our JAX-based KAN implementation.
pip install -r requirements.txt
This will install JAX, Optax, Flax, PyTorch, Torchvision, and TQDM.
After installing the dependencies, you can run the MNIST example using the following command:
python mnist_efficient_kan_jax.py
This will download the MNIST dataset the first time it is run and then start training the model, displaying the training and validation progress.
In additition, was also ported Ziyao Li's FastKAN to JAX.
python mnist_fastkan_jax.py
To compare the performance of the JAX port of EfficientKAN and FastKAN, we ran a benchmark on the MNIST dataset. The models were trained for 10 epochs with a batch size of 64. Below are the results: (Mac Book Pro, M2)
Benchmarking EfficientKAN JAX
Average Epoch Time: 11.81s
Final Training Loss: 0.0122
Final Validation Loss: 0.1102
Final Validation Accuracy: 0.9706
Benchmarking FastKAN JAX
Average Epoch Time: 7.34s
Final Training Loss: 0.0002
Final Validation Loss: 0.1180
Final Validation Accuracy: 0.9723
The benchmark can be run with the following command:
python benchmark.py
- 2024.06.23
- Added FastKAN JAX port to repo.
- Benchmark added
- 2024.06.22
- Initial repository setup and first commit.
MIT