-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
165 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,165 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import argparse | ||
import copy | ||
import itertools | ||
import json | ||
import sys | ||
from collections import OrderedDict | ||
from pprint import pprint | ||
from typing import Any, Dict, List | ||
|
||
import onnx | ||
|
||
TuningResults = Dict[str, Any] | ||
|
||
_tuning_results_key = "tuning_results" | ||
|
||
|
||
def _find_tuning_results_in_props(metadata_props): | ||
for idx, prop in enumerate(metadata_props): | ||
if prop.key == _tuning_results_key: | ||
return idx | ||
return -1 | ||
|
||
|
||
def extract(onnx: onnx.ModelProto): | ||
idx = _find_tuning_results_in_props(onnx.metadata_props) | ||
if idx < 0: | ||
return None | ||
|
||
tuning_results_prop = onnx.metadata_props[idx] | ||
return json.loads(tuning_results_prop.value) | ||
|
||
|
||
def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False): | ||
idx = _find_tuning_results_in_props(model.metadata_props) | ||
assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embeded!" | ||
|
||
if idx >= 0: | ||
model.metadata_props.pop(idx) | ||
|
||
entry = model.metadata_props.add() | ||
entry.key = _tuning_results_key | ||
entry.value = json.dumps(tuning_results) | ||
return model | ||
|
||
|
||
class Merger: | ||
class EpAndValidators: | ||
def __init__(self, ep: str, validators: Dict[str, str]): | ||
self.ep = ep | ||
self.validators = copy.deepcopy(validators) | ||
self.key = (ep, tuple(sorted(validators.items()))) | ||
|
||
def __hash__(self): | ||
return hash(self.key) | ||
|
||
def __eq__(self, other): | ||
return self.ep == other.ep and self.key == other.key | ||
|
||
def __init__(self): | ||
self.ev_to_results = OrderedDict() | ||
|
||
def merge(self, tuning_results: List[TuningResults]): | ||
for trs in tuning_results: | ||
self._merge_one(trs) | ||
|
||
def get_merged(self): | ||
tuning_results = [] | ||
for ev, flat_results in self.ev_to_results.items(): | ||
results = {} | ||
trs = { | ||
"ep": ev.ep, | ||
"validators": ev.validators, | ||
"results": results, | ||
} | ||
for (op_sig, params_sig), kernel_id in flat_results.items(): | ||
kernel_map = results.setdefault(op_sig, {}) | ||
kernel_map[params_sig] = kernel_id | ||
tuning_results.append(trs) | ||
return tuning_results | ||
|
||
def _merge_one(self, trs: TuningResults): | ||
ev = Merger.EpAndValidators(trs["ep"], trs["validators"]) | ||
flat_results = self.ev_to_results.setdefault(ev, {}) | ||
for op_sig, kernel_map in trs["results"].items(): | ||
for params_sig, kernel_id in kernel_map.items(): | ||
if (op_sig, params_sig) not in flat_results: | ||
flat_results[(op_sig, params_sig)] = kernel_id | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd") | ||
|
||
extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.") | ||
extract_parser.add_argument("input_onnx") | ||
extract_parser.add_argument("output_json") | ||
|
||
embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.") | ||
embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.") | ||
embed_parser.add_argument("output_onnx", help="Path of the output onnx file.") | ||
embed_parser.add_argument("input_onnx", help="Path of the input onnx file.") | ||
embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.") | ||
|
||
merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.") | ||
merge_parser.add_argument("output_json", help="Path of the output tuning results file.") | ||
merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.") | ||
|
||
pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.") | ||
pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.") | ||
|
||
args = parser.parse_args() | ||
if len(vars(args)) == 0: | ||
parser.print_help() | ||
exit(-1) | ||
return args | ||
|
||
|
||
if __name__ == "__main__": | ||
args = parse_args() | ||
if args.cmd == "extract": | ||
tuning_results = extract(onnx.load_model(args.input_onnx)) | ||
if tuning_results is None: | ||
sys.stderr.write(f"{args.input_onnx} does not have tuning results embeded!\n") | ||
exit(-1) | ||
json.dump(tuning_results, open(args.output_json, "w")) | ||
elif args.cmd == "embed": | ||
model = onnx.load_model(args.input_onnx) | ||
merger = Merger() | ||
for tuning_results in [json.load(open(f)) for f in args.input_json]: | ||
merger.merge(tuning_results) | ||
model = embed(model, merger.get_merged(), args.force) | ||
onnx.save_model(model, args.output_onnx) | ||
elif args.cmd == "merge": | ||
merger = Merger() | ||
for tuning_results in [json.load(open(f)) for f in args.input_json]: | ||
merger.merge(tuning_results) | ||
json.dump(merger.get_merged(), open(args.output_json, "w")) | ||
elif args.cmd == "pprint": | ||
tuning_results = None | ||
try: | ||
tuning_results = json.load(open(args.json_or_onnx, "r")) | ||
except: | ||
pass | ||
|
||
if tuning_results is None: | ||
try: | ||
model = onnx.load_model(args.json_or_onnx) | ||
tuning_results = extract(model) | ||
if tuning_results is None: | ||
sys.stderr.write(f"{args.input_onnx} does not have tuning results embeded!\n") | ||
exit(-1) | ||
except: | ||
pass | ||
|
||
if tuning_results is None: | ||
sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!") | ||
exit(-1) | ||
|
||
pprint(tuning_results) | ||
else: | ||
# invalid choice will be handled by the parser | ||
pass |