-
Notifications
You must be signed in to change notification settings - Fork 1
/
pytorchUtility.py
87 lines (70 loc) · 2.83 KB
/
pytorchUtility.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import time
from builtins import range
import torch
import torch.nn as nn
import numpy as np
from collections import Counter
def calAccuracy(output, target, topk=(1,)):
"""Computes the precision@k for the specified values of k"""
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
#print(pred.type(), pred.size())
correct = pred.eq(target.view(1, -1).expand_as(pred))
#print(target.type(), target.size())
res = []
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res.append(correct_k.mul_(100.0 / batch_size))
return res
def calAveragePredictionVectorAccuracy(predictionVectorsList, target, modelsList=None, topk=(1,)):
predictionVectorsStack = torch.stack(predictionVectorsList)
if len(modelsList) > 0:
predictionVectorsStack = predictionVectorsStack[modelsList,...]
averagePrediction = torch.mean(predictionVectorsStack, dim=0)
return calAccuracy(averagePrediction, target, topk)
def calNegativeSamplesSet(predictionVectorsList, target):
"""filter the disagreed samples, return an array of sets"""
batchSize = target.size(0)
predictionList = list()
negativeSamplesSet = list()
for pVL in predictionVectorsList:
_, pred = pVL.max(dim=1)
predictionList.append(pred)
negativeSamplesSet.append(set())
for i in range(batchSize):
for j,_ in enumerate(predictionList):
if predictionList[j][i] != target[i]:
negativeSamplesSet[j].add(i)
return negativeSamplesSet
def calDisagreementSamplesOneTargetNegative(predictionVectorsList, target, oneTargetIdx):
"""filter the disagreed samples"""
batchSize = target.size(0)
predictionList = list()
for pVL in predictionVectorsList:
_, pred = pVL.max(dim=1)
predictionList.append(pred)
# return sampleID, sampleTarget, predictions, predVectors
sampleID = list()
sampleTarget = list()
predictions = list()
predVectors = list()
for i in range(batchSize):
pred = []
predVect = []
for j, p in enumerate(predictionList):
pred.append(p[i].item())
predVect.append(predictionVectorsList[j][i])
if predictionList[oneTargetIdx][i] != target[i]:
sampleID.append(i)
sampleTarget.append(target[i].item())
predictions.append(pred)
predVectors.append(predVect)
return sampleID, sampleTarget, predictions, predVectors
def filterModelsFixed(sampleID, sampleTarget, predictions, predVectors, selectModels):
filteredPredictions = predictions[:, selectModels]
#print(filteredPredictions.shape)
filteredPredVectors = predVectors[:, selectModels]
return sampleID, sampleTarget, filteredPredictions, filteredPredVectors