Skip to content

Commit

Permalink
Add tool for offline tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhan committed Feb 3, 2023
1 parent 693df24 commit 9bf2d26
Showing 1 changed file with 165 additions and 0 deletions.
165 changes: 165 additions & 0 deletions tools/python/offline_tuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import argparse
import copy
import itertools
import json
import sys
from collections import OrderedDict
from pprint import pprint
from typing import Any, Dict, List

import onnx

TuningResults = Dict[str, Any]

_tuning_results_key = "tuning_results"


def _find_tuning_results_in_props(metadata_props):
for idx, prop in enumerate(metadata_props):
if prop.key == _tuning_results_key:
return idx
return -1


def extract(onnx: onnx.ModelProto):
idx = _find_tuning_results_in_props(onnx.metadata_props)
if idx < 0:
return None

tuning_results_prop = onnx.metadata_props[idx]
return json.loads(tuning_results_prop.value)


def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False):
idx = _find_tuning_results_in_props(model.metadata_props)
assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embeded!"

if idx >= 0:
model.metadata_props.pop(idx)

entry = model.metadata_props.add()
entry.key = _tuning_results_key
entry.value = json.dumps(tuning_results)
return model


class Merger:
class EpAndValidators:
def __init__(self, ep: str, validators: Dict[str, str]):
self.ep = ep
self.validators = copy.deepcopy(validators)
self.key = (ep, tuple(sorted(validators.items())))

def __hash__(self):
return hash(self.key)

def __eq__(self, other):
return self.ep == other.ep and self.key == other.key

def __init__(self):
self.ev_to_results = OrderedDict()

def merge(self, tuning_results: List[TuningResults]):
for trs in tuning_results:
self._merge_one(trs)

def get_merged(self):
tuning_results = []
for ev, flat_results in self.ev_to_results.items():
results = {}
trs = {
"ep": ev.ep,
"validators": ev.validators,
"results": results,
}
for (op_sig, params_sig), kernel_id in flat_results.items():
kernel_map = results.setdefault(op_sig, {})
kernel_map[params_sig] = kernel_id
tuning_results.append(trs)
return tuning_results

def _merge_one(self, trs: TuningResults):
ev = Merger.EpAndValidators(trs["ep"], trs["validators"])
flat_results = self.ev_to_results.setdefault(ev, {})
for op_sig, kernel_map in trs["results"].items():
for params_sig, kernel_id in kernel_map.items():
if (op_sig, params_sig) not in flat_results:
flat_results[(op_sig, params_sig)] = kernel_id


def parse_args():
parser = argparse.ArgumentParser()
sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd")

extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.")
extract_parser.add_argument("input_onnx")
extract_parser.add_argument("output_json")

embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.")
embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.")
embed_parser.add_argument("output_onnx", help="Path of the output onnx file.")
embed_parser.add_argument("input_onnx", help="Path of the input onnx file.")
embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.")

merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.")
merge_parser.add_argument("output_json", help="Path of the output tuning results file.")
merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.")

pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.")
pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.")

args = parser.parse_args()
if len(vars(args)) == 0:
parser.print_help()
exit(-1)
return args


if __name__ == "__main__":
args = parse_args()
if args.cmd == "extract":
tuning_results = extract(onnx.load_model(args.input_onnx))
if tuning_results is None:
sys.stderr.write(f"{args.input_onnx} does not have tuning results embeded!\n")
exit(-1)
json.dump(tuning_results, open(args.output_json, "w"))
elif args.cmd == "embed":
model = onnx.load_model(args.input_onnx)
merger = Merger()
for tuning_results in [json.load(open(f)) for f in args.input_json]:
merger.merge(tuning_results)
model = embed(model, merger.get_merged(), args.force)
onnx.save_model(model, args.output_onnx)
elif args.cmd == "merge":
merger = Merger()
for tuning_results in [json.load(open(f)) for f in args.input_json]:
merger.merge(tuning_results)
json.dump(merger.get_merged(), open(args.output_json, "w"))
elif args.cmd == "pprint":
tuning_results = None
try:
tuning_results = json.load(open(args.json_or_onnx, "r"))
except:
pass

if tuning_results is None:
try:
model = onnx.load_model(args.json_or_onnx)
tuning_results = extract(model)
if tuning_results is None:
sys.stderr.write(f"{args.input_onnx} does not have tuning results embeded!\n")
exit(-1)
except:
pass

if tuning_results is None:
sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!")
exit(-1)

pprint(tuning_results)
else:
# invalid choice will be handled by the parser
pass

0 comments on commit 9bf2d26

Please sign in to comment.