Skip to content

Commit

Permalink
Add create_graph=True to the original file
Browse files Browse the repository at this point in the history
in order to make the 2nd order derivatives available.
dragen1860#32
Now the regularizer coefficient makes differences.
  • Loading branch information
hummarow committed Sep 26, 2022
1 parent 38ee1a3 commit 9143862
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 35 deletions.
117 changes: 91 additions & 26 deletions meta.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch import optim
import numpy as np

from learner import Learner
from copy import deepcopy
import torch
import numpy as np
import os
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch import linalg as LA

from learner import Learner
from copy import deepcopy


class Meta(nn.Module):
Expand All @@ -30,12 +31,14 @@ def __init__(self, args, config):
self.task_num = args.task_num
self.update_step = args.update_step
self.update_step_test = args.update_step_test

self.reg = args.reg
self.ord = args.ord

self.net = Learner(config, args.imgc, args.imgsz)
self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr)


log_path = os.path.join('logs', 'L' + str(args.ord) + '_Reg' + str(args.reg) + args.log_dir)
self.writer = SummaryWriter(log_path)


def clip_grad_by_norm_(self, grad, max_norm):
Expand All @@ -61,9 +64,8 @@ def clip_grad_by_norm_(self, grad, max_norm):

return total_norm/counter

