Skip to content

training continuous time RNNs with ReLU activation function and multiple constraints using BPTT

License

Notifications You must be signed in to change notification settings

engellab/trainRNNbrain

Repository files navigation

Training continuous time Recurrent Neural Networks (RNNs) on various behavioral tasks

This project established a pipeline for seamlessly defining time-constrained behavioral task and training RNNs on it using backpropagation (BPTT) on PyTorch.

Further, it contains additional post-training task-performance analysis classes, as well as a class for analysis of the RNN dynamics: computing the fixed-points of the dynamics for a given input.

Some examples:

Fixed point structure revealed after training an RNN with Tanh activation function to perform a 3 bit flip-flop task


Random trials after training the RNN on 2 bit flip-flop task


Fixed point structure for MemoryAntiNumber task: The first line attractor (blue-red points, appearing for the input during the stimulus-presentation stage) lies in the nullspace of the W_out. The second line-attractor (violet-tomato points, appearing for the input provided on the recall stage) has some projection on the output axes


Fixed point structure in the MemoryAntiAngle task: same as for the line attractors in MemoryAntiNumber task, but instead of the line attractors, the networks forms ring attractors.


Continuous-time RNN description

The dynamics for RNN are captured by the following equations:

Where "\tau" is the time constant, "x" is the state vector of the RNN, "u" is an input vector, "W rec" is the recurrent connectivity of the RNN, "W inp" - matrix of input connectivities distributing input vector "u" to the neural nodes, "b rec" is a bias in the recurrent connectivity, "\xi" is some gaussian random noise. The output of the network is provided by the readout matrix "W out" applied to the neural nodes.

There are two classes implementing RNN dynamics:

  • RNN_pytorch -- used for training the network on the task
  • RNN_numpy -- used for performance analysis, fixed point analysis, and easy plotting.

Task definition

Each task has its own class specifying the structure of (input, target output) of the behavior. It should contain two main methods:

  • generate_input_target_stream(**kwargs) -- generates a single (input, target output, condition) tuple with specified parameters
  • get_batch(**kwargs) -- generates a batch of inputs, targets and conditions. The batch dimension is the last.

The implemented example tasks are:

  • Context-Dependent Decision Making
  • Delayed Match to Sample
  • 3 Bit Flip-Flop
  • MemoryAntiNumber
  • MemoryAntiAngle

Descriptions of these tasks are provided in the comments in the relevant task classes.

One can easily define their own task following the provided general template.

Training

During the training, the connectivity matrices W_rec, W_inp, W_out are iteratively modified to minimize a loss function: the lower the loss function, the better the network performs the task.

The training loop is implemented in the Trainer class, which accepts the task and the RNN_pytorch instance. Trainer implements three main methods:

  • train_step(input_batch, target_batch) -- returns the loss-value for a given batch, (linked to the computational graph to compute the gradient w.r.t connectivity weights) as well as the vector of losses on each individual trial
  • eval_step(input_batch, target_batch) -- returns the loss value for a given batch, detached from the gradient.
  • run_training(**kwargs) -- implements an iterative update of connectivity parameters, minimizing a loss function

Performance Analysis

The class PerformanceAnalyzer accepts the RNN_numpy instance and a Task instance and has two main methods:

  • get_validation_score(scoring_function, input_batch, target_batch, **kwargs) -- runs the network with the specified inputs and calculates the mean loss between the predicted and target outputs using the specified scoring function.
  • plot_trials(input_batch, target_batch, **kwargs) -- generates a figure plotting multiple predicted outputs as a response to specified inputs, as well as shows target outputs for comparison.

One can extend the base class by defining task-specific PerformanceAnalyzer (see AnalyzerCDDM as an example)

Fixed-point Analysis

The fixed-point analysis is implemented in the DynamicSystemAnalyzer class and accepts RNN_numpy instance.

The class contains three methods:

  • get_fixed_points(Input_vector, mode, **kwargs) -- calculates stable and unstable fixed points of the RNN's dynamics for a given input. It searches for exact fixed points if mode = 'exact' option, using scipy.fsolve methods applied to the right-hand side of the dynamics equations. Alternatively, if mode = 'approx' it searches for 'slow points' -- points where RHS of dynamics is approximately zero. In the latter case, the cut-off threshold for a point is controlled by the parameter 'fun_tol'.
  • plot_fixed_points(projection, P) -- assumes that the fixed points has been calculated with get_fixed_points method for maximum three input vectors, If th projection matrix `P' is not specified, assembles the fixed points into an array, performs the PCA on them and projects the points on either first 2 (projection='2D') or first 3 ( projection='3D') PCs, returning the figure with the projected fixed points.
  • compute_point_analytics(point, Input_vector, **kwargs) -- at a given point in the state-space, and given input to the RNN, calculate statistics of the point: value of the |RHS|^2, Jacobian, eigenvalues, and the principle left and right eigenvectors.

Saving the data

When initialized, DataSaver creates a dedicated data_folder and stores its address as a 'data_folder' parameter. It has two methods:

  • save_data(data, filename) -- saves either a pickle or JSON file containing the data into the 'data_folder'
  • save_figure(figure, filename) -- saves a figure as a png-file into the 'data_folder'

Integration with DataJoint is coming

Where to get started

I suggest starting at this Jupyter notebook to get the idea how to use the package: [training RNN to perform context-dependent decision-making](jupyter/Training CDDM.ipynb)

About

training continuous time RNNs with ReLU activation function and multiple constraints using BPTT

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published