diff --git a/docs/_static/artificial-neural-network-tutorial1.png b/docs/_static/artificial-neural-network-tutorial1.png
new file mode 100644
index 0000000..eccb7dc
Binary files /dev/null and b/docs/_static/artificial-neural-network-tutorial1.png differ
diff --git a/docs/_static/artificial-neural-network-tutorial2.png b/docs/_static/artificial-neural-network-tutorial2.png
new file mode 100644
index 0000000..481aad8
Binary files /dev/null and b/docs/_static/artificial-neural-network-tutorial2.png differ
diff --git a/docs/_static/artificial-neural-network-tutorial3.png b/docs/_static/artificial-neural-network-tutorial3.png
new file mode 100644
index 0000000..6d8a54d
Binary files /dev/null and b/docs/_static/artificial-neural-network-tutorial3.png differ
diff --git a/docs/_static/artificial-neural-network-tutorial4.png b/docs/_static/artificial-neural-network-tutorial4.png
new file mode 100644
index 0000000..2099033
Binary files /dev/null and b/docs/_static/artificial-neural-network-tutorial4.png differ
diff --git a/docs/_static/artificial-neural-network-tutorial5.jpg b/docs/_static/artificial-neural-network-tutorial5.jpg
new file mode 100644
index 0000000..659a33a
Binary files /dev/null and b/docs/_static/artificial-neural-network-tutorial5.jpg differ
diff --git a/docs/tutorials/artificial_neural_networks-zh.ipynb b/docs/tutorials/artificial_neural_networks-zh.ipynb
index d8e1e37..8a81b33 100644
--- a/docs/tutorials/artificial_neural_networks-zh.ipynb
+++ b/docs/tutorials/artificial_neural_networks-zh.ipynb
@@ -2,13 +2,503 @@
"cells": [
{
"cell_type": "markdown",
- "source": [
- "# 构建人工神经网络"
- ],
+ "id": "cff08e5afff144b2",
"metadata": {
"collapsed": false
},
- "id": "cff08e5afff144b2"
+ "source": [
+ "# 构建人工神经网络"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "1c216cfe",
+ "metadata": {},
+ "source": [
+ "人工神经网络(英语:artificial neural network,ANNs)又称类神经网络,简称神经网络(neural network,NNs),在机器学习和认知科学领域,是一种模仿生物神经网络(动物的中枢神经系统,特别是大脑)的结构和功能的数学模型或计算模型,用于对函数进行估计或近似。与人类大脑的神经元相互连接类似,人工神经网络在网络的各个层中也有神经元相互连接。\n",
+ "\n",
+ "相比于脉冲神经网络,人工神经网络中的神经元更为简化,不具有自身动力学。神经元之间传递的信息也不是0、1离散化的动作电位,而是连续的浮点数,可以被理解为在一个时间步内该神经元的发放率。虽然人工神经网络最初是受生物脑启发而来,也可以表现出一些生物脑的性质,但是我们主要使用它作为一种强大的模型解决具体问题,而不用纠结它各部分和生物的对应。\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "\n",
+ "\n",
+ "
\n",
+ "\n",
+ "\n",
+ "## 人工神经网络架构\n",
+ "人工神经网络的神经元分布于不同的层,信息逐层前向传播。这些层可以分为3类:\n",
+ "- 输入层:人工神经网路的第一层是输入层,它接收输入并将其传递给隐藏层。\n",
+ "- 隐藏层:隐藏层对输入层传递来的信息进行各种计算和特征提取。通常,会有不止一个隐藏层,信息依次通过所有隐藏层进行计算。\n",
+ "- 输出层:最后,输出层接收隐藏层的信息进行计算,提供最终结果。\n",
+ "\n",
+ "\n",
+ "
\n",
+ "\n",
+ "\n",
+ "人工神经网络根据计算不同有很多种类的层,其中最简单的层为线性层。线性层接收输入后,会计算输入的加权和,再加上一个偏差值。用公式可以表示为:\n",
+ "$$\n",
+ "\\sum_{\\mathrm{i=1}}^{\\mathrm{n}}\\mathrm{W_i}*\\mathrm{X_i}+\\mathrm{b}。\\tag{1}\n",
+ "$$\n",
+ "这样的点积、求和都是线性运算。如果添加更多的层,但只使用线性操作,那么添加层将没有任何效果,通过交换律和结合律都可以等效为一层单一的线性变换。因此,需要通过添加非线性的**激活函数**,增加模型的表现力。可以理解为,通过激活函数判断神经元是否被激活,只有被激活神经元才有(非0)输出。这和生物神经元是类似的。激活函数类型多种多样,可以根据任务效果选用。在本教程的神经网络实现中,我们将使用ReLU(整流线性函数)和Softmax(归一化指数函数)激活函数,在下文中再详细叙述。\n",
+ "\n",
+ "## 人工神经网络的工作流\n",
+ "人工神经网络是数据驱动的统计模型。我们训练模型解决问题,不是明确地编写规则,而是提供训练数据,使模型学会解决问题的方法。\n",
+ "\n",
+ "具体而言,就是提供数据集,规定模型输入输出的对应关系,运行模型得到模型目前的输出,用**损失函数**计算当前输出和正确输出的差异。然后通过**反向传播**的方法,逐层对参数(主要是$W$和$b$)求偏导数,获得参数优化的大小和方向。然后用**优化器**进行参数优化,使模型的输出和数据集给出的标准输出更为相近。\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c7262143",
+ "metadata": {},
+ "source": [
+ "# 构建你的第一个人工神经网络\n",
+ "下面我们用`brainstate`,写代码构建一个三层的多层感知机(Multilayer Perceptron, MLP),来完成一个手写数字识别(MNIST)的任务,作为示例。\n",
+ "\n",
+ "我们将手写数字的图片输入构建好的多层感知机中,令多层感知机输出这张图片写的是哪一个数字。\n",
+ "\n",
+ "\n",
+ "
\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "0825d40e",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import jax.numpy as jnp\n",
+ "import numpy as np\n",
+ "from datasets import load_dataset\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import brainstate as bst\n",
+ "from braintools.metric import softmax_cross_entropy_with_integer_labels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "8bf694f8",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'0.1.0'"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "bst.__version__"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d67a35ad",
+ "metadata": {},
+ "source": [
+ "## 数据集准备\n",
+ "首先准备数据集。数据集提供了许多组对应的“输入-输出”样本,在这个任务中就是,手写数字图片和对应数字是多少的标签。\n",
+ "\n",
+ "数据集可以分为训练集和测试集(二者的样本没有交集),在训练集上训练模型,调整模型的参数;在测试集上测试模型的训练效果,此时模型参数不更新。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "e37df3d9",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = load_dataset('mnist')\n",
+ "X_train = np.array(np.stack(dataset['train']['image']), dtype=np.uint8)\n",
+ "X_test = np.array(np.stack(dataset['test']['image']), dtype=np.uint8)\n",
+ "X_train = (X_train > 0).astype(jnp.float32)\n",
+ "X_test = (X_test > 0).astype(jnp.float32)\n",
+ "Y_train = np.array(dataset['train']['label'], dtype=np.int32)\n",
+ "Y_test = np.array(dataset['test']['label'], dtype=np.int32)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0cc63f9d",
+ "metadata": {},
+ "source": [
+ "MNIST训练集有60000个样本,测试集有10000个样本。每个样本输入就是$28 \\times 28$的单通道图像,输出是一位标签值。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "d3150ef6",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "iVBORw0KGgoAAAANSUhEUgAAARYAAAEUCAYAAADuhRlEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAASvUlEQVR4nO3dX2hU6RnH8d/ETU51NzPZqMlkMEmz3XaFSi2I2mCRgsE/BWl2vSjbXlgoLu6OC7r0D7nQdKGQ1oVebCvsRalSqFqERtmFCm7USCFJaaqI3SWoK3W2ycSuNGdMNGNI3l7s7rSjMX+f8ZyJ3w88F3POm5nHd5wfZ857ZibinHMCAEMlQTcAYOEhWACYI1gAmCNYAJgjWACYI1gAmCNYAJgjWACYI1gAmHsq6AYeNDExof7+fpWXlysSiQTdDoDPOOd0584dJRIJlZRMc0ziCuQ3v/mNq6+vd57nuXXr1rmenp4Z/V0qlXKSKIoKaaVSqWlfxwUJluPHj7uysjL3u9/9zv3jH/9wu3btchUVFW5wcHDavx0aGgp84iiKenQNDQ1N+zouSLCsW7fOJZPJ3O3x8XGXSCRcW1vbtH/r+37gE0dR1KPL9/1pX8fmJ2/v37+v3t5eNTU15baVlJSoqalJXV1dD43PZrPKZDJ5BaC4mQfLJ598ovHxcVVXV+dtr66uVjqdfmh8W1ubYrFYrmpra61bAvCYBb7c3NLSIt/3c5VKpYJuCcA8mS83L1u2TIsWLdLg4GDe9sHBQcXj8YfGe54nz/Os2wAQIPMjlrKyMq1Zs0YdHR25bRMTE+ro6FBjY6P1wwEIo3kt/zzC8ePHned57siRI+6DDz5wr7zyiquoqHDpdHrav2VViKLCXTNZFSrIlbff/e539e9//1sHDhxQOp3W17/+dZ0+ffqhE7oAFqaIc+H6Mu1MJqNYLBZ0GwAewfd9RaPRKccEvioEYOEhWACYI1gAmCNYAJgjWACYI1gAmCNYAJgjWACYI1gAmCNYAJgjWACYI1gAmCNYAJgjWACYI1gAmCNYAJgjWACYI1gAmCNYAJgjWACYI1gAmCNYAJgjWACYI1gAmCNYAJgjWACYI1gAmCNYAJgjWACYI1gAmDMPlp/97GeKRCJ5tXLlSuuHwWPinKNCUsXkqULc6Ve/+lW9//77/3uQpwryMABCqiCv+KeeekrxeLwQdw2gCBTkHMvVq1eVSCT03HPP6fvf/75u3rxZiIcBEFIRZ/zm7c9//rOGh4f1wgsvaGBgQG+++ab+9a9/6cqVKyovL39ofDabVTabzd3OZDKqra21bAnzUGzv7ReySCQSdAuSJN/3FY1Gpx7kCuw///mPi0aj7re//e2k+1tbW50kKqSF8Aj6/8Ln5fv+tL0WfLm5oqJCX/nKV3Tt2rVJ97e0tMj3/VylUqlCtwSgwAoeLMPDw7p+/bpqamom3e95nqLRaF4BKG7mq0I/+tGPtH37dtXX16u/v1+tra1atGiRXn75ZeuHWvAc5zdQpMyD5eOPP9bLL7+s27dva/ny5frmN7+p7u5uLV++3PqhAISU+arQfGUyGcVisaDbCIWQPTUIWDGtCvFZIQDmCBYA5ggWAOYIFgDmCBYA5ggWAOb4opQAsZyM/xeW5WQLHLEAMEewADBHsAAwR7AAMEewADBHsAAwR7AAMMd1LFjwFtL1IcWCIxYA5ggWAOYIFgDmCBYA5ggWAOYIFgDmCBYA5riOJUDTXV/xpHxfC/Ow8HDEAsAcwQLAHMECwBzBAsAcwQLAHMECwBzBAsAc17GE2EK4vsPiu1D4PpXiM+sjlgsXLmj79u1KJBKKRCI6efJk3n7nnA4cOKCamhotXrxYTU1Nunr1qlW/AIrArINlZGREq1ev1qFDhybdf/DgQb399tt655131NPTo6efflpbtmzR6OjovJsFUCTcPEhy7e3tudsTExMuHo+7t956K7dtaGjIeZ7njh07NqP79H3fSaJmUMUg6Dmi7Mv3/Wmfd9OTtzdu3FA6nVZTU1NuWywW0/r169XV1TXp32SzWWUymbwCUNxMgyWdTkuSqqur87ZXV1fn9j2ora1NsVgsV7W1tZYtAQhA4MvNLS0t8n0/V6lUKuiWAMyTabDE43FJ0uDgYN72wcHB3L4HeZ6naDSaVwCKm2mwNDQ0KB6Pq6OjI7ctk8mop6dHjY2Nlg8FIMRmfYHc8PCwrl27lrt948YNXbp0SZWVlaqrq9PevXv185//XF/+8pfV0NCg/fv3K5FIqLm52bJvyObCMVfgi+xmcv9cALcAzXb58Ny5c5MuQe3cudM59+mS8/79+111dbXzPM9t2rTJ9fX1zfj+WW5+vBUGQc8BNbuayXJz5LMnNjQymYxisVjQbTwxwvD0c8RSXHzfn/ZcaOCrQgAWHoIFgDmCBYA5ggWAOYIFgDmCBYA5ggWAOYIFgDmCBYA5ggWAOYIFgDmCBYA5ggWAOX6w7AkXhh9FK/Rj8Onpx48jFgDmCBYA5ggWAOYIFgDmCBYA5ggWAOYIFgDmuI4FUwrDdS7zNZMeudbFFkcsAMwRLADMESwAzBEsAMwRLADMESwAzBEsAMwRLADMzTpYLly4oO3btyuRSCgSiejkyZN5+3/wgx8oEonk1datW636Rcg8+Fw/WMXCOTdlYXZmHSwjIyNavXq1Dh069MgxW7du1cDAQK6OHTs2ryYBFJdZX9K/bds2bdu2bcoxnucpHo/PuSkAxa0g51jOnz+vqqoqvfDCC3r11Vd1+/btR47NZrPKZDJ5BaC4mQfL1q1b9fvf/14dHR365S9/qc7OTm3btk3j4+OTjm9ra1MsFstVbW2tdUsAHrOIm8eZqUgkovb2djU3Nz9yzEcffaQvfelLev/997Vp06aH9mezWWWz2dztTCZDuCwgC+XEZzGdiC403/cVjUanHFPw5ebnnntOy5Yt07Vr1ybd73meotFoXgEobgUPlo8//li3b99WTU1NoR8KQEjMelVoeHg47+jjxo0bunTpkiorK1VZWak333xTO3bsUDwe1/Xr1/WTn/xEzz//vLZs2WLaOIrDTN5CLJS3S/g/bpbOnTvnJD1UO3fudHfv3nWbN292y5cvd6Wlpa6+vt7t2rXLpdPpGd+/7/uT3j+1cKsYBD1HYSrf96edr3mdvC2ETCajWCwWdBt4jEL2X3BSnLz9n1CcvAXw5CFYAJgjWACYI1gAmCNYAJjjB8sQuPmuuDyOVaXpHoNVo3wcsQAwR7AAMEewADBHsAAwR7AAMEewADBHsAAwx3UsKKhi+OQy7HHEAsAcwQLAHMECwBzBAsAcwQLAHMECwBzBAsAcwQLAHBfIYUpc4PYpvshpdjhiAWCOYAFgjmABYI5gAWCOYAFgjmABYI5gAWBuVsHS1tamtWvXqry8XFVVVWpublZfX1/emNHRUSWTSS1dulTPPPOMduzYocHBQdOmMTPOuXnXkyISiUxZmJ1ZBUtnZ6eSyaS6u7t15swZjY2NafPmzRoZGcmN2bdvn959912dOHFCnZ2d6u/v10svvWTeOIAQc/Nw69YtJ8l1dnY655wbGhpypaWl7sSJE7kxH374oZPkurq6ZnSfvu87SZRBYeaCfq6KqXzfn3Y+53WOxfd9SVJlZaUkqbe3V2NjY2pqasqNWblyperq6tTV1TXpfWSzWWUymbwCUNzmHCwTExPau3evNmzYoFWrVkmS0um0ysrKVFFRkTe2urpa6XR60vtpa2tTLBbLVW1t7VxbAhAScw6WZDKpK1eu6Pjx4/NqoKWlRb7v5yqVSs3r/gAEb06fbt6zZ4/ee+89XbhwQStWrMhtj8fjun//voaGhvKOWgYHBxWPxye9L8/z5HneXNoAEFKzOmJxzmnPnj1qb2/X2bNn1dDQkLd/zZo1Ki0tVUdHR25bX1+fbt68qcbGRpuOAYTerI5Yksmkjh49qlOnTqm8vDx33iQWi2nx4sWKxWL64Q9/qDfeeEOVlZWKRqN6/fXX1djYqG984xsF+QcsZO4Juo6kkLgOJQAWS3KHDx/Ojbl375577bXX3LPPPuuWLFniXnzxRTcwMDDjx2C5+X8FG0E/jwutZrLcHPls4kMjk8koFosF3UYohOypKVocsdjyfV/RaHTKMXxWCIA5ggWAOYIFgDmCBYA5ggWAOX5XqEBY0bHDqk7x4YgFgDmCBYA5ggWAOYIFgDmCBYA5ggWAOYIFgDmCBYA5LpB7BC5ws8HFbU8mjlgAmCNYAJgjWACYI1gAmCNYAJgjWACYI1gAmOM6FkyJ61AwFxyxADBHsAAwR7AAMEewADBHsAAwR7AAMEewADA3q2Bpa2vT2rVrVV5erqqqKjU3N6uvry9vzLe+9S1FIpG82r17t2nTj8OD/4YntYC5mFWwdHZ2KplMqru7W2fOnNHY2Jg2b96skZGRvHG7du3SwMBArg4ePGjaNIBwm9WVt6dPn867feTIEVVVVam3t1cbN27MbV+yZIni8bhNhwCKzrzOsfi+L0mqrKzM2/6HP/xBy5Yt06pVq9TS0qK7d+/O52EAFJk5f1ZoYmJCe/fu1YYNG7Rq1arc9u9973uqr69XIpHQ5cuX9dOf/lR9fX3605/+NOn9ZLNZZbPZ3O1MJjPXlgCEhZuj3bt3u/r6epdKpaYc19HR4SS5a9euTbq/tbXVSaIoqkjK9/1p82FOwZJMJt2KFSvcRx99NO3Y4eFhJ8mdPn160v2jo6PO9/1cpVKpwCeOoqhH10yCZVZvhZxzev3119Xe3q7z58+roaFh2r+5dOmSJKmmpmbS/Z7nyfO82bQBIORmFSzJZFJHjx7VqVOnVF5ernQ6LUmKxWJavHixrl+/rqNHj+rb3/62li5dqsuXL2vfvn3auHGjvva1rxXkHwAghGbzFkiPODQ6fPiwc865mzdvuo0bN7rKykrneZ57/vnn3Y9//OMZHTp9zvf9wA/1KIp6dM3k9Rz5LDBCI5PJKBaLBd0GgEfwfV/RaHTKMXxWCIA5ggWAOYIFgDmCBYA5ggWAOYIFgDmCBYA5ggWAOYIFgDmCBYA5ggWAOYIFgDmCBYC50AVLyD5sDeABM3mNhi5Y7ty5E3QLAKYwk9do6L6PZWJiQv39/SovL1ckElEmk1Ftba1SqdS03wGBqTGXNp7UeXTO6c6dO0okEiopmfqYZM4//1EoJSUlWrFixUPbo9HoE/UkFhJzaeNJnMeZfglb6N4KASh+BAsAc6EPFs/z1Nrayk+EGGAubTCP0wvdyVsAxS/0RywAig/BAsAcwQLAHMECwFzog+XQoUP64he/qC984Qtav369/vrXvwbdUuhduHBB27dvVyKRUCQS0cmTJ/P2O+d04MAB1dTUaPHixWpqatLVq1eDaTbE2tratHbtWpWXl6uqqkrNzc3q6+vLGzM6OqpkMqmlS5fqmWee0Y4dOzQ4OBhQx+ER6mD54x//qDfeeEOtra36+9//rtWrV2vLli26detW0K2F2sjIiFavXq1Dhw5Nuv/gwYN6++239c4776inp0dPP/20tmzZotHR0cfcabh1dnYqmUyqu7tbZ86c0djYmDZv3qyRkZHcmH379undd9/ViRMn1NnZqf7+fr300ksBdh0Ss/lR+Mdt3bp1LplM5m6Pj4+7RCLh2traAuyquEhy7e3tudsTExMuHo+7t956K7dtaGjIeZ7njh07FkCHxePWrVtOkuvs7HTOfTpvpaWl7sSJE7kxH374oZPkurq6gmozFEJ7xHL//n319vaqqakpt62kpERNTU3q6uoKsLPiduPGDaXT6bx5jcViWr9+PfM6Dd/3JUmVlZWSpN7eXo2NjeXN5cqVK1VXV/fEz2Vog+WTTz7R+Pi4qqur87ZXV1crnU4H1FXx+3zumNfZmZiY0N69e7VhwwatWrVK0qdzWVZWpoqKiryxzGUIP90MhFEymdSVK1f0l7/8JehWikJoj1iWLVumRYsWPXSGfXBwUPF4PKCuit/nc8e8ztyePXv03nvv6dy5c3lf6RGPx3X//n0NDQ3ljWcuQxwsZWVlWrNmjTo6OnLbJiYm1NHRocbGxgA7K24NDQ2Kx+N585rJZNTT08O8PsA5pz179qi9vV1nz55VQ0ND3v41a9aotLQ0by77+vp08+ZN5jLos8dTOX78uPM8zx05csR98MEH7pVXXnEVFRUunU4H3Vqo3blzx128eNFdvHjRSXK/+tWv3MWLF90///lP55xzv/jFL1xFRYU7deqUu3z5svvOd77jGhoa3L179wLuPFxeffVVF4vF3Pnz593AwECu7t69mxuze/duV1dX586ePev+9re/ucbGRtfY2Bhg1+EQ6mBxzrlf//rXrq6uzpWVlbl169a57u7uoFsKvXPnzjlJD9XOnTudc58uOe/fv99VV1c7z/Pcpk2bXF9fX7BNh9BkcyjJHT58ODfm3r177rXXXnPPPvusW7JkiXvxxRfdwMBAcE2HBF+bAMBcaM+xACheBAsAcwQLAHMECwBzBAsAcwQLAHMECwBzBAsAcwQLAHMECwBzBAsAcwQLAHP/BdmVef0zDuYfAAAAAElFTkSuQmCC",
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "5\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "((60000, 28, 28), (60000,), (10000, 28, 28), (10000,))"
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "plt.figure(figsize=(3, 3))\n",
+ "plt.imshow(X_train[0], cmap='gray')\n",
+ "plt.show()\n",
+ "print(Y_train[0])\n",
+ "X_train.shape, Y_train.shape, X_test.shape, Y_test.shape"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "23747623",
+ "metadata": {},
+ "source": [
+ "为了方便训练,我们需要将数据集包装为`Dataset`类,统一进行一些处理。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "5deb09a0",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class Dataset:\n",
+ " def __init__(self, X, Y, batch_size, shuffle=True):\n",
+ " self.X = X\n",
+ " self.Y = Y\n",
+ " self.batch_size = batch_size\n",
+ " self.shuffle = shuffle\n",
+ " self.indices = np.arange(len(X))\n",
+ " self.current_index = 0\n",
+ " if self.shuffle:\n",
+ " np.random.shuffle(self.indices)\n",
+ "\n",
+ " def __iter__(self):\n",
+ " self.current_index = 0\n",
+ " if self.shuffle:\n",
+ " np.random.shuffle(self.indices)\n",
+ " return self\n",
+ "\n",
+ " def __next__(self):\n",
+ " # Check if all samples have been processed\n",
+ " if self.current_index >= len(self.X):\n",
+ " raise StopIteration\n",
+ "\n",
+ " # Define the start and end of the current batch\n",
+ " start = self.current_index\n",
+ " end = start + self.batch_size\n",
+ " if end > len(self.X):\n",
+ " end = len(self.X)\n",
+ " \n",
+ " # Update current index\n",
+ " self.current_index = end\n",
+ "\n",
+ " # Select batch samples\n",
+ " batch_indices = self.indices[start:end]\n",
+ " batch_X = self.X[batch_indices]\n",
+ " batch_Y = self.Y[batch_indices]\n",
+ "\n",
+ " # Ensure batch has consistent shape\n",
+ " if batch_X.ndim == 1:\n",
+ " batch_X = np.expand_dims(batch_X, axis=0)\n",
+ "\n",
+ " return batch_X, batch_Y"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "9c426971",
+ "metadata": {},
+ "source": [
+ "在训练时,我们一般会将数据划分为批(batch),输入模型。模型根据一个批次的所有样本计算损失,进行梯度回传优化参数。\n",
+ "\n",
+ "因为如果将整个训练集输入模型,会需要过大的显存;如果只将一个训练样本输入模型训练,并行化不够会导致资源浪费、遍历完一个训练集用时过长,也会导致每次参数更新只包含一个样本点的信息,不利于在整个数据集上收敛。以批的形式输入是很好的权衡。测试集由于不需要更新参数,所以可以在显存允许的情况下,批的大小可以设置地更大一些。\n",
+ "\n",
+ "训练集一般打乱顺序(`shuffle=True`),这样能尽量保证每次迭代训练集,每个批都有不同的样本组合,有利于模型在整个数据集上收敛。训练集不需要更新参数所以不用打乱。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "8a44b4d2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize training and testing datasets\n",
+ "batch_size = 32\n",
+ "train_dataset = Dataset(X_train, Y_train, batch_size, shuffle=True)\n",
+ "test_dataset = Dataset(X_test, Y_test, batch_size, shuffle=False)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "54a36298",
+ "metadata": {},
+ "source": [
+ "## 模型结构设置\n",
+ "\n",
+ "经典的多层感知机具有3个线性层,一个输入层,一个隐藏层,和一个输出层。\n",
+ "\n",
+ "在定义每个线性层时,需要给出这一层的输入维数和输出维数。(这是非常好理解的,由公式(1),需要提供$W$的大小)\n",
+ "\n",
+ "- 线性层只接受一维输入,因此我们要使用`flatten()`函数将二维图片展为一维。在这里,$28*28=784$就是第一个线性层输入的大小。\n",
+ "- 手写数字识别是十分类任务,对于多分类任务,一般要求模型的输出是一个十维的向量,每一维表示是这一维对应标签的概率是多少。因此,最后一个线性层的输出维是$10$。\n",
+ "- 隐层在自动提取特征,维数越多,提取的特征越多,表现力越强大。在这个简单的任务中,可以将隐藏层的维数设置为$784$和$10$之间的一个数,如果效果不好,就增加隐藏层维数。在其他更困难的任务中,可以增加隐藏层层数,并且使隐藏层维数超过模型输入和输出维。但一般隐藏层的维数要逐层增加,再逐层下降。\n",
+ "- 注意,相邻的层,上一层的输出维数,就是下一层的输入维数。\n",
+ "\n",
+ "\n",
+ "
\n",
+ "\n",
+ "\n",
+ "如前文所说,线性层之间需要加入激活函数,否则就相当于一个线性层。在这里,我们使用了ReLU(整流线性函数)激活函数将负值置为0,增加非线性。公式为:\n",
+ "$$\n",
+ "\\text{ReLU}(x) = \\max(0, x)\\tag{2}\n",
+ "$$\n",
+ "\n",
+ "\n",
+ "
\n",
+ "\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "331a3ca7",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Define MLP model\n",
+ "class MLP(bst.nn.Module):\n",
+ " def __init__(self, din, dhidden, dout):\n",
+ " super().__init__()\n",
+ " self.linear1 = bst.nn.Linear(din, dhidden) # Define the first linear layer, input dimension is din, output dimension is dhidden \n",
+ " self.linear2 = bst.nn.Linear(dhidden, dhidden) # Define the second linear layer, input dimension is dhidden, output dimension is dhidden\n",
+ " self.linear3 = bst.nn.Linear(dhidden, dout) # Define the third linear layer, input dimension is dhidden, output dimension is dout (10 classes for MNIST)\n",
+ " self.flatten = bst.nn.Flatten(start_axis=1) # Flatten images to 1D\n",
+ " self.relu = bst.nn.ReLU() # ReLU activation function\n",
+ "\n",
+ " def __call__(self, x):\n",
+ " x = self.flatten(x) # Flatten the input image from 2D to 1D\n",
+ " x = self.linear1(x) # Pass the flattened input through the first linear layer\n",
+ " x = self.relu(x) # Alternatively, you can use jax's ReLU function: x = jax.nn.relu(x)\n",
+ " x = self.linear2(x) # Pass the result through the second linear layer\n",
+ " x = self.relu(x) # Apply the ReLU activation function\n",
+ " x = self.linear3(x) # Pass the result through the third linear layer to get the final output\n",
+ "\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "347dc916",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize model with input, hidden, and output layer sizes\n",
+ "model = MLP(din=28*28, dhidden=512, dout=10)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b9c94f0e",
+ "metadata": {},
+ "source": [
+ "## 模型优化\n",
+ "\n",
+ "人工神经网络得到图片输入后,运行输出分类结果。我们需要将这个分类结果和标准答案进行对比,然后优化参数,让模型预测的类别概率尽可能接近真实类别。\n",
+ "\n",
+ "### 损失函数\n",
+ "\n",
+ "在这个过程中,用**损失函数**衡量分类结果和真实类别差异的大小。针对输出的不同含义,有多种多样的损失函数可以选择。这里我们使用多分类任务的常用损失函数交叉熵(cross entropy)函数。对于一个样本,交叉熵损失函数公式如下:\n",
+ "\n",
+ "$$\n",
+ "Loss(y_i, \\hat{y_i}) = - \\sum_{i=1}^{N} y_i \\log(\\hat{y_i})\\tag{3}\n",
+ "$$\n",
+ "\n",
+ "其中,$\\hat{y_i}$ 是模型预测的概率分布(所有类别的概率和为1),$y_i$ 是真实类别标签,使用 One-Hot 编码(只有正确类别对应的值为 1,其余为 0),$N$ 是类别的数量。如果模型正确预测(真实类别的概率接近 1),损失会非常小;反之,真实类别的概率接近 0 时,损失会非常大。\n",
+ "\n",
+ "在此,我们给损失函数的输入模型输出值并不是概率值(没有控制所有类别的概率和为1),是因为`braintools.metric`中的损失函数`softmax_cross_entropy_with_integer_labels`会自动让模型的输出经过softmax(归一化指数函数)激活函数,将其转换为概率分布。公式如下:\n",
+ "\n",
+ "$$\n",
+ "\\sigma(\\mathbf{z})_i = \\frac{e^{z_i}}{\\sum_{j=1}^{K} e^{z_j}}\\tag{4}\n",
+ "$$\n",
+ "\n",
+ "其中,$\\mathbf{z}$ 是输入向量,$z_i$ 是输入向量中的第 $i$ 个元素,$K$ 是输入向量的维度。\n",
+ "\n",
+ "同时,`softmax_cross_entropy_with_integer_labels`也可以将一维真实类别标签自动转为One-Hot 编码。\n",
+ "\n",
+ "### 反向传播算法\n",
+ "\n",
+ "**反向传播**(Backpropagation) 是神经网络训练过程中优化参数的关键算法。它的主要任务是根据损失函数的值,计算出模型中每个参数(主要是权重$W$和偏置$b$)的梯度。用这种方法推测模型预测的误差来源,用于优化参数。得到损失值后,反向传播利用链式法则,逐层计算损失函数对每个参数的偏导数。这些偏导数(梯度)描述了损失函数随着参数变化的方向和幅度,是优化的基础。\n",
+ "\n",
+ "### 优化器\n",
+ "**优化器**(Optimizer) 是一种算法,决定得到梯度后,如何用其更新网络的参数(主要是权重和偏置),减少损失值。基本的更新公式是:\n",
+ "$$\n",
+ "w=w-\\eta\\cdot\\frac{\\partial L}{\\partial w}\\tag{5}\n",
+ "$$\n",
+ "其中$w$是参数,$\\eta$是学习率,$\\frac\\partial L{\\partial w}$是梯度。\n",
+ "\n",
+ "有很多种类的优化器可供选择,在此我们选用的是常用的SGD(随机梯度下降)优化器。\n",
+ "\n",
+ "在这里我们将模型的优化器实例化,并指定其更新的参数是哪些。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "375b3e54",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Initialize optimizer and register model parameters\n",
+ "optimizer = bst.optim.SGD(lr = 1e-3) # Initialize SGD optimizer with learning rate\n",
+ "optimizer.register_trainable_weights(model.states(bst.ParamState)) # Register parameters for optimization"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d5ecc46c",
+ "metadata": {},
+ "source": [
+ "## 模型训练&测试\n",
+ "在每一个训练数据批的迭代中,都有这样的训练流程:\n",
+ "- 将数据输入模型,得到输出\n",
+ "- 计算损失值\n",
+ "- 计算梯度\n",
+ "- 将梯度提供给优化器,优化器优化参数"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "68c121df",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Training step function\n",
+ "@bst.compile.jit\n",
+ "def train_step(batch):\n",
+ " x, y = batch\n",
+ " # Define loss function\n",
+ " def loss_fn():\n",
+ " return softmax_cross_entropy_with_integer_labels(model(x), y).mean()\n",
+ " \n",
+ " # Compute gradients of the loss with respect to model parameters\n",
+ " grads = bst.augment.grad(loss_fn, model.states(bst.ParamState))()\n",
+ " optimizer.update(grads) # Update parameters using optimizer"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0fe15316",
+ "metadata": {},
+ "source": [
+ "在每一个测试数据批的迭代中,测试流程无需计算梯度和优化参数,但可以选择计算正确率反映训练的效果:\n",
+ "- 将数据输入模型,得到输出\n",
+ "- 计算损失值\n",
+ "- 计算正确率"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "9a4df640",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Testing step function\n",
+ "@bst.compile.jit\n",
+ "def test_step(batch):\n",
+ " x, y = batch\n",
+ " y_pred = model(x) # Perform forward pass\n",
+ " loss = softmax_cross_entropy_with_integer_labels(y_pred, y).mean() # Compute loss\n",
+ " correct = (y_pred.argmax(1) == y).sum() # Count correct predictions\n",
+ "\n",
+ " return {'loss': loss, 'correct': correct}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "a0a91dfa",
+ "metadata": {},
+ "source": [
+ "模型一般在同一个训练集上训练多次,每一次都分批迭代完整个训练集;可以每次或每几次在测试集上查看训练效果。\n",
+ "\n",
+ "在下面这个例子中,随着训练集迭代次数的增加,训练损失值逐渐减小,测试正确率逐渐提高,说明我们成功地训练了一个多层感知机进行手写数字分类。"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "9964de29",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "epoch: 0, test loss: 1.2870731353759766, test accuracy: 0.8459000587463379\n",
+ "epoch: 1, test loss: 0.8965848088264465, test accuracy: 0.8851000666618347\n",
+ "epoch: 2, test loss: 0.7595921754837036, test accuracy: 0.8992000222206116\n",
+ "epoch: 3, test loss: 0.6877797842025757, test accuracy: 0.9066000580787659\n",
+ "epoch: 4, test loss: 0.6413018703460693, test accuracy: 0.9117000699043274\n",
+ "epoch: 5, test loss: 0.6088062524795532, test accuracy: 0.914400041103363\n",
+ "epoch: 6, test loss: 0.5789325833320618, test accuracy: 0.9187000393867493\n",
+ "epoch: 7, test loss: 0.5573971271514893, test accuracy: 0.921000063419342\n",
+ "epoch: 8, test loss: 0.5350563526153564, test accuracy: 0.9230000376701355\n",
+ "epoch: 9, test loss: 0.5217731595039368, test accuracy: 0.9251000285148621\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Execute training and testing\n",
+ "total_steps = 20\n",
+ "for epoch in range(10):\n",
+ " for step, batch in enumerate(train_dataset):\n",
+ " train_step(batch) # Perform training step for each batch\n",
+ "\n",
+ " # Calculate test loss and accuracy\n",
+ " test_loss, correct = 0, 0\n",
+ " for step_, test_ in enumerate(test_dataset):\n",
+ " logs = test_step(test_)\n",
+ " test_loss += logs['loss']\n",
+ " correct += logs['correct']\n",
+ " test_loss += logs['loss']\n",
+ " test_loss = test_loss / (step_ + 1)\n",
+ " test_accuracy = correct / len(X_test)\n",
+ " print(f\"epoch: {epoch}, test loss: {test_loss}, test accuracy: {test_accuracy}\")"
+ ]
}
],
"metadata": {
@@ -20,14 +510,14 @@
"language_info": {
"codemirror_mode": {
"name": "ipython",
- "version": 2
+ "version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
- "pygments_lexer": "ipython2",
- "version": "2.7.6"
+ "pygments_lexer": "ipython3",
+ "version": "3.11.10"
}
},
"nbformat": 4,