Skip to content

Commit

Permalink
Fix a bug in autograd example
Browse files Browse the repository at this point in the history
  • Loading branch information
lhlich committed Jun 29, 2023
1 parent bfeb647 commit 3a336f3
Showing 1 changed file with 31 additions and 5 deletions.
36 changes: 31 additions & 5 deletions 06_rnns/Autograd_Simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,22 @@
" else:\n",
" self.grad += backward_grad\n",
" \n",
" gradient_to_send = backward_grad if backward_grad is not None else 1\n",
" if self.creation_op == \"add\":\n",
" # Simply send backward self.grad, since increasing either of these \n",
" # Simply send backward backward_grad, since increasing either of these \n",
" # elements will increase the output by that same amount\n",
" self.depends_on[0].backward(self.grad)\n",
" self.depends_on[1].backward(self.grad) \n",
" self.depends_on[0].backward(gradient_to_send)\n",
" self.depends_on[1].backward(gradient_to_send) \n",
"\n",
" if self.creation_op == \"mul\":\n",
"\n",
" # Calculate the derivative with respect to the first element\n",
" new = self.depends_on[1] * self.grad\n",
" new = self.depends_on[1] * gradient_to_send\n",
" # Send backward the derivative with respect to that element\n",
" self.depends_on[0].backward(new.num)\n",
"\n",
" # Calculate the derivative with respect to the second element\n",
" new = self.depends_on[0] * self.grad\n",
" new = self.depends_on[0] * gradient_to_send\n",
" # Send backward the derivative with respect to that element\n",
" self.depends_on[1].backward(new.num)"
]
Expand Down Expand Up @@ -157,6 +158,31 @@
"print(b.grad) # as expected"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"24\n"
]
}
],
"source": [
"a = NumberWithGrad(3)\n",
"b = a * 4\n",
"c = b + 3\n",
"d = b * 5\n",
"e = c + d\n",
"\n",
"\n",
"e.backward()\n",
"print(a.grad)"
]
},
{
"cell_type": "code",
"execution_count": 6,
Expand Down

0 comments on commit 3a336f3

Please sign in to comment.