Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code improvement #4

Open
shubham303 opened this issue Sep 29, 2022 · 0 comments
Open

Code improvement #4

shubham303 opened this issue Sep 29, 2022 · 0 comments

Comments

@shubham303
Copy link

shubham303 commented Sep 29, 2022

I have updated the optimize function to ensure it runs faster on GPU. Please take a look.

`

def optimize_new_2(classifier, generator, batch_size = 512 ,image_size=None , num_classes=10):
    #batch = torch.randn((1,3,32,32)).float().to(setup.device)     #one extra dummy vector is defined in batch to avoid None value exception
    batch= Batch()
    batch_size+=1  # increase batch size by 1 to take care of first dummy element in  batch
    mse_loss = torch.nn.MSELoss(reduction="none")
    sample_per_class = batch_size//num_classes + 1               #512,10 = 51+1 = 52  # each class label is equally represented in batch
    
    n_random_samples=30
    n_sample_step = 32    #number of images to create at one step
    encoding_size =256
    
    if 'sngan' in str(type(generator)):
        encoding_size = 128

    if "ProgressiveGAN" in  str(type(generator)):
        encoding_size = 512

    resize=None
    if image_size is not None:
        resize = torchvision.transforms.Resize(image_size)

    random_labels = torch.randperm(num_classes)               #randomly select labels
    

    x = torch.zeros(size=(n_sample_step,),device=setup.device)   # for each image we keep track of number of iterations. 
    threshold =torch.zeros(size=(n_sample_step,),device=setup.device)   #shape=(32*1)                                                                #After every 100 iterations we decrease threshold to avoid being stuck in loop
    
    label_tensor = torch.full((n_sample_step,n_random_samples),0,device=setup.device)   # 32*30 labels with same class
    
    specimens_bucket = torch.zeros((n_sample_step*32,n_random_samples,encoding_size),device=setup.device).uniform_(-3, 3)
    s_index=0

    def get_speciemens(size):
        nonlocal s_index
        nonlocal specimens_bucket

        if s_index+size > specimens_bucket.size(0):
            specimens_bucket = torch.zeros((n_sample_step*32,n_random_samples,encoding_size),device=setup.device).uniform_(-3,3)
            s_index=0

        specimens = specimens_bucket[s_index:s_index+size]
        s_index+=size
        return specimens

    
    s1 = torch.zeros((n_sample_step,n_random_samples//3, encoding_size),device=setup.device)
    s2 = torch.zeros((n_sample_step,n_random_samples//3, encoding_size),device=setup.device)

    while batch.size(0) <= batch_size:                
        c = None
        
        x*=0      #reset all entries in x to zero     
        
        specimens = get_speciemens(n_sample_step)
        
        current_label_value = random_labels[batch.size(0)//sample_per_class]     # get current label. we sample #sample_per_class number of images for a label and then move to next label
        current_label_index = batch.size(0)//sample_per_class    #inde

        label_tensor.fill_(current_label_value)
        label = torch.nn.functional.one_hot(label_tensor, num_classes=num_classes)  # convert to one hot

        threshold.fill_(0.9)
        
        
        while batch.size(0)//sample_per_class == current_label_index:   #while sufficient images are not added for current label    
            x += 1

            y = torch.where(x%100!=0,0 , 0.1)      # reduce threshold for images that are stuck for long time
            threshold-=y
            
            with torch.no_grad():
                
                specimens = specimens.view(-1,encoding_size)   # convert shape to (32*30, encoding size)
                if "ProgressiveGAN" in str(type(generator)) :
                    noise , _ = generator.buildNoiseData(n_sample_step*n_random_samples)
                    images = generator.test(noise).to(setup.device)
                else:
                    images = generator(specimens)   #get images  shape is (32*30, 3 , 32, 32)
                
                #if resize:
                #   images = resize(images)
                
                specimens = specimens.view(n_sample_step,n_random_samples,encoding_size)   #convert back to original shape  shape=(32,30,128)
               
                output = classifier(images)    # get classifier output for images  shape = (32*30, 10)

                output = output.view(n_sample_step,n_random_samples,num_classes)     #reshape  shape=(32,30,10)

           
            softmaxes = torch.nn.functional.softmax(output,dim=2)          #get softmax output
            losses = torch.sum(mse_loss(softmaxes, label), dim=2)   #calculate loss for each image shape =(32*30,)
            
            images = images.view(n_sample_step,n_random_samples,3,image_size,image_size)                #reshape images shape = (32,30,3,32,32)
            
            indexes = torch.argsort(losses , dim=1)          #sort losses at dimension-1
            
             # select index of best loss
            first_index = indexes[:, 0:1].view(n_sample_step, 1, 1,1,1).expand(-1,-1,3,image_size,image_size)

            
            first_ten_index = indexes[:, 0:n_random_samples//3].view(n_sample_step, n_random_samples//3, 1).expand(-1,-1,encoding_size)
            

            image = torch.gather(images,1,first_index).squeeze(1)    # select best out of 30 samples for each of 12 images  shape=(12,3,32,32) 
           
            
            best_ten_specimens =torch.gather(specimens,1,first_ten_index)    #select top 10 specimens for each of 12 images  shape=(12,10,encoding_size)            
            
            specimens = torch.cat([
                best_ten_specimens,
                best_ten_specimens + s1.normal_(0.5,0.5),
                best_ten_specimens + s2.normal_(0.5,0.5)
            ], dim=1)    #concatenate
            
           
            
            first_index_softmax = indexes[:, 0:1].view(n_sample_step, 1, 1).expand(-1,-1,num_classes)
            
            c = torch.gather(output,1,first_index_softmax).squeeze(1)   # for each of 32 images, selct softmax output of best out of 30 samples shape=(12,10)
            c = c[:,current_label_value]    # for each of 32 images select softmax score for the label

            index_list = (c>threshold).int().nonzero().squeeze(1) # find for which images softmax score is more than 0.9  shape =(-1,30,128)
            
            images=torch.index_select(image, 0, index_list)   #add selected images to batch
            

            index_list = index_list.view(-1,1,1).expand(-1,n_random_samples,128)   # change dimension of index_list to match with specimens shape
            
            new_specimens = get_speciemens(index_list.size(0))
            
            specimens.scatter_(dim=0,index=index_list,src = new_specimens)   #update specimens for selected images        

            threshold = torch.where(c>threshold,0.9, threshold)   # for selected images reset threshold back to 0.9 shape
            x = torch.where(c>threshold,0, x)
            
            batch.append(images)
        
        

    batch = torch.cat(batch.batch)
    
    return torch.index_select(batch, 0 , torch.randperm(len(batch),device=setup.device))    

`
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant