Skip to content

Commit

Permalink
Merge pull request #4 from pytorch/numpy
Browse files Browse the repository at this point in the history
fix ToTensor to handle numpy
  • Loading branch information
soumith authored Nov 12, 2016
2 parents 63dabca + e659e27 commit 44da562
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
35 changes: 29 additions & 6 deletions test/cifar.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,35 @@
import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms

print('\n\nCifar 10')
a = dset.CIFAR10(root="abc/def/ghi", download=True)
# print('\n\nCifar 10')
# a = dset.CIFAR10(root="abc/def/ghi", download=True)

print(a[3])
# print(a[3])

print('\n\nCifar 100')
a = dset.CIFAR100(root="abc/def/ghi", download=True)
# print('\n\nCifar 100')
# a = dset.CIFAR100(root="abc/def/ghi", download=True)

print(a[3])
# print(a[3])


dataset = dset.CIFAR10(root='cifar', download=True, transform=transforms.ToTensor())
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32,
shuffle=True, num_workers=2)


# miter = dataloader.__iter__()
# def getBatch():
# global miter
# try:
# return miter.next()
# except StopIteration:
# miter = dataloader.__iter__()
# return miter.next()

# i=0
# while True:
# print(i)
# img, target = getBatch()
# i+=1

16 changes: 11 additions & 5 deletions torchvision/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import random
from PIL import Image
import numpy as np


class Compose(object):
Expand All @@ -16,11 +17,16 @@ def __call__(self, img):

class ToTensor(object):
def __call__(self, pic):
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[0], pic.size[1], 3)
# put it in CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 2).transpose(1, 2).contiguous()
if isinstance(pic, np.ndarray):
# handle numpy array
img = torch.from_numpy(pic)
else:
# handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[0], pic.size[1], 3)
# put it in CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 2).transpose(1, 2).contiguous()
return img.float()

class Normalize(object):
Expand Down

0 comments on commit 44da562

Please sign in to comment.