diff --git a/src/__pycache__/model.cpython-39.pyc b/src/__pycache__/model.cpython-39.pyc new file mode 100644 index 0000000..cc2c942 Binary files /dev/null and b/src/__pycache__/model.cpython-39.pyc differ diff --git a/src/train.py b/src/train.py index 36dfe33..6b96980 100644 --- a/src/train.py +++ b/src/train.py @@ -1,6 +1,3 @@ -# -*- coding: utf-8 -*- -# 作者:小土堆 -# 公众号:土堆碎念 import torchvision from torch.utils.tensorboard import SummaryWriter @@ -10,6 +7,8 @@ from torch import nn from torch.utils.data import DataLoader +# 使用GPU +device = torch.device("cuda") train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(), download=True) test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(), @@ -29,9 +28,11 @@ # 创建网络模型 tudui = Tudui() +tudui = tudui.to(device) # 损失函数 loss_fn = nn.CrossEntropyLoss() +loss_fn = loss_fn.to(device) # 优化器 # learning_rate = 0.01 @@ -57,6 +58,8 @@ tudui.train() for data in train_dataloader: imgs, targets = data + imgs = imgs.to(device) + targets = targets.to(device) outputs = tudui(imgs) loss = loss_fn(outputs, targets) @@ -77,6 +80,8 @@ with torch.no_grad(): for data in test_dataloader: imgs, targets = data + imgs = imgs.to(device) + targets = targets.to(device) outputs = tudui(imgs) loss = loss_fn(outputs, targets) total_test_loss = total_test_loss + loss.item() diff --git a/src/tudui_0.pth b/src/tudui_0.pth index 94d0669..be9094c 100644 Binary files a/src/tudui_0.pth and b/src/tudui_0.pth differ diff --git a/src/tudui_1.pth b/src/tudui_1.pth new file mode 100644 index 0000000..8e83136 Binary files /dev/null and b/src/tudui_1.pth differ diff --git a/src/tudui_2.pth b/src/tudui_2.pth new file mode 100644 index 0000000..5c798c2 Binary files /dev/null and b/src/tudui_2.pth differ diff --git a/src/tudui_3.pth b/src/tudui_3.pth new file mode 100644 index 0000000..1df3c78 Binary files /dev/null and b/src/tudui_3.pth differ diff --git a/src/tudui_4.pth b/src/tudui_4.pth new file mode 100644 index 0000000..f91d29d Binary files /dev/null and b/src/tudui_4.pth differ diff --git a/src/tudui_5.pth b/src/tudui_5.pth new file mode 100644 index 0000000..533be31 Binary files /dev/null and b/src/tudui_5.pth differ diff --git a/src/tudui_6.pth b/src/tudui_6.pth new file mode 100644 index 0000000..dd403f5 Binary files /dev/null and b/src/tudui_6.pth differ diff --git a/src/tudui_7.pth b/src/tudui_7.pth new file mode 100644 index 0000000..7e1ef89 Binary files /dev/null and b/src/tudui_7.pth differ diff --git a/src/tudui_8.pth b/src/tudui_8.pth new file mode 100644 index 0000000..f9004c7 Binary files /dev/null and b/src/tudui_8.pth differ diff --git a/src/tudui_9.pth b/src/tudui_9.pth new file mode 100644 index 0000000..dbe575b Binary files /dev/null and b/src/tudui_9.pth differ