Skip to content

rodrigodzf/physmodjax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Towards Efficient Modelling of String Dynamics: A Comparison of State Space and Koopman based Deep Learning Methods

This is the accompanying repository for the paper Towards Efficient Modelling of String Dynamics: A Comparison of State Space and Koopman Methods.

arXiv

Install

Create an enviroment with conda with at least python 3.10

conda create -n physmodjax python=3.10

Install Jax first

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Install the rest of the dependencies

pip install -e '.[dev]'

Generate the dataset

To generate the dataset use the following command from the root of the repository, after installing the library:

generate_dataset -m +dataset=ftm_string_linear
generate_dataset -m +dataset=ftm_string_nonlinear

Each command will create 4 folders with combinations of initial conditions and sampling rates for the linear and nonlinear string models.

Each dataset folder must be converted to a single file .npy file using the following command, for example:

convert_to_single_file \
data/ftm_linear/ftm_string_lin_1000_Gaussian_4000Hz \
data/ftm_linear/ftm_string_lin_1000_Gaussian_4000Hz.npy

We do this to speed up the data loading process during training.

Data convention

The data has the following convention:

(timesteps, gridpoints[x,y,z], state_variables[u,v])

and for multiple trajectories (initial conditions):

(initial_conditions, timesteps, gridpoints[x,y,z], state_variables[u,v])

Premade Datasets

The location for storing datasets in apocrita is:

/data/EECS-Sandler-Lab/physical_modelling

Train

All experiments need a path where the data is located. This has to be appended to the command:

++datamodule.data_array=data.npy

Train 1d model (default whole truncated trajectory at 4000 steps)

Model Command
1d Koopman train_rnn +experiment=1d_koopman ++epochs=1000 ++epochs_val=50 ++optimiser.learning_rate=0.001 ++model.d_vars=1
1d Koopman time-varying train_rnn +experiment=1d_koopman_varying ++epochs=1000 ++epochs_val=50 ++optimiser.learning_rate=0.001 ++model.d_vars=1
1d LRU train_rnn +experiment=1d_lru ++epochs=1000 ++epochs_val=50 ++optimiser.learning_rate=0.001 ++model.d_vars=1
1d S5 train_rnn +experiment=1d_s5 ++epochs=1000 ++epochs_val=50 ++optimiser.learning_rate=0.001 ++model.d_vars=1
1d FNO train_rnn +experiment=1d_fno ++epochs=1000 ++epochs_val=50 ++optimiser.learning_rate=0.001 ++model.d_vars=1

Train 1d model with non-overlapping segments of 400 steps (for AR mode)

Model Command
1d Koopman train_rnn +experiment=1d_koopman datamodule=string_windowed
1d Koopman time-varying train_rnn +experiment=1d_koopman_varying datamodule=string_windowed
1d LRU train_rnn +experiment=1d_lru datamodule=string_windowed
1d S5 train_rnn +experiment=1d_s5 datamodule=string_windowed
1d FNO train_rnn +experiment=1d_fno datamodule=string_windowed

Train 1d model with random (overlapping) segments of 400 steps per trajectory (for AR mode)

Here we can also test against FNO.

Model Command
1d Koopman train_rnn +experiment=1d_koopman datamodule=string_tb ++epochs=200 ++epochs_val=20
1d Koopman time-varying train_rnn +experiment=1d_koopman_varying datamodule=string_tb ++epochs=200 ++epochs_val=20
1d LRU train_rnn +experiment=1d_lru datamodule=string_tb ++epochs=200 ++epochs_val=20
1d S5 train_rnn +experiment=1d_s5 datamodule=string_tb ++epochs=200 ++epochs_val=20
1d FNO train_rnn +experiment=1d_s5 datamodule=string_tb +experiment=1d_fno_tb ++epochs=200 ++epochs_val=20

Testing the library

For development purposes, you can use the following command to test the library:

JAX_PLATFORMS=cpu nbdev_test

Using the JAX_PLATFORMS=cpu environment variable is important to avoid using the GPU, as the tests are not optimized for GPU usage.

The same should be done for exporting the readme:

JAX_PLATFORMS=cpu nbdev_readme