You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have updated the optimize function to ensure it runs faster on GPU. Please take a look.
`
def optimize_new_2(classifier, generator, batch_size = 512 ,image_size=None , num_classes=10):
#batch = torch.randn((1,3,32,32)).float().to(setup.device) #one extra dummy vector is defined in batch to avoid None value exception
batch= Batch()
batch_size+=1 # increase batch size by 1 to take care of first dummy element in batch
mse_loss = torch.nn.MSELoss(reduction="none")
sample_per_class = batch_size//num_classes + 1 #512,10 = 51+1 = 52 # each class label is equally represented in batch
n_random_samples=30
n_sample_step = 32 #number of images to create at one step
encoding_size =256
if 'sngan' in str(type(generator)):
encoding_size = 128
if "ProgressiveGAN" in str(type(generator)):
encoding_size = 512
resize=None
if image_size is not None:
resize = torchvision.transforms.Resize(image_size)
random_labels = torch.randperm(num_classes) #randomly select labels
x = torch.zeros(size=(n_sample_step,),device=setup.device) # for each image we keep track of number of iterations.
threshold =torch.zeros(size=(n_sample_step,),device=setup.device) #shape=(32*1) #After every 100 iterations we decrease threshold to avoid being stuck in loop
label_tensor = torch.full((n_sample_step,n_random_samples),0,device=setup.device) # 32*30 labels with same class
specimens_bucket = torch.zeros((n_sample_step*32,n_random_samples,encoding_size),device=setup.device).uniform_(-3, 3)
s_index=0
def get_speciemens(size):
nonlocal s_index
nonlocal specimens_bucket
if s_index+size > specimens_bucket.size(0):
specimens_bucket = torch.zeros((n_sample_step*32,n_random_samples,encoding_size),device=setup.device).uniform_(-3,3)
s_index=0
specimens = specimens_bucket[s_index:s_index+size]
s_index+=size
return specimens
s1 = torch.zeros((n_sample_step,n_random_samples//3, encoding_size),device=setup.device)
s2 = torch.zeros((n_sample_step,n_random_samples//3, encoding_size),device=setup.device)
while batch.size(0) <= batch_size:
c = None
x*=0 #reset all entries in x to zero
specimens = get_speciemens(n_sample_step)
current_label_value = random_labels[batch.size(0)//sample_per_class] # get current label. we sample #sample_per_class number of images for a label and then move to next label
current_label_index = batch.size(0)//sample_per_class #inde
label_tensor.fill_(current_label_value)
label = torch.nn.functional.one_hot(label_tensor, num_classes=num_classes) # convert to one hot
threshold.fill_(0.9)
while batch.size(0)//sample_per_class == current_label_index: #while sufficient images are not added for current label
x += 1
y = torch.where(x%100!=0,0 , 0.1) # reduce threshold for images that are stuck for long time
threshold-=y
with torch.no_grad():
specimens = specimens.view(-1,encoding_size) # convert shape to (32*30, encoding size)
if "ProgressiveGAN" in str(type(generator)) :
noise , _ = generator.buildNoiseData(n_sample_step*n_random_samples)
images = generator.test(noise).to(setup.device)
else:
images = generator(specimens) #get images shape is (32*30, 3 , 32, 32)
#if resize:
# images = resize(images)
specimens = specimens.view(n_sample_step,n_random_samples,encoding_size) #convert back to original shape shape=(32,30,128)
output = classifier(images) # get classifier output for images shape = (32*30, 10)
output = output.view(n_sample_step,n_random_samples,num_classes) #reshape shape=(32,30,10)
softmaxes = torch.nn.functional.softmax(output,dim=2) #get softmax output
losses = torch.sum(mse_loss(softmaxes, label), dim=2) #calculate loss for each image shape =(32*30,)
images = images.view(n_sample_step,n_random_samples,3,image_size,image_size) #reshape images shape = (32,30,3,32,32)
indexes = torch.argsort(losses , dim=1) #sort losses at dimension-1
# select index of best loss
first_index = indexes[:, 0:1].view(n_sample_step, 1, 1,1,1).expand(-1,-1,3,image_size,image_size)
first_ten_index = indexes[:, 0:n_random_samples//3].view(n_sample_step, n_random_samples//3, 1).expand(-1,-1,encoding_size)
image = torch.gather(images,1,first_index).squeeze(1) # select best out of 30 samples for each of 12 images shape=(12,3,32,32)
best_ten_specimens =torch.gather(specimens,1,first_ten_index) #select top 10 specimens for each of 12 images shape=(12,10,encoding_size)
specimens = torch.cat([
best_ten_specimens,
best_ten_specimens + s1.normal_(0.5,0.5),
best_ten_specimens + s2.normal_(0.5,0.5)
], dim=1) #concatenate
first_index_softmax = indexes[:, 0:1].view(n_sample_step, 1, 1).expand(-1,-1,num_classes)
c = torch.gather(output,1,first_index_softmax).squeeze(1) # for each of 32 images, selct softmax output of best out of 30 samples shape=(12,10)
c = c[:,current_label_value] # for each of 32 images select softmax score for the label
index_list = (c>threshold).int().nonzero().squeeze(1) # find for which images softmax score is more than 0.9 shape =(-1,30,128)
images=torch.index_select(image, 0, index_list) #add selected images to batch
index_list = index_list.view(-1,1,1).expand(-1,n_random_samples,128) # change dimension of index_list to match with specimens shape
new_specimens = get_speciemens(index_list.size(0))
specimens.scatter_(dim=0,index=index_list,src = new_specimens) #update specimens for selected images
threshold = torch.where(c>threshold,0.9, threshold) # for selected images reset threshold back to 0.9 shape
x = torch.where(c>threshold,0, x)
batch.append(images)
batch = torch.cat(batch.batch)
return torch.index_select(batch, 0 , torch.randperm(len(batch),device=setup.device))
`
The text was updated successfully, but these errors were encountered:
I have updated the optimize function to ensure it runs faster on GPU. Please take a look.
`
The text was updated successfully, but these errors were encountered: