From 3a336f32df89cce89f611827b31315d8cb7583c9 Mon Sep 17 00:00:00 2001 From: lhlich Date: Thu, 29 Jun 2023 00:25:54 -0700 Subject: [PATCH] Fix a bug in autograd example --- 06_rnns/Autograd_Simple.ipynb | 36 ++++++++++++++++++++++++++++++----- 1 file changed, 31 insertions(+), 5 deletions(-) diff --git a/06_rnns/Autograd_Simple.ipynb b/06_rnns/Autograd_Simple.ipynb index cce30d9..1d04ab2 100644 --- a/06_rnns/Autograd_Simple.ipynb +++ b/06_rnns/Autograd_Simple.ipynb @@ -115,21 +115,22 @@ " else:\n", " self.grad += backward_grad\n", " \n", + " gradient_to_send = backward_grad if backward_grad is not None else 1\n", " if self.creation_op == \"add\":\n", - " # Simply send backward self.grad, since increasing either of these \n", + " # Simply send backward backward_grad, since increasing either of these \n", " # elements will increase the output by that same amount\n", - " self.depends_on[0].backward(self.grad)\n", - " self.depends_on[1].backward(self.grad) \n", + " self.depends_on[0].backward(gradient_to_send)\n", + " self.depends_on[1].backward(gradient_to_send) \n", "\n", " if self.creation_op == \"mul\":\n", "\n", " # Calculate the derivative with respect to the first element\n", - " new = self.depends_on[1] * self.grad\n", + " new = self.depends_on[1] * gradient_to_send\n", " # Send backward the derivative with respect to that element\n", " self.depends_on[0].backward(new.num)\n", "\n", " # Calculate the derivative with respect to the second element\n", - " new = self.depends_on[0] * self.grad\n", + " new = self.depends_on[0] * gradient_to_send\n", " # Send backward the derivative with respect to that element\n", " self.depends_on[1].backward(new.num)" ] @@ -157,6 +158,31 @@ "print(b.grad) # as expected" ] }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "24\n" + ] + } + ], + "source": [ + "a = NumberWithGrad(3)\n", + "b = a * 4\n", + "c = b + 3\n", + "d = b * 5\n", + "e = c + d\n", + "\n", + "\n", + "e.backward()\n", + "print(a.grad)" + ] + }, { "cell_type": "code", "execution_count": 6,