-
Notifications
You must be signed in to change notification settings - Fork 26
/
convert.py
executable file
·144 lines (106 loc) · 5.08 KB
/
convert.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import future
import builtins
import past
import six
from timeit import default_timer as timer
from datetime import datetime
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, utils
from torch.utils.data import Dataset
import inspect
from inspect import getargspec
import os
import helpers as h
from helpers import Timer
import copy
import random
from itertools import count
from components import *
import models
import goals
from goals import *
import math
from torch.serialization import SourceChangeWarning
import warnings
parser = argparse.ArgumentParser(description='Convert a pickled PyTorch DiffAI net to an abstract onyx net which returns the interval concretization around the final logits. The first dimension of the output is the natural center, the second dimension is the lb, the third is the ub', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('-n', '--net', type=str, default=None, metavar='N', help='Saved and pickled net to use, in pynet format', required=True)
parser.add_argument('-d', '--domain', type=str, default="Point()", help='picks which abstract goals to use for testing. Uses box. Doesn\'t use time, so don\'t use Lin. Unless point, should specify a width w.')
parser.add_argument('-b', '--batch-size', type=int, default=1, help='The batch size to export. Not sure this matters.')
parser.add_argument('-o', '--out', type=str, default="convert_out/", metavar='F', help='Where to save the net.')
parser.add_argument('--update-net', type=h.str2bool, nargs='?', const=True, default=False, help="should update test net")
parser.add_argument('--net-name', type=str, choices = h.getMethodNames(models), default=None, help="update test net name")
parser.add_argument('--save-name', type=str, default=None, help="name to save the net with. Defaults to <domain>___<netfile-.pynet>.onyx")
parser.add_argument('-D', '--dataset', choices = [n for (n,k) in inspect.getmembers(datasets, inspect.isclass) if issubclass(k, Dataset)]
, default="MNIST", help='picks which dataset to use.')
parser.add_argument('--map-to-cpu', type=h.str2bool, nargs='?', const=True, default=False, help="map cuda operations in save back to cpu; enables to run on a computer without a GPU")
parser.add_argument('--tf-input', type=h.str2bool, nargs='?', const=True, default=False, help="change the shape of the input data from batch-channels-height-width (standard in pytroch) to batch-height-width-channels (standard in tf)")
args = parser.parse_args()
out_dir = args.out
if not os.path.exists(out_dir):
os.makedirs(out_dir)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always", SourceChangeWarning)
if args.map_to_cpu:
net = torch.load(args.net, map_location='cpu')
else:
net = torch.load(args.net)
net_name = None
if args.net_name is not None:
net_name = args.net_name
elif args.update_net and 'name' in dir(net):
net_name = net.name
def buildNet(n, input_dims, num_classes):
n = n(num_classes)
if args.dataset in ["MNIST"]:
n = Seq(Normalize([0.1307], [0.3081] ), n)
elif args.dataset in ["CIFAR10", "CIFAR100"]:
n = Seq(Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]), n)
elif dataset in ["SVHN"]:
n = Seq(Normalize([0.5,0.5,0.5], [0.2, 0.2, 0.2]), n)
elif dataset in ["Imagenet12"]:
n = Seq(Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225]), n)
n = n.infer(input_dims)
n.clip_norm()
return n
if net_name is not None:
n = getattr(models,net_name)
n = buildNet(n, net.inShape, net.outShape)
n.load_state_dict(net.state_dict())
net = n
net = net.to(h.device)
net.remove_norm()
domain = eval(args.domain)
if args.save_name is None:
save_name = h.prepareDomainNameForFile(args.domain) + "___" + os.path.basename(args.net)[:-6] + ".onyx"
else:
save_name = args.save_name
def abstractNet(inpt):
if args.tf_input:
inpt = inpt.permute(0, 3, 1, 2)
dom = domain.box(inpt, w = None)
o = net(dom, onyx=True).unsqueeze(1)
out = torch.cat([o.vanillaTensorPart(), o.lb().vanillaTensorPart(), o.ub().vanillaTensorPart()], dim=1)
return out
input_shape = [args.batch_size] + list(net.inShape)
if args.tf_input:
input_shape = [args.batch_size] + list(net.inShape)[1:] + [net.inShape[0]]
dummy = h.zeros(input_shape)
abstractNet(dummy)
class AbstractNet(nn.Module):
def __init__(self, domain, net, abstractNet):
super(AbstractNet, self).__init__()
self.net = net
self.abstractNet = abstractNet
if hasattr(domain, "net") and domain.net is not None:
self.netDom = domain.net
def forward(self, inpt):
return self.abstractNet(inpt)
absNet = AbstractNet(domain, net, abstractNet)
out_path = os.path.join(out_dir, save_name)
print("Saving:", out_path)
param_list = ["param"+str(i) for i in range(len(list(absNet.parameters())))]
torch.onnx.export(absNet, dummy, out_path, verbose=False, input_names=["actual_input"] + param_list, output_names=["output"])