Skip to content

Commit

Permalink
Update documentation (#49)
Browse files Browse the repository at this point in the history
* quickstart snn_simulation-zh

* fix a bug in snn_sim

* 3 en version

---------

Co-authored-by: LyuMuyang <[email protected]>
  • Loading branch information
senngadaisuki and LyuMuyang authored Dec 16, 2024
1 parent 1c8446b commit 8c90098
Show file tree
Hide file tree
Showing 5 changed files with 2,187 additions and 28 deletions.
Binary file added docs/_static/snn-simulation1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
374 changes: 367 additions & 7 deletions docs/quickstart/ann_training-en.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,373 @@
"cells": [
{
"cell_type": "markdown",
"source": [
"# Training Artificial Neural Networks"
],
"id": "3e91a13480dd82aa",
"metadata": {
"collapsed": false
},
"id": "3e91a13480dd82aa"
"source": [
"# Training Artificial Neural Networks"
]
},
{
"cell_type": "markdown",
"id": "98b8e8f3",
"metadata": {},
"source": [
"In recent years, artificial neural networks have developed rapidly and play an important role in neuroscience research. As a high-performance computational framework for brain dynamics modeling, brainstate also supports the training of artificial neural networks, facilitating the integration of neural dynamics models with artificial neural networks.\n",
"\n",
"Here, we will introduce how to train an artificial neural network using brainstate, with an example of a simple 2-layer multilayer perceptron (MLP) for handwritten digit recognition (MNIST)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6d512371",
"metadata": {},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"import numpy as np\n",
"from datasets import load_dataset\n",
"\n",
"import brainstate as bst\n",
"from braintools.metric import softmax_cross_entropy_with_integer_labels"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "fd228991",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'0.1.0'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"bst.__version__"
]
},
{
"cell_type": "markdown",
"id": "9889df14",
"metadata": {},
"source": [
"## Preparing the Dataset\n",
"\n",
"First, we need to obtain the dataset and wrap it into an iterable object that automatically samples and shuffles the data according to the batch size."
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "01e96ac0",
"metadata": {},
"outputs": [],
"source": [
"dataset = load_dataset('mnist')\n",
"X_train = np.array(np.stack(dataset['train']['image']), dtype=np.uint8)\n",
"X_test = np.array(np.stack(dataset['test']['image']), dtype=np.uint8)\n",
"X_train = (X_train > 0).astype(jnp.float32)\n",
"X_test = (X_test > 0).astype(jnp.float32)\n",
"Y_train = np.array(dataset['train']['label'], dtype=np.int32)\n",
"Y_test = np.array(dataset['test']['label'], dtype=np.int32)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8b886cbc",
"metadata": {},
"outputs": [],
"source": [
"class Dataset:\n",
" def __init__(self, X, Y, batch_size, shuffle=True):\n",
" self.X = X\n",
" self.Y = Y\n",
" self.batch_size = batch_size\n",
" self.shuffle = shuffle\n",
" self.indices = np.arange(len(X))\n",
" self.current_index = 0\n",
" if self.shuffle:\n",
" np.random.shuffle(self.indices)\n",
"\n",
" def __iter__(self):\n",
" self.current_index = 0\n",
" if self.shuffle:\n",
" np.random.shuffle(self.indices)\n",
" return self\n",
"\n",
" def __next__(self):\n",
" # Check if all samples have been processed\n",
" if self.current_index >= len(self.X):\n",
" raise StopIteration\n",
"\n",
" # Define the start and end of the current batch\n",
" start = self.current_index\n",
" end = start + self.batch_size\n",
" if end > len(self.X):\n",
" end = len(self.X)\n",
" \n",
" # Update current index\n",
" self.current_index = end\n",
"\n",
" # Select batch samples\n",
" batch_indices = self.indices[start:end]\n",
" batch_X = self.X[batch_indices]\n",
" batch_Y = self.Y[batch_indices]\n",
"\n",
" # Ensure batch has consistent shape\n",
" if batch_X.ndim == 1:\n",
" batch_X = np.expand_dims(batch_X, axis=0)\n",
"\n",
" return batch_X, batch_Y"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d6c383f1",
"metadata": {},
"outputs": [],
"source": [
"# Initialize training and testing datasets\n",
"batch_size = 32\n",
"train_dataset = Dataset(X_train, Y_train, batch_size, shuffle=True)\n",
"test_dataset = Dataset(X_test, Y_test, batch_size, shuffle=False)"
]
},
{
"cell_type": "markdown",
"id": "8ba1b50a",
"metadata": {},
"source": [
"## Defining the Artificial Neural Network\n",
"\n",
"When defining an artificial neural network in brainstate, you need to inherit the base class ``brainstate.nn.Module``. In the class method ``__init__()``, define the layers in the network (make sure to initialize the base class first using ``super().__init__()``). In the class method ``__call__()``, define the forward pass method of the network.\n",
"\n",
"brainstate also supports defining operations for individual layers in the network. For these custom layers, you need to inherit from the base class ``brainstate.nn.Module``, similar to defining a network.\n",
"\n",
"All quantities that need to change in the model should be encapsulated in the ``State`` object. Parameters that need to be updated during training should be encapsulated in a subclass of ``State`` called ``ParamState``. Other quantities that need to be updated during training are encapsulated in another subclass of ``State`` called ``ShortTermState``."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cea8b5d5",
"metadata": {},
"outputs": [],
"source": [
"# Define linear layer\n",
"class Linear(bst.nn.Module):\n",
" def __init__(self, din: int, dout: int):\n",
" super().__init__()\n",
" self.w = bst.ParamState(bst.random.rand(din, dout)) # Initialize weight parameters\n",
" self.b = bst.ParamState(jnp.zeros((dout,))) # Initialize bias parameters\n",
"\n",
" def __call__(self, x):\n",
" return x @ self.w.value + self.b.value # Perform linear transformation"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8e541917",
"metadata": {},
"outputs": [],
"source": [
"# Define a short-term state for counting times called\n",
"class Count(bst.ShortTermState):\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f9445bd8",
"metadata": {},
"outputs": [],
"source": [
"# Define MLP model\n",
"class MLP(bst.nn.Module):\n",
" def __init__(self, din, dhidden, dout):\n",
" super().__init__()\n",
" self.count = Count(jnp.array(0)) # Count how many times model is called\n",
" self.linear1 = Linear(din, dhidden) # brainstate有常规层的实现,可以直接写 self.linear1 = bst.nn.Linear(din, dhidden)\n",
" self.linear2 = Linear(dhidden, dout)\n",
" self.flatten = bst.nn.Flatten(start_axis=1) # Flatten images to 1D\n",
" self.relu = bst.nn.ReLU() # ReLU activation function\n",
"\n",
" def __call__(self, x):\n",
" self.count.value += 1 # Increment call count\n",
"\n",
" x = self.flatten(x)\n",
" x = self.linear1(x)\n",
" x = self.relu(x) # 也兼容jax函数,可以直接写 x = jax.nn.relu(x)\n",
" x = self.linear2(x)\n",
" return x"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97aefeaa",
"metadata": {},
"outputs": [],
"source": [
"# Initialize model with input, hidden, and output layer sizes\n",
"model = MLP(din=28*28, dhidden=512, dout=10)"
]
},
{
"cell_type": "markdown",
"id": "d7edf945",
"metadata": {},
"source": [
"## Optimizer Setup\n",
"\n",
"``brainstate.optim`` provides various optimizers to choose from.\n",
"\n",
"After instantiating the optimizer, you need to specify which parameters the optimizer should update by calling ``optimizer.register_trainable_weights()``.\n",
"\n",
"In this case, we use ``brainstate.nn.Module.states()`` to collect all the ``State`` objects of the network nodes and their sub-nodes in the model. We restrict the types of ``State`` collected to ``brainstate.ParamState`` (in this model, ``State`` instances may also have other types like ``Count``, which do not need to be updated by the optimizer, so we apply type restrictions)."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e8dd079f",
"metadata": {},
"outputs": [],
"source": [
"# Initialize optimizer and register model parameters\n",
"optimizer = bst.optim.SGD(lr = 1e-3) # Initialize SGD optimizer with learning rate\n",
"optimizer.register_trainable_weights(model.states(bst.ParamState)) # Register parameters for optimization"
]
},
{
"cell_type": "markdown",
"id": "0bbbfad2",
"metadata": {},
"source": [
"## Model Training\n",
"\n",
"During model training, use the ``brainstate.augment.grad`` function to calculate gradients. This function requires the loss function and the parameters (``State``) for which gradients should be computed.\n",
"\n",
"Then, the gradients are passed to the previously defined optimizer via ``update()`` for the update.\n",
"\n",
"To improve computational efficiency and performance, use the ``brainstate.compile.jit`` function to decorate the training step function, enabling just-in-time compilation."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "99272b18",
"metadata": {},
"outputs": [],
"source": [
"# Training step function\n",
"@bst.compile.jit\n",
"def train_step(batch):\n",
" x, y = batch\n",
" # Define loss function\n",
" def loss_fn():\n",
" return softmax_cross_entropy_with_integer_labels(model(x), y).mean()\n",
" \n",
" # Compute gradients of the loss with respect to model parameters\n",
" grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()\n",
" optimizer.update(grads) # Update parameters using optimizer"
]
},
{
"cell_type": "markdown",
"id": "e8c2dfe0",
"metadata": {},
"source": [
"## Model Testing\n",
"\n",
"Similarly, use the ``brainstate.compile.jit`` function to decorate the testing step function, allowing for just-in-time compilation to improve computational efficiency and performance."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7456afa1",
"metadata": {},
"outputs": [],
"source": [
"# Testing step function\n",
"@bst.compile.jit\n",
"def test_step(batch):\n",
" x, y = batch\n",
" y_pred = model(x) # Perform forward pass\n",
" loss = softmax_cross_entropy_with_integer_labels(y_pred, y).mean() # Compute loss\n",
" correct = (y_pred.argmax(1) == y).sum() # Count correct predictions\n",
"\n",
" return {'loss': loss, 'correct': correct}"
]
},
{
"cell_type": "markdown",
"id": "a550e569",
"metadata": {},
"source": [
"## Training Process\n",
"\n",
"This completes the setup and the process for training an artificial neural network with brainstate."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eae5f682",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, test loss: 410.2366638183594, test accuracy: 0.24890001118183136\n",
"epoch: 1, test loss: 278.79864501953125, test accuracy: 0.6233000159263611\n",
"epoch: 2, test loss: 75.72823333740234, test accuracy: 0.7638000249862671\n",
"epoch: 3, test loss: 59.49712371826172, test accuracy: 0.7830000519752502\n",
"epoch: 4, test loss: 38.07597351074219, test accuracy: 0.8623000383377075\n",
"epoch: 5, test loss: 54.225074768066406, test accuracy: 0.8329000473022461\n",
"epoch: 6, test loss: 74.46405792236328, test accuracy: 0.7676000595092773\n",
"epoch: 7, test loss: 35.6864128112793, test accuracy: 0.867900013923645\n",
"epoch: 8, test loss: 140.0616912841797, test accuracy: 0.7529000639915466\n",
"epoch: 9, test loss: 42.05353927612305, test accuracy: 0.8574000597000122\n",
"times called: 21880\n"
]
}
],
"source": [
"# Execute training and testing\n",
"total_steps = 20\n",
"for epoch in range(10):\n",
" for step, batch in enumerate(train_dataset):\n",
" train_step(batch) # Perform training step for each batch\n",
"\n",
" # Calculate test loss and accuracy\n",
" test_loss, correct = 0, 0\n",
" for step_, test_ in enumerate(test_dataset):\n",
" logs = test_step(test_)\n",
" test_loss += logs['loss']\n",
" correct += logs['correct']\n",
" test_loss += logs['loss']\n",
" test_loss = test_loss / (step_ + 1)\n",
" test_accuracy = correct / len(X_test)\n",
" print(f\"epoch: {epoch}, test loss: {test_loss}, test accuracy: {test_accuracy}\")\n",
"\n",
"print('times model called:', model.count.value) # Output number of model calls"
]
}
],
"metadata": {
Expand All @@ -20,14 +380,14 @@
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 8c90098

Please sign in to comment.