-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a7a3b46
commit 4600d35
Showing
2 changed files
with
406 additions
and
0 deletions.
There are no files selected for viewing
274 changes: 274 additions & 0 deletions
274
competition/科大讯飞AI开发者大赛-脑PET图像分析和疾病预测挑战赛算法挑战大赛/1_train.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
# -*- coding: utf-8 -*- | ||
import os, sys, glob, argparse | ||
import pandas as pd | ||
import numpy as np | ||
from tqdm import tqdm | ||
|
||
import time, datetime | ||
import pdb, traceback | ||
|
||
import cv2 | ||
# import imagehash | ||
from PIL import Image | ||
|
||
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold | ||
|
||
# from efficientnet_pytorch import EfficientNet | ||
# model = EfficientNet.from_pretrained('efficientnet-b4') | ||
|
||
import torch | ||
torch.manual_seed(0) | ||
torch.backends.cudnn.deterministic = False | ||
torch.backends.cudnn.benchmark = True | ||
|
||
import torchvision.models as models | ||
import torchvision.transforms as transforms | ||
import torchvision.datasets as datasets | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torch.autograd import Variable | ||
from torch.utils.data.dataset import Dataset | ||
|
||
# input dataset | ||
train_jpg = np.array(glob.glob('../初赛数据/train/*/*.png')) | ||
|
||
class QRDataset(Dataset): | ||
def __init__(self, train_jpg, transform=None): | ||
self.train_jpg = train_jpg | ||
if transform is not None: | ||
self.transform = transform | ||
else: | ||
self.transform = None | ||
|
||
def __getitem__(self, index): | ||
start_time = time.time() | ||
img = Image.open(self.train_jpg[index]).convert('RGB') | ||
|
||
if self.transform is not None: | ||
img = self.transform(img) | ||
|
||
return img,torch.from_numpy(np.array(int('AD' in self.train_jpg[index]))) | ||
|
||
def __len__(self): | ||
return len(self.train_jpg) | ||
|
||
def accuracy(output, target, topk=(1,)): | ||
"""Computes the accuracy over the k top predictions for the specified values of k""" | ||
with torch.no_grad(): | ||
maxk = max(topk) | ||
batch_size = target.size(0) | ||
|
||
_, pred = output.topk(maxk, 1, True, True) | ||
pred = pred.t() | ||
correct = pred.eq(target.view(1, -1).expand_as(pred)) | ||
|
||
res = [] | ||
for k in topk: | ||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) | ||
res.append(correct_k.mul_(100.0 / batch_size)) | ||
return res | ||
|
||
class AverageMeter(object): | ||
"""Computes and stores the average and current value""" | ||
def __init__(self, name, fmt=':f'): | ||
self.name = name | ||
self.fmt = fmt | ||
self.reset() | ||
|
||
def reset(self): | ||
self.val = 0 | ||
self.avg = 0 | ||
self.sum = 0 | ||
self.count = 0 | ||
|
||
def update(self, val, n=1): | ||
self.val = val | ||
self.sum += val * n | ||
self.count += n | ||
self.avg = self.sum / self.count | ||
|
||
def __str__(self): | ||
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' | ||
return fmtstr.format(**self.__dict__) | ||
|
||
class ProgressMeter(object): | ||
def __init__(self, num_batches, *meters): | ||
self.batch_fmtstr = self._get_batch_fmtstr(num_batches) | ||
self.meters = meters | ||
self.prefix = "" | ||
|
||
|
||
def pr2int(self, batch): | ||
entries = [self.prefix + self.batch_fmtstr.format(batch)] | ||
entries += [str(meter) for meter in self.meters] | ||
print('\t'.join(entries)) | ||
|
||
def _get_batch_fmtstr(self, num_batches): | ||
num_digits = len(str(num_batches // 1)) | ||
fmt = '{:' + str(num_digits) + 'd}' | ||
return '[' + fmt + '/' + fmt.format(num_batches) + ']' | ||
|
||
class VisitNet(nn.Module): | ||
def __init__(self): | ||
super(VisitNet, self).__init__() | ||
|
||
model = models.resnet34(True) | ||
model.avgpool = nn.AdaptiveAvgPool2d(1) | ||
model.fc = nn.Linear(512, 2) | ||
self.resnet = model | ||
|
||
def forward(self, img): | ||
out = self.resnet(img) | ||
return out | ||
|
||
def validate(val_loader, model, criterion): | ||
batch_time = AverageMeter('Time', ':6.3f') | ||
losses = AverageMeter('Loss', ':.4e') | ||
top1 = AverageMeter('Acc@1', ':6.2f') | ||
top5 = AverageMeter('Acc@2', ':6.2f') | ||
progress = ProgressMeter(len(val_loader), batch_time, losses, top1, top5) | ||
|
||
# switch to evaluate mode | ||
model.eval() | ||
|
||
with torch.no_grad(): | ||
end = time.time() | ||
for i, (input, target) in enumerate(val_loader): | ||
input = input.cuda() | ||
target = target.cuda() | ||
|
||
# compute output | ||
output = model(input) | ||
loss = criterion(output, target) | ||
|
||
# measure accuracy and record loss | ||
acc1, acc5 = accuracy(output, target, topk=(1, 2)) | ||
losses.update(loss.item(), input.size(0)) | ||
top1.update(acc1[0], input.size(0)) | ||
top5.update(acc5[0], input.size(0)) | ||
|
||
# measure elapsed time | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
# TODO: this should also be done with the ProgressMeter | ||
print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}' | ||
.format(top1=top1, top5=top5)) | ||
return top1 | ||
|
||
def predict(test_loader, model, tta=10): | ||
# switch to evaluate mode | ||
model.eval() | ||
|
||
test_pred_tta = None | ||
for _ in range(tta): | ||
test_pred = [] | ||
with torch.no_grad(): | ||
end = time.time() | ||
for i, (input, target) in enumerate(test_loader): | ||
input = input.cuda() | ||
target = target.cuda() | ||
|
||
# compute output | ||
output = model(input, path) | ||
output = output.data.cpu().numpy() | ||
|
||
test_pred.append(output) | ||
test_pred = np.vstack(test_pred) | ||
|
||
if test_pred_tta is None: | ||
test_pred_tta = test_pred | ||
else: | ||
test_pred_tta += test_pred | ||
|
||
return test_pred_tta | ||
|
||
def train(train_loader, model, criterion, optimizer, epoch): | ||
batch_time = AverageMeter('Time', ':6.3f') | ||
# data_time = AverageMeter('Data', ':6.3f') | ||
losses = AverageMeter('Loss', ':.4e') | ||
top1 = AverageMeter('Acc@1', ':6.2f') | ||
# top5 = AverageMeter('Acc@5', ':6.2f') | ||
progress = ProgressMeter(len(train_loader), batch_time, losses, top1) | ||
|
||
# switch to train mode | ||
model.train() | ||
|
||
end = time.time() | ||
for i, (input, target) in enumerate(train_loader): | ||
input = input.cuda(non_blocking=True) | ||
target = target.cuda(non_blocking=True) | ||
|
||
# compute output | ||
output = model(input) | ||
loss = criterion(output, target) | ||
|
||
# measure accuracy and record loss | ||
acc1, acc5 = accuracy(output, target, topk=(1, 2)) | ||
losses.update(loss.item(), input.size(0)) | ||
top1.update(acc1[0], input.size(0)) | ||
# top5.update(acc5[0], input.size(0)) | ||
|
||
# compute gradient and do SGD step | ||
optimizer.zero_grad() | ||
loss.backward() | ||
optimizer.step() | ||
|
||
# measure elapsed time | ||
batch_time.update(time.time() - end) | ||
end = time.time() | ||
|
||
if i % 100 == 0: | ||
progress.pr2int(i) | ||
|
||
skf = KFold(n_splits=10, random_state=233, shuffle=True) | ||
for flod_idx, (train_idx, val_idx) in enumerate(skf.split(train_jpg, train_jpg)): | ||
train_loader = torch.utils.data.DataLoader( | ||
QRDataset(train_jpg[train_idx], | ||
transforms.Compose([ | ||
# transforms.RandomGrayscale(), | ||
transforms.Resize((512, 512)), | ||
transforms.RandomAffine(10), | ||
# transforms.ColorJitter(hue=.05, saturation=.05), | ||
# transforms.RandomCrop((450, 450)), | ||
transforms.RandomHorizontalFlip(), | ||
transforms.RandomVerticalFlip(), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
), batch_size=10, shuffle=True, num_workers=20, pin_memory=True | ||
) | ||
|
||
val_loader = torch.utils.data.DataLoader( | ||
QRDataset(train_jpg[val_idx], | ||
transforms.Compose([ | ||
transforms.Resize((512, 512)), | ||
# transforms.Resize((124, 124)), | ||
# transforms.RandomCrop((450, 450)), | ||
# transforms.RandomCrop((88, 88)), | ||
transforms.ToTensor(), | ||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
), batch_size=10, shuffle=False, num_workers=10, pin_memory=True | ||
) | ||
|
||
|
||
model = VisitNet().cuda() | ||
# model = nn.DataParallel(model).cuda() | ||
criterion = nn.CrossEntropyLoss().cuda() | ||
optimizer = torch.optim.SGD(model.parameters(), 0.01) | ||
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=4, gamma=0.85) | ||
best_acc = 0.0 | ||
for epoch in range(10): | ||
scheduler.step() | ||
print('Epoch: ', epoch) | ||
|
||
train(train_loader, model, criterion, optimizer, epoch) | ||
val_acc = validate(val_loader, model, criterion) | ||
|
||
if val_acc.avg.item() > best_acc: | ||
best_acc = val_acc.avg.item() | ||
torch.save(model.state_dict(), './resnet18_fold{0}.pt'.format(flod_idx)) | ||
|
||
break |
Oops, something went wrong.