A Julia implementation of Wavelet Kolmogorov-Arnold Networks (wavKAN). Mutli-layer Perceptron (MLP) and wavKAN implementations of the Transformer and Recurrent Neural Operator (RNO) are applied to the 1D unit cell problem with a viscoelastic constitutive relation.
This dataset is particularly difficult for the Transformer to learn, but easy for the RNO. The wavKAN is investigated here to see if it can improve the Transformer's performance, and perhaps even surpass the RNO.
The MLP models were developed in a previous side project. The commit history attributed to their development can be found there.
- Get dependencies:
julia requirements.jl
- Tune hyperparameters:
julia src/models/Vanilla_RNO/hyperparameter_tuning.jl
julia src/models/wavKAN_RNO/hyperparameter_tuning.jl
julia src/models/Vanilla_Transformer/hyperparameter_tuning.jl
julia src/models/wavKAN_Transformer/hyperparameter_tuning.jl
-
(Alternatively to 2) Manually configure hyperparameters in the respective
config.ini
files. -
Train the models, (model_name variable is set on line 26), and log the results:
julia train.jl
- Compare the training loops:
python results.py
- Visualize the results:
julia predict_stress.jl
Wavelet KAN models seem to perform poorly compared to their MLP counterparts. Additionally, the wavKAN Transformer had to be limited in complexity to load on the GPU, which may have contributed to its poor performance. However, the KAN RNO performed decently, but optimised towards a greater complexity than its MLP counterpart. The MLP RNO was the best performing model, with the lowest test loss and BIC, and the highest predictive power and consistency.
Below are the resulting best predictions of the models. The MLPs consistently outperformed the wavKANs, with the RNOs performing better than the Transformers.
Model | Train Loss | Test Loss | BIC | Time (mins) | Param Count |
---|---|---|---|---|---|
MLP RNO | 1.35 ± 0.20 | 0.41 ± 0.07 | 0.82 ± 0.13 | 62.61 ± 19.06 | 52 |
wavKAN RNO | 2.62 ± 0.72 | 0.97 ± 0.39 | 10163.26 ± 0.77 | 43.53 ± 0.53 | 4,413 |
MLP Transformer | 9.43 ± 2.28 | 34.52 ± 61.56 | 9692121.72 ± 123.13 | 5.01 ± 0.72 | 4,209,205 |
wavKAN Transformer | 584.57 ± 153.44 | 187.15 ± 44.61 | 788293.94 ± 89.23 | 23.31 ± 0.22 | 489,562 |
Training time was recorded for each of the models, but this is not considered a reliable estimate of the computational cost of the models, given that they were not run on the same hardware, and multiple tasks were running on the same machine. The number of FLOPs for each model will be calculated and compared in the future, once GFlops is updated to work with the latest Julia version.
There were two intentions behind the development of this repo:
- For me to learn about and verify wavelet transforms for function approximation in the context of KANs.
- To show off some scientific machine learning and demonstrate that the same techniques used for NLP could instead be applied to something else.
- Showcase empirically why you can't just chuck Transformers at sequence modelling problems outside of NLP and expect them to be the most efficient or optimal architecture.
I expected the discrete wavelet transform to work well here, since its good at representing both spatial and temporal dependencies, (which is what you need for viscoplastic material deformation). However, while the wavelet-KAN was able to learn the solution operator when realised as a Recurrent Neural Operator, it struggled wildly during tuning, was outperformed by its MLP counterpart, and completely failed when realised as a Transformer, (although its complexity was much reduced from the MLP Transformer).
That being said, in a different project, a wavelet-KAN realisation of a Convolutional Neural Network completely outshone its MLP variant when predicting a 2D Darcy FLow in terms of generalisation. This suggests that the choice of univariate function matters a lot in your KAN architectures - the wavelets were more suitable for learning the Darcy Flow problem than this viscoplastic material modelling problem.
One of the strengths of the KAN seems to be the ability to embed priors and shape its architecture through the choice of univariate function. Wavelets may be too restrictive compared to some of the other KAN models arising from the community. Architectural flexibility is really important for real-world problems, especially when data is limited, noisy, and expensive to obtain. Even AlphaFold v2 is not just a Transformer - it's an 'Evoformer' with embedded physical and biological priors to help it generalise.
So, I think KANs are awesome and an incredible opportunity for scientific machine learning. If you want to solve the big problems, (e.g. the climate crisis, growth of cancer cells, neurocognitive disorders), you can't just disregard centuries of accumulated knowledge and throw black-box algorithms at them. Different architectures are better at different things, and the KAN's flexibility, along with its capacity for symbolic regression, has the potential to be instrumental in expanding human knowledge.
The dataset has been sourced from the University of Cambridge Engineering Department's Part IIB course on Data-Driven and Learning-Based Methods in Mechanics and Materials.
It consists of the unit cell of a three-phase viscoelastic composite material. The objective is to understand the macroscopic behavior of the material by learning the constitutive relation that maps the strain field
Both a Transformer and a Recurrent Neural Operator (RNO) are implemented in their MLP and wavKAN formats. From a previous project, I found this dataset to be especially difficult to learn for the Transformer, but easy enough for the RNO. It is also one-dimensional, making it a prime candidate to compare wavKAN against its MLP equivalents.
The behavior of the unit cell is described by the following equations:
The strain field
The equilibrium condition is given by:
The stress field
where
where
The composite material is made up of three different phases, each with distinct values of Young's modulus
The objective is to learn the macroscopic constitutive relation that maps the strain field
using a macroscopic constitutive model. This model should capture the complex viscoelastic behavior of the composite material.
The input to the macroscopic constitutive model at each time step
The output of the macroscopic constitutive model at each time step
In essence, the macroscopic constitutive model aims to learn the mapping between the applied macroscopic strain field
- Bozorgasl, Z., & Chen, H. (2024). Wav-KAN: Wavelet Kolmogorov-Arnold Networks.
- Liu, Z., Wang, Y., Vaidya, S., Ruehle, F., Halverson, J., Soljačić, M., Hou, T. Y., & Tegmark, M. (2024). KAN: Kolmogorov-Arnold Networks.
- Mejade Dios, J.-A., Mezura-Montes, E., & Quiroz-Castellanos, M. (2021). Automated parameter tuning as a bilevel optimization problem solved by a surrogate-assisted population-based approach.
- Liu, B., Cicirello, A. (2024). Cambridge University Engineering Department Part IIB Course on Data-Driven and Learning-Based Methods in Mechanics and Materials.