Skip to content

Commit

Permalink
add 4.5
Browse files Browse the repository at this point in the history
  • Loading branch information
ShusenTang committed Mar 16, 2019
1 parent c0bc547 commit ad26727
Show file tree
Hide file tree
Showing 2 changed files with 408 additions and 0 deletions.
254 changes: 254 additions & 0 deletions code/chapter04_DL_computation/4.5_read-write.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 4.5 读取和存储"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0.4.1\n"
]
}
],
"source": [
"import torch\n",
"from torch import nn\n",
"\n",
"print(torch.__version__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4.5.1 读写`Tensor`"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": [
"x = torch.ones(3)\n",
"torch.save(x, 'x.pt')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([1., 1., 1.])"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x2 = torch.load('x.pt')\n",
"x2"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[tensor([1., 1., 1.]), tensor([0., 0., 0., 0.])]"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"y = torch.zeros(4)\n",
"torch.save([x, y], 'xy.pt')\n",
"xy_list = torch.load('xy.pt')\n",
"xy_list"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'x': tensor([1., 1., 1.]), 'y': tensor([0., 0., 0., 0.])}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.save({'x': x, 'y': y}, 'xy_dict.pt')\n",
"xy = torch.load('xy_dict.pt')\n",
"xy"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4.5.2 读写模型\n",
"### 4.5.2.1 `state_dict`"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('hidden.weight', tensor([[ 0.1836, -0.1812, -0.1681],\n",
" [ 0.0406, 0.3061, 0.4599]])),\n",
" ('hidden.bias', tensor([-0.3384, 0.1910])),\n",
" ('output.weight', tensor([[0.0380, 0.4919]])),\n",
" ('output.bias', tensor([0.1451]))])"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"class MLP(nn.Module):\n",
" def __init__(self):\n",
" super(MLP, self).__init__()\n",
" self.hidden = nn.Linear(3, 2)\n",
" self.act = nn.ReLU()\n",
" self.output = nn.Linear(2, 1)\n",
"\n",
" def forward(self, x):\n",
" a = self.act(self.hidden(x))\n",
" return self.output(a)\n",
"\n",
"net = MLP()\n",
"net.state_dict()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'param_groups': [{'dampening': 0,\n",
" 'lr': 0.001,\n",
" 'momentum': 0.9,\n",
" 'nesterov': False,\n",
" 'params': [4624483024, 4624484608, 4624484680, 4624484752],\n",
" 'weight_decay': 0}],\n",
" 'state': {}}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)\n",
"optimizer.state_dict()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4.5.2.2 保存和加载模型"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1],\n",
" [1]], dtype=torch.uint8)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X = torch.randn(2, 3)\n",
"Y = net(X)\n",
"\n",
"PATH = \"./net.pt\"\n",
"torch.save(net.state_dict(), PATH)\n",
"\n",
"net2 = MLP()\n",
"net2.load_state_dict(torch.load(PATH))\n",
"Y2 = net2(X)\n",
"Y2 == Y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": true
},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [default]",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit ad26727

Please sign in to comment.