-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some questions about training a custom dataset? #43
Comments
Can you teach me how to create a support set for my dataset |
@hjh151220 from torch.utils.data import Dataset
import os
from torchvision import transforms
from PIL import Image
import torch
class MVTecSupportDataset(Dataset):
def __init__(self,
dataset_path='../data/mvtec_anomaly_detection',
class_name='bottle',
is_train=True,
resize=224,
):
self.dataset_path = dataset_path
self.class_name = class_name
self.is_train = is_train
self.resize = resize
# load dataset
self.support_dir = self.load_dataset_folder()
# set transforms
self.transform = transforms.Compose([
transforms.Resize(resize, Image.ANTIALIAS),
transforms.ToTensor(),
])
def __getitem__(self, idx):
support = self.support_dir[idx]
support_img = Image.open(support).convert('RGB')
support_img = self.transform(support_img)
return support_img
def __len__(self):
return len(self.support_dir)
def load_dataset_folder(self):
phase = 'train'
img_dir = os.path.join(self.dataset_path, self.class_name, phase, 'good')
import pdb;pdb.set_trace()
img_fpath_list = sorted(
[os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')])
return img_fpath_list
data_path = "data"
obj = "bottle"
img_size = (224, 224)
shot = 8
inferences = 1
support_dataset = MVTecSupportDataset(data_path, class_name=obj, is_train=True, resize=img_size)
support_data_loader = torch.utils.data.DataLoader(support_dataset, batch_size=shot, shuffle=True)
save_img_list = []
for i in range(inferences):
support_img = iter(support_data_loader).next()
save_img_list.append(support_img)
torch.save(save_img_list, f'regAD/data/support_set/{obj}/{shot}_{inferences}.pt') |
Thank you for your reply.I will try and contact you if I have any other problems. Thank you for your help
…-----原始邮件-----
发件人:iangiu ***@***.***>
发送时间:2023-11-06 16:52:50 (星期一)
收件人: MediaBrain-SJTU/RegAD ***@***.***>
抄送: hjh151220 ***@***.***>, Mention ***@***.***>
主题: Re: [MediaBrain-SJTU/RegAD] Some questions about training a custom dataset? (Issue #43)
@hjh151220
hi, you can refer to the following codes:
from torch.utils.data import Dataset
import os
from torchvision import transforms
from PIL import Image
import torch
class MVTecSupportDataset(Dataset):
def init(self,
dataset_path='../data/mvtec_anomaly_detection',
class_name='bottle',
is_train=True,
resize=224,
):
self.dataset_path = dataset_path
self.class_name = class_name
self.is_train = is_train
self.resize = resize
# load dataset
self.support_dir = self.load_dataset_folder()
# set transforms
self.transform = transforms.Compose([
transforms.Resize(resize, Image.ANTIALIAS),
transforms.ToTensor(),
])
def __getitem__(self, idx):
support = self.support_dir[idx]
support_img = Image.open(support).convert('RGB')
support_img = self.transform(support_img)
return support_img
def __len__(self):
return len(self.support_dir)
def load_dataset_folder(self):
phase = 'train'
img_dir = os.path.join(self.dataset_path, self.class_name, phase, 'good')
import pdb;pdb.set_trace()
img_fpath_list = sorted(
[os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith('.jpg')])
return img_fpath_list
data_path = "data"
obj = "shield"
img_size = (224, 224)
shot = 8
inferences = 1
support_dataset = MVTecSupportDataset(data_path, class_name=obj, is_train=True, resize=img_size)
support_data_loader = torch.utils.data.DataLoader(support_dataset, batch_size=shot, shuffle=True)
save_img_list = []
for i in range(inferences):
support_img = iter(support_data_loader).next()
save_img_list.append(support_img)
torch.save(save_img_list, f'regAD/data/support_set/{obj}/{shot}_{inferences}.pt')
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
Do you know how to visualize the results of detecting anomalies |
Hi @iangiu, have you found answers to your questions? I am in the same phase. Thanks. |
您好,请问您在推理的时候是否使用其他数据集的训练结果,然后来推理自己的数据集?但是作者提供的test中都需要使用到groundTruth,如果没有groundTruth,该怎么办呢?期待您的回答。 |
sorry, I haven't figure out these questions. I gave up on the project. |
需要自己标注的。这个项目我已经不跟了,感觉结果不太对。 |
好的,感谢您的解答。因为我现在只有正样本和未标注的负样本,需求也是如此,根本没可能去做标注,对于这种需求您有什么模型推荐嘛? |
https://github.com/openvinotoolkit/anomalib,可以参考这个项目,这里面有很多zero-shot的算法实现,你可以用这个做一些实验 |
非常感谢!! |
请问 我是小白, 第一次跟这个项目,为什么用这个代码找不到相应数据集的路径啊 |
请问解决了么
2024-03-24 13:23:54 "Zhang Lei" ***@***.***> 写道:
@hjh151220 hi, you can refer to the following codes:
fromtorch.utils.dataimportDatasetimportosfromtorchvisionimporttransformsfromPILimportImageimporttorchclassMVTecSupportDataset(Dataset):
def__init__(self,
dataset_path='../data/mvtec_anomaly_detection',
class_name='bottle',
is_train=True,
resize=224,
):
self.dataset_path=dataset_pathself.class_name=class_nameself.is_train=is_trainself.resize=resize# load datasetself.support_dir=self.load_dataset_folder()
# set transformsself.transform=transforms.Compose([
transforms.Resize(resize, Image.ANTIALIAS),
transforms.ToTensor(),
])
def__getitem__(self, idx):
support=self.support_dir[idx]
support_img=Image.open(support).convert('RGB')
support_img=self.transform(support_img)
returnsupport_imgdef__len__(self):
returnlen(self.support_dir)
defload_dataset_folder(self):
phase='train'img_dir=os.path.join(self.dataset_path, self.class_name, phase, 'good')
importpdb;pdb.set_trace()
img_fpath_list=sorted(
[os.path.join(img_dir, f) forfinos.listdir(img_dir) iff.endswith('.jpg')])
returnimg_fpath_listdata_path="data"obj="bottle"img_size= (224, 224)
shot=8inferences=1support_dataset=MVTecSupportDataset(data_path, class_name=obj, is_train=True, resize=img_size)
support_data_loader=torch.utils.data.DataLoader(support_dataset, batch_size=shot, shuffle=True)
save_img_list= []
foriinrange(inferences):
support_img=iter(support_data_loader).next()
save_img_list.append(support_img)
torch.save(save_img_list, f'regAD/data/support_set/{obj}/{shot}_{inferences}.pt')
请问 我是小白, 第一次跟这个项目,为什么用这个代码找不到相应数据集的路径啊
—
Reply to this email directly, view it on GitHub, or unsubscribe.
You are receiving this because you were mentioned.Message ID: ***@***.***>
|
Dear author, thank you very much for your work! However, I am puzzled by the following questions when i trained a model with a custom dataset.
Looking forward to your reply! Thanks a lot!
The text was updated successfully, but these errors were encountered: