Skip to content

Commit

Permalink
fix bug in autograd example
Browse files Browse the repository at this point in the history
  • Loading branch information
lhlich committed Jun 29, 2023
1 parent bfeb647 commit c3e8315
Showing 1 changed file with 45 additions and 13 deletions.
58 changes: 45 additions & 13 deletions 06_rnns/Autograd_Simple.ipynb
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
Expand All @@ -9,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -22,7 +23,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -31,7 +32,7 @@
"7"
]
},
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -43,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -66,7 +67,7 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -115,28 +116,29 @@
" else:\n",
" self.grad += backward_grad\n",
" \n",
" gradient_to_send = backward_grad if backward_grad is not None else 1\n",
" if self.creation_op == \"add\":\n",
" # Simply send backward self.grad, since increasing either of these \n",
" # Simply send backward backward_grad, since increasing either of these \n",
" # elements will increase the output by that same amount\n",
" self.depends_on[0].backward(self.grad)\n",
" self.depends_on[1].backward(self.grad) \n",
" self.depends_on[0].backward(gradient_to_send)\n",
" self.depends_on[1].backward(gradient_to_send) \n",
"\n",
" if self.creation_op == \"mul\":\n",
"\n",
" # Calculate the derivative with respect to the first element\n",
" new = self.depends_on[1] * self.grad\n",
" new = self.depends_on[1] * gradient_to_send\n",
" # Send backward the derivative with respect to that element\n",
" self.depends_on[0].backward(new.num)\n",
"\n",
" # Calculate the derivative with respect to the second element\n",
" new = self.depends_on[0] * self.grad\n",
" new = self.depends_on[0] * gradient_to_send\n",
" # Send backward the derivative with respect to that element\n",
" self.depends_on[1].backward(new.num)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -157,6 +159,31 @@
"print(b.grad) # as expected"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"24\n"
]
}
],
"source": [
"a = NumberWithGrad(3)\n",
"b = a * 4\n",
"c = b + 3\n",
"d = b * 5\n",
"e = c + d\n",
"\n",
"\n",
"e.backward()\n",
"print(a.grad)"
]
},
{
"cell_type": "code",
"execution_count": 6,
Expand Down Expand Up @@ -202,7 +229,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3.9.13 64-bit (microsoft store)",
"language": "python",
"name": "python3"
},
Expand All @@ -216,7 +243,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
"version": "3.9.13"
},
"vscode": {
"interpreter": {
"hash": "96f88a1d939096e74b5883cdeb3bbaf3df602d5ab14210a6f9e7d8e0ea241fea"
}
}
},
"nbformat": 4,
Expand Down

0 comments on commit c3e8315

Please sign in to comment.