def forward(self, x_spt, y_spt, x_qry, y_qry):
def forward(self, x_spt, y_spt, x_qry, y_qry, t):
"""
:param x_spt: [b, setsz, c_, h, w]
:param y_spt: [b, setsz]
:param x_qry: [b, querysz, c_, h, w]
Expand All @@ -76,14 +78,19 @@ def forward(self, x_spt, y_spt, x_qry, y_qry):
losses_q = [0 for _ in range(self.update_step + 1)] # losses_q[i] is the loss on step i
corrects = [0 for _ in range(self.update_step + 1)]

phi = self.net.parameters()

fast_weights = [phi] * task_num
# fast_weights = torch.empty(len(self.net.parameters()), task_num)

for i in range(task_num):

# 1. run the i-th task and compute loss for k=0
# self.net = Learner()
logits = self.net(x_spt[i], vars=None, bn_training=True)
loss = F.cross_entropy(logits, y_spt[i])
grad = torch.autograd.grad(loss, self.net.parameters())
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))
grad = torch.autograd.grad(loss, self.net.parameters(), retain_graph=True, create_graph=True)
fast_weights[i] = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters())))

# this is the loss and accuracy before first update
with torch.no_grad():
Expand All @@ -99,7 +106,7 @@ def forward(self, x_spt, y_spt, x_qry, y_qry):
# this is the loss and accuracy after the first update
with torch.no_grad():
# [setsz, nway]
logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
logits_q = self.net(x_qry[i], fast_weights[i], bn_training=True)
loss_q = F.cross_entropy(logits_q, y_qry[i])
losses_q[1] += loss_q
# [setsz]
Expand All @@ -109,14 +116,14 @@ def forward(self, x_spt, y_spt, x_qry, y_qry):

for k in range(1, self.update_step):
# 1. run the i-th task and compute loss for k=1~K-1
logits = self.net(x_spt[i], fast_weights, bn_training=True)
logits = self.net(x_spt[i], fast_weights[i], bn_training=True)
loss = F.cross_entropy(logits, y_spt[i])
# 2. compute grad on theta_pi
grad = torch.autograd.grad(loss, fast_weights)
grad = torch.autograd.grad(loss, fast_weights[i], create_graph=True, retain_graph=True)
# 3. theta_pi = theta_pi - train_lr * grad
fast_weights = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights)))
fast_weights[i] = list(map(lambda p: p[1] - self.update_lr * p[0], zip(grad, fast_weights[i])))

logits_q = self.net(x_qry[i], fast_weights, bn_training=True)
logits_q = self.net(x_qry[i], fast_weights[i], bn_training=True)
# loss_q will be overwritten and just keep the loss_q on last update step.
loss_q = F.cross_entropy(logits_q, y_qry[i])
losses_q[k + 1] += loss_q
Expand All @@ -125,23 +132,82 @@ def forward(self, x_spt, y_spt, x_qry, y_qry):
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
correct = torch.eq(pred_q, y_qry[i]).sum().item() # convert to numpy
corrects[k + 1] = corrects[k + 1] + correct
############################################################################
# Reptile
##########################################################################
# state_dict = self.net.state_dict()
# key_list = list(state_dict.keys())
# for key in state_dict.keys():
# if 'bn' in key:
# key_list.remove(key)
# for j, key in enumerate(key_list):
# state_dict[key] = state_dict[key] - self.reg * (state_dict[key] - fast_weights[i][j])
#
# self.net.load_state_dict(state_dict)

# print(losses_q)
# input()

# end of all tasks
# sum over all losses on query set across all tasks
loss_q = losses_q[-1] / task_num

####################################################################################
# Weight Clustering
####################################################################################

weight_flat = []

for i in range(len(fast_weights)):
w = None
for fw in fast_weights[i]:
if not torch.is_tensor(w):
w = torch.flatten(fw)
else:
w = torch.cat([w, torch.flatten(fw)], dim=0)
weight_flat.append(w)

weight_flat = torch.stack(weight_flat, axis=1)

average_weight = torch.mean(weight_flat, dim=1, keepdim=True)

# # Reset origin
weight_flat = weight_flat - average_weight
# norm = torch.norm(weight_flat, p='fro')

norm = LA.vector_norm(weight_flat, ord=self.ord)
# norm = norm * norm

self.writer.add_scalar("Distance", norm, t)
self.writer.add_scalar("loss", loss_q, t)

loss_q += self.reg * norm

###############################################################################
# End Weight Clustering
####################################################################################

self.writer.add_scalar("loss+Distance", loss_q, t)
# MAML
# optimize theta parameters
self.meta_optim.zero_grad()
loss_q.backward()

# print('meta update')
# for p in self.net.parameters()[:5]:
# print(torch.norm(p).item())
self.meta_optim.step()

self.meta_optim.step()

# Reptile
# state_dict = self.net.state_dict()
# key_list = list(state_dict.keys())
# for key in state_dict.keys():
# if 'bn' in key:
# key_list.remove(key)
# for fw in fast_weights:
# for i, key in enumerate(key_list):
# state_dict[key] = state_dict[key] - fw[i]
# self.net.load_state_dict(state_dict)

accs = np.array(corrects) / (querysz * task_num)

return accs
Expand Down Expand Up @@ -218,11 +284,10 @@ def finetunning(self, x_spt, y_spt, x_qry, y_qry):
return accs




def main():
pass


if __name__ == '__main__':
main()

48 changes: 39 additions & 9 deletions miniimagenet_train.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
import torch, os
import numpy as np
from MiniImagenet import MiniImagenet
import scipy.stats
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
import random, sys, pickle
import argparse
import torch, os
import numpy as np
import scipy.stats
import random, sys, pickle
import argparse
import tensorboard
import matplotlib.pyplot as plt
import os
from MiniImagenet import MiniImagenet
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim import lr_scheduler
from collections import defaultdict

from meta import Meta


imagenet_path = '/home/bjk/Datasets/mini-imagenet/'

# Tensorboard custom scalar layout
layout = {
"Accuracy": {
"Accuracy": ["Multiline", ["Accuracy/Train", "Accuracy/Test"]]
}
}


def mean_confidence_interval(accs, confidence=0.95):
n = accs.shape[0]
Expand Down Expand Up @@ -67,6 +79,10 @@ def main():
k_query=args.k_qry,
batchsz=100, resize=args.imgsz)

log_path = os.path.join('logs', 'L' + str(args.ord) + '_Reg' + str(args.reg) + args.log_dir)
writer = SummaryWriter(log_path)
writer.add_custom_scalars(layout)

for epoch in range(args.epoch//10000):
# fetch meta_batchsz num of episode each time
db = DataLoader(mini, args.task_num, shuffle=True, num_workers=1, pin_memory=True)
Expand All @@ -75,10 +91,13 @@ def main():

x_spt, y_spt, x_qry, y_qry = x_spt.to(device), y_spt.to(device), x_qry.to(device), y_qry.to(device)

accs = maml(x_spt, y_spt, x_qry, y_qry)
accs = maml(x_spt, y_spt, x_qry, y_qry, step+len(db)*epoch)

if step % 30 == 0:
print('step:', step, '\ttraining acc:', accs)
writer.add_scalar('Accuracy/Train',
accs[-1],
step + epoch*len(db))

if step % 500 == 0 or step == (len(db) - 1): # evaluation
db_test = DataLoader(mini_test, 1, shuffle=True, num_workers=1, pin_memory=True)
Expand All @@ -94,6 +113,12 @@ def main():
# [b, update_step+1]
accs = np.array(accs_all_test).mean(axis=0).astype(np.float16)
print('Test acc:', accs)
writer.add_scalar('Accuracy',
accs[-1],
step + epoch*len(db))
writer.add_scalar('Accuracy/Test',
accs[-1],
step + epoch*len(db))


if __name__ == '__main__':
Expand All @@ -110,7 +135,12 @@ def main():
argparser.add_argument('--update_lr', type=float, help='task-level inner update learning rate', default=0.01)
argparser.add_argument('--update_step', type=int, help='task-level inner update steps', default=5)
argparser.add_argument('--update_step_test', type=int, help='update steps for finetunning', default=10)
argparser.add_argument('--reg', type=float, help='coefficient for regularizer', default=1.0)
argparser.add_argument('--log_dir', type=str, help='log directory for tensorboard', default='')
argparser.add_argument('--ord', type=int, help='order of norms among fine-tuned weights', default=2)

args = argparser.parse_args()
if args.log_dir != '':
args.log_dir = '_' + args.log_dir

main()

0 comments on commit 9143862

Please sign in to comment.