From fec0f745e68374579d3556cec9514cdd75cc2b3c Mon Sep 17 00:00:00 2001 From: noopy Date: Sat, 4 Sep 2021 07:59:51 +0900 Subject: [PATCH] add testdataset class --- dataset.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/dataset.py b/dataset.py index a08edaf..0f48ec0 100644 --- a/dataset.py +++ b/dataset.py @@ -108,3 +108,21 @@ class GenderLabels: class AgeGroup: map_label = lambda x: 0 if int(x) < 30 else 1 if int(x) < 60 else 2 + + +class TestDataset(Dataset): + def __init__(self, img_paths, transform, device): + self.img_paths = img_paths + self.device = device + self.transform = transform + + def __getitem__(self, index): + image = Image.open(self.img_paths[index]) + + if self.transform: + image = self.transform(image=np.array(image))['image'] + + return image.to(self.device) + + def __len__(self): + return len(self.img_paths)