Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 10, 2024
1 parent 2697120 commit 4b94a98
Showing 1 changed file with 15 additions and 15 deletions.
30 changes: 15 additions & 15 deletions test/3x/torch/algorithms/fp8_quant/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
import os
import sys
import torch
import time

import habana_frameworks.torch.core as htcore

from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10)
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
out = x.view(-1,28*28)
out = x.view(-1, 28 * 28)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
out = F.log_softmax(out, dim=1)
return out


model = Net()
model_link = "https://vault.habana.ai/artifactory/misc/inference/mnist/mnist-epoch_20.pth"
model_path = "/tmp/.neural_compressor/mnist-epoch_20.pth"
Expand All @@ -36,14 +38,12 @@ def forward(self, x):
model = model.to("hpu")


transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))])
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

data_path = './data'
test_kwargs = {'batch_size': 32}
data_path = "./data"
test_kwargs = {"batch_size": 32}
dataset1 = datasets.MNIST(data_path, train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(dataset1,**test_kwargs)
test_loader = torch.utils.data.DataLoader(dataset1, **test_kwargs)

correct = 0
for batch_idx, (data, label) in enumerate(test_loader):
Expand All @@ -56,4 +56,4 @@ def forward(self, x):

correct += output.max(1)[1].eq(label).sum()

print('Accuracy: {:.2f}%'.format(100. * correct / (len(test_loader) * 32)))
print("Accuracy: {:.2f}%".format(100.0 * correct / (len(test_loader) * 32)))

0 comments on commit 4b94a98

Please sign in to comment.