From 4b94a9878db7e2b538f7008f82b88cff9efaf6b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Jul 2024 05:22:11 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/fp8_quant/test_basic.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test/3x/torch/algorithms/fp8_quant/test_basic.py b/test/3x/torch/algorithms/fp8_quant/test_basic.py index ae59de65da1..f6167f81ab6 100644 --- a/test/3x/torch/algorithms/fp8_quant/test_basic.py +++ b/test/3x/torch/algorithms/fp8_quant/test_basic.py @@ -1,29 +1,31 @@ import os import sys -import torch import time import habana_frameworks.torch.core as htcore - -from torch.utils.data import DataLoader -from torchvision import transforms, datasets +import torch import torch.nn as nn import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + class Net(nn.Module): def __init__(self): super(Net, self).__init__() - self.fc1 = nn.Linear(784, 256) - self.fc2 = nn.Linear(256, 64) - self.fc3 = nn.Linear(64, 10) + self.fc1 = nn.Linear(784, 256) + self.fc2 = nn.Linear(256, 64) + self.fc3 = nn.Linear(64, 10) + def forward(self, x): - out = x.view(-1,28*28) + out = x.view(-1, 28 * 28) out = F.relu(self.fc1(out)) out = F.relu(self.fc2(out)) out = self.fc3(out) out = F.log_softmax(out, dim=1) return out + model = Net() model_link = "https://vault.habana.ai/artifactory/misc/inference/mnist/mnist-epoch_20.pth" model_path = "/tmp/.neural_compressor/mnist-epoch_20.pth" @@ -36,14 +38,12 @@ def forward(self, x): model = model.to("hpu") -transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,))]) +transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) -data_path = './data' -test_kwargs = {'batch_size': 32} +data_path = "./data" +test_kwargs = {"batch_size": 32} dataset1 = datasets.MNIST(data_path, train=False, download=True, transform=transform) -test_loader = torch.utils.data.DataLoader(dataset1,**test_kwargs) +test_loader = torch.utils.data.DataLoader(dataset1, **test_kwargs) correct = 0 for batch_idx, (data, label) in enumerate(test_loader): @@ -56,4 +56,4 @@ def forward(self, x): correct += output.max(1)[1].eq(label).sum() -print('Accuracy: {:.2f}%'.format(100. * correct / (len(test_loader) * 32))) \ No newline at end of file +print("Accuracy: {:.2f}%".format(100.0 * correct / (len(test_loader) * 32)))