forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Changes for TF/PT Rn50 * Refactoring * Comments
- Loading branch information
1 parent
74fc376
commit cd4b887
Showing
9 changed files
with
520 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
import mxnet as mx | ||
import tvm | ||
from tvm import relay | ||
from tvm import hago | ||
from mxnet import gluon | ||
|
||
import logging | ||
logging.basicConfig(level=logging.DEBUG) | ||
def get_calibration_dataset(dataset, batch_fn, var_name, num_samples=100): | ||
dataset.reset() | ||
batches = [] | ||
for i, batch in enumerate(dataset): | ||
if i * dataset.batch_size > num_samples: | ||
break | ||
data, label = batch_fn(batch, [mx.cpu(0)]) | ||
batches.append({var_name: tvm.nd.array(data[0].asnumpy()), | ||
'label': tvm.nd.array(label[0].asnumpy())}) | ||
return hago.CalibrationDataset(batches) | ||
|
||
|
||
################## | ||
# Evaluation infra | ||
################## | ||
def eval_acc(func, dataset, batch_fn, args, var_name, target='cuda', ctx=tvm.gpu(), postprocess=None, log_interval=100): | ||
with relay.build_config(opt_level=3): | ||
graph, lib, params = relay.build(func, target) | ||
# create runtime module | ||
m = tvm.contrib.graph_runtime.create(graph, lib, ctx) | ||
m.set_input(**params) | ||
|
||
# setup evaluaiton metric | ||
dataset.reset() | ||
batch_size = dataset.batch_size | ||
acc_top1 = mx.metric.Accuracy() | ||
acc_top5 = mx.metric.TopKAccuracy(5) | ||
acc_top1.reset() | ||
acc_top5.reset() | ||
# Execute | ||
|
||
if args.soundness_check: | ||
exit_at_batch = (100 + batch_size - 1)//batch_size | ||
else: | ||
exit_at_batch = -1 | ||
|
||
for i, batch in enumerate(dataset): | ||
data, label = batch_fn(batch, [mx.cpu(0)]) | ||
m.set_input(var_name, data[0].asnumpy()) | ||
m.run() | ||
out_arr = m.get_output(0).asnumpy() | ||
if postprocess is not None: | ||
out_arr = postprocess(out_arr) | ||
acc_top1.update(label, [mx.nd.array(out_arr)]) | ||
acc_top5.update(label, [mx.nd.array(out_arr)]) | ||
|
||
if not (i + 1) % log_interval or i == exit_at_batch: | ||
_, top1 = acc_top1.get() | ||
_, top5 = acc_top5.get() | ||
nsamples = (i + 1) * batch_size | ||
logging.info('[%d samples] validation: acc-top1=%f acc-top5=%f', nsamples, top1, top5) | ||
|
||
if i == exit_at_batch: | ||
break | ||
logging.info('[final] validation: acc-top1=%f acc-top5=%f', top1, top5) | ||
return top1 | ||
|
||
|
||
################# | ||
# Quantize helper | ||
################# | ||
def quantize_hago(mod, params, calib_dataset): | ||
qconfig = hago.qconfig(skip_conv_layers=[0], | ||
log_file='temp.log') | ||
|
||
with qconfig: | ||
graph = hago.prerequisite_optimize(mod['main'], params=params) | ||
logging.debug('current quantize config') | ||
logging.debug(hago.current_qconfig()) | ||
hardware = hago.create_accelerator_description() | ||
space = hago.generate_search_space(graph, hardware) | ||
# tuner = hago.BatchedGreedySearchTuner(space, 'accuracy') | ||
tuner = hago.DefaultSetting(space, 'accuracy') | ||
ctx = tvm.cpu() | ||
strategy, result = hago.search_quantize_strategy(graph, hardware, calib_dataset, tuner, ctx, | ||
target='llvm') | ||
|
||
quantizer = hago.create_quantizer(graph, hardware, strategy) | ||
simulated_graph = quantizer.simulate() | ||
quantized_graph = quantizer.quantize() | ||
logging.debug('simulated graph') | ||
logging.debug(simulated_graph.astext(show_meta_data=False)) | ||
logging.debug('quantize graph') | ||
logging.debug(quantized_graph.astext(show_meta_data=False)) | ||
# hago.inspect_graph_statistic(graph, hardware, strategy, dataset, ctx, target='llvm') | ||
return tvm.IRModule.from_expr(quantized_graph) | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import tvm | ||
from tvm import relay | ||
|
||
import numpy as np | ||
import argparse | ||
import os | ||
|
||
import mxnet as mx | ||
from tvm import hago | ||
from mxnet import gluon | ||
|
||
from common_hago import * | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", default="resnet50_v1", help="model to quantize") | ||
parser.add_argument("--soundness_check", default=False, action='store_true') | ||
parser.add_argument("--skip_fp32", default=False, action='store_true') | ||
parser.add_argument("--run_all", default=False, action='store_true') | ||
args = parser.parse_args() | ||
|
||
batch_size = 32 | ||
target = 'llvm -mcpu=core-avx2' | ||
ctx = tvm.context(target) | ||
|
||
##################### | ||
# Dataset prepartions | ||
##################### | ||
|
||
def get_val_data(img_size, | ||
rec_val, | ||
batch_size, | ||
num_workers=4): | ||
rec_val = os.path.expanduser(rec_val) | ||
mean_rgb = [123.68, 116.779, 103.939] | ||
std_rgb = [58.393, 57.12, 57.375] | ||
def batch_fn(batch, ctx): | ||
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) | ||
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) | ||
return data, label | ||
|
||
val_data = mx.io.ImageRecordIter( | ||
path_imgrec = rec_val, | ||
preprocess_threads = num_workers, | ||
shuffle = True, | ||
seed = 0, | ||
batch_size = batch_size, | ||
resize = 256, | ||
data_shape = (3, img_size, img_size), | ||
mean_r = mean_rgb[0], | ||
mean_g = mean_rgb[1], | ||
mean_b = mean_rgb[2], | ||
std_r = std_rgb[0], | ||
std_g = std_rgb[1], | ||
std_b = std_rgb[2], | ||
) | ||
return val_data, batch_fn | ||
|
||
############################################################################### | ||
# Load the model | ||
# ---------------- | ||
def get_model(model_name): | ||
gluon_model = gluon.model_zoo.vision.get_model(model_name, pretrained=True) | ||
img_size = 299 if model_name == 'inceptionv3' else 224 | ||
data_shape = (batch_size, 3, img_size, img_size) | ||
mod, params = relay.frontend.from_mxnet(gluon_model, {"data": data_shape}) | ||
return mod, params | ||
|
||
def main(): | ||
val_path = '/home/ubuntu/tensorflow_datasets/downloads/manual/imagenet2012/val.rec' | ||
if args.run_all: | ||
models = ['resnet50_v1', 'inceptionv3', 'mobilenetv2_1.0', 'mobilenet1.0', 'resnet18_v1', | ||
'densenet161', 'vgg16'] | ||
else: | ||
models = [args.model] | ||
for model_name in models: | ||
img_size = 299 if model_name == 'inceptionv3' else 224 | ||
val_data, batch_fn = get_val_data(img_size, val_path, batch_size) | ||
|
||
if not args.skip_fp32: | ||
fp32_mod, params = get_model(model_name) | ||
func = hago.prerequisite_optimize(fp32_mod['main'], params=params) | ||
acc = eval_acc(func, val_data, batch_fn, args, var_name='data', target=target, ctx=ctx) | ||
print("fp32_accuracy", model_name, acc, sep=',') | ||
|
||
# Quantize | ||
calib_dataset = get_calibration_dataset(val_data, batch_fn, var_name='data') | ||
fp32_mod, params = get_model(model_name) | ||
quantized_func = quantize_hago(fp32_mod, params, calib_dataset) | ||
acc = eval_acc(quantized_func, val_data, batch_fn, args, var_name='data', target=target, ctx=ctx) | ||
print("quantized_accuracy", model_name, acc, sep=',') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
import tvm | ||
from tvm import relay | ||
|
||
import numpy as np | ||
import argparse | ||
|
||
import torch | ||
from torch.nn import Module | ||
import torchvision | ||
from torchvision import transforms | ||
import os | ||
|
||
import mxnet as mx | ||
from tvm import hago | ||
from mxnet import gluon | ||
|
||
from common_hago import * | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--model", default="resnet50_v1", help="model to quantize") | ||
parser.add_argument("--soundness_check", default=False, action='store_true') | ||
parser.add_argument("--skip_fp32", default=False, action='store_true') | ||
parser.add_argument("--run_all", default=False, action='store_true') | ||
args = parser.parse_args() | ||
|
||
batch_size = 32 | ||
target = 'llvm -mcpu=core-avx2' | ||
ctx = tvm.context(target) | ||
|
||
##################### | ||
# Dataset prepartions | ||
##################### | ||
|
||
def get_val_data(img_size, | ||
rec_val, | ||
batch_size, | ||
num_workers=4): | ||
rec_val = os.path.expanduser(rec_val) | ||
mean_rgb = [255 * x for x in [0.485, 0.456, 0.406]] | ||
std_rgb = [255 * x for x in [0.229, 0.224, 0.225]] | ||
def batch_fn(batch, ctx): | ||
data = gluon.utils.split_and_load(batch.data[0], ctx_list=ctx, batch_axis=0) | ||
label = gluon.utils.split_and_load(batch.label[0], ctx_list=ctx, batch_axis=0) | ||
return data, label | ||
|
||
val_data = mx.io.ImageRecordIter( | ||
path_imgrec = rec_val, | ||
preprocess_threads = num_workers, | ||
shuffle = True, | ||
seed = 0, | ||
batch_size = batch_size, | ||
resize = 256, | ||
data_shape = (3, img_size, img_size), | ||
mean_r = mean_rgb[0], | ||
mean_g = mean_rgb[1], | ||
mean_b = mean_rgb[2], | ||
std_r = std_rgb[0], | ||
std_g = std_rgb[1], | ||
std_b = std_rgb[2], | ||
) | ||
return val_data, batch_fn | ||
|
||
############################################################################### | ||
# Load the model from torchvision | ||
# ---------------- | ||
def load_model(model_name): | ||
"""Given a model name, returns a model as well as an example input.""" | ||
if hasattr(torchvision.models, model_name): | ||
with torch.no_grad(): | ||
if model_name.startswith("inception"): | ||
height = width = 299 | ||
mean = [0.5, 0.5, 0.5] | ||
std = [0.5, 0.5, 0.5] | ||
else: | ||
height = width = 224 | ||
mean = [0.485, 0.456, 0.406] | ||
std = [0.229, 0.224, 0.225] | ||
input_shape = [batch_size, 3, height, width] | ||
input_data = torch.randn(input_shape).float() | ||
for channel in range(3): | ||
input_data[:, channel] -= mean[channel] | ||
input_data[:, channel] /= std[channel] | ||
model = getattr(torchvision.models, model_name)(pretrained=True) | ||
model = model.float().eval() | ||
return model, [input_data] | ||
try: | ||
import pretrainedmodels | ||
if hasattr(pretrainedmodels, model_name): | ||
return load_pretrainedmodels(model_name) | ||
except ModuleNotFoundError: | ||
raise ModuleNotFoundError("Please install pretrainedmodels.pytorch") | ||
raise RuntimeError("Model not supported") | ||
|
||
def get_model(model_name): | ||
torch.set_grad_enabled(False) | ||
baseline_model, baseline_input = load_model(model_name) | ||
|
||
trace = torch.jit.trace(baseline_model, baseline_input) | ||
if isinstance(baseline_model, torch.nn.Module): | ||
trace = trace.float().eval() | ||
trace = trace.cpu() | ||
|
||
global input_names | ||
input_names = ["input{}".format(idx) for idx, inp in enumerate(baseline_input)] | ||
input_shapes = list(zip(input_names, | ||
[inp.shape for inp in baseline_input])) | ||
mod, params = relay.frontend.from_pytorch(trace, input_shapes) | ||
return mod, params | ||
|
||
|
||
############# | ||
# Test models | ||
############# | ||
def main(): | ||
val_path = '/home/ubuntu/tensorflow_datasets/downloads/manual/imagenet2012/val.rec' | ||
if args.run_all: | ||
models = ['resnet50', 'inception_v3', 'mobilenet_v2', 'resnet18', | ||
'densenet161', 'googlenet', 'vgg16'] | ||
else: | ||
models = [args.model] | ||
for model_name in models: | ||
height = 224 | ||
if model_name.startswith("inception"): | ||
height = 299 | ||
|
||
val_data, batch_fn = get_val_data(height, val_path, batch_size) | ||
|
||
# Original | ||
if not args.skip_fp32: | ||
fp32_mod, params = get_model(model_name) | ||
func = hago.prerequisite_optimize(fp32_mod['main'], params=params) | ||
acc = eval_acc(func, val_data, batch_fn, args, var_name=input_names[0], target=target, ctx=ctx) | ||
print("fp32_accuracy", model_name, acc, sep=',') | ||
|
||
# Quantize | ||
calib_dataset = get_calibration_dataset(val_data, batch_fn, var_name=input_names[0]) | ||
fp32_mod, params = get_model(model_name) | ||
quantized_func = quantize_hago(fp32_mod, params, calib_dataset) | ||
acc = eval_acc(quantized_func, val_data, batch_fn, args, var_name=input_names[0], target=target, ctx=ctx) | ||
print("quantized_accuracy", model_name, acc, sep=',') | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.