diff --git a/06_rnns/Autograd_Simple.ipynb b/06_rnns/Autograd_Simple.ipynb index cce30d9..cea323e 100644 --- a/06_rnns/Autograd_Simple.ipynb +++ b/06_rnns/Autograd_Simple.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -9,7 +10,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -22,7 +23,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -31,7 +32,7 @@ "7" ] }, - "execution_count": 2, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -43,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -66,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -115,28 +116,29 @@ " else:\n", " self.grad += backward_grad\n", " \n", + " gradient_to_send = backward_grad if backward_grad is not None else 1\n", " if self.creation_op == \"add\":\n", - " # Simply send backward self.grad, since increasing either of these \n", + " # Simply send backward backward_grad, since increasing either of these \n", " # elements will increase the output by that same amount\n", - " self.depends_on[0].backward(self.grad)\n", - " self.depends_on[1].backward(self.grad) \n", + " self.depends_on[0].backward(gradient_to_send)\n", + " self.depends_on[1].backward(gradient_to_send) \n", "\n", " if self.creation_op == \"mul\":\n", "\n", " # Calculate the derivative with respect to the first element\n", - " new = self.depends_on[1] * self.grad\n", + " new = self.depends_on[1] * gradient_to_send\n", " # Send backward the derivative with respect to that element\n", " self.depends_on[0].backward(new.num)\n", "\n", " # Calculate the derivative with respect to the second element\n", - " new = self.depends_on[0] * self.grad\n", + " new = self.depends_on[0] * gradient_to_send\n", " # Send backward the derivative with respect to that element\n", " self.depends_on[1].backward(new.num)" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -157,6 +159,31 @@ "print(b.grad) # as expected" ] }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "24\n" + ] + } + ], + "source": [ + "a = NumberWithGrad(3)\n", + "b = a * 4\n", + "c = b + 3\n", + "d = b * 5\n", + "e = c + d\n", + "\n", + "\n", + "e.backward()\n", + "print(a.grad)" + ] + }, { "cell_type": "code", "execution_count": 6, @@ -202,7 +229,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3.9.13 64-bit (microsoft store)", "language": "python", "name": "python3" }, @@ -216,7 +243,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.4" + "version": "3.9.13" + }, + "vscode": { + "interpreter": { + "hash": "96f88a1d939096e74b5883cdeb3bbaf3df602d5ab14210a6f9e7d8e0ea241fea" + } } }, "nbformat": 4,