diff --git a/tools/python/offline_tuning.py b/tools/python/offline_tuning.py new file mode 100644 index 0000000000000..c2349b4706fb2 --- /dev/null +++ b/tools/python/offline_tuning.py @@ -0,0 +1,165 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import copy +import itertools +import json +import sys +from collections import OrderedDict +from pprint import pprint +from typing import Any, Dict, List + +import onnx + +TuningResults = Dict[str, Any] + +_tuning_results_key = "tuning_results" + + +def _find_tuning_results_in_props(metadata_props): + for idx, prop in enumerate(metadata_props): + if prop.key == _tuning_results_key: + return idx + return -1 + + +def extract(onnx: onnx.ModelProto): + idx = _find_tuning_results_in_props(onnx.metadata_props) + if idx < 0: + return None + + tuning_results_prop = onnx.metadata_props[idx] + return json.loads(tuning_results_prop.value) + + +def embed(model: onnx.ModelProto, tuning_results: List[TuningResults], overwrite=False): + idx = _find_tuning_results_in_props(model.metadata_props) + assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!" + + if idx >= 0: + model.metadata_props.pop(idx) + + entry = model.metadata_props.add() + entry.key = _tuning_results_key + entry.value = json.dumps(tuning_results) + return model + + +class Merger: + class EpAndValidators: + def __init__(self, ep: str, validators: Dict[str, str]): + self.ep = ep + self.validators = copy.deepcopy(validators) + self.key = (ep, tuple(sorted(validators.items()))) + + def __hash__(self): + return hash(self.key) + + def __eq__(self, other): + return self.ep == other.ep and self.key == other.key + + def __init__(self): + self.ev_to_results = OrderedDict() + + def merge(self, tuning_results: List[TuningResults]): + for trs in tuning_results: + self._merge_one(trs) + + def get_merged(self): + tuning_results = [] + for ev, flat_results in self.ev_to_results.items(): + results = {} + trs = { + "ep": ev.ep, + "validators": ev.validators, + "results": results, + } + for (op_sig, params_sig), kernel_id in flat_results.items(): + kernel_map = results.setdefault(op_sig, {}) + kernel_map[params_sig] = kernel_id + tuning_results.append(trs) + return tuning_results + + def _merge_one(self, trs: TuningResults): + ev = Merger.EpAndValidators(trs["ep"], trs["validators"]) + flat_results = self.ev_to_results.setdefault(ev, {}) + for op_sig, kernel_map in trs["results"].items(): + for params_sig, kernel_id in kernel_map.items(): + if (op_sig, params_sig) not in flat_results: + flat_results[(op_sig, params_sig)] = kernel_id + + +def parse_args(): + parser = argparse.ArgumentParser() + sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd") + + extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.") + extract_parser.add_argument("input_onnx") + extract_parser.add_argument("output_json") + + embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.") + embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.") + embed_parser.add_argument("output_onnx", help="Path of the output onnx file.") + embed_parser.add_argument("input_onnx", help="Path of the input onnx file.") + embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.") + + merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.") + merge_parser.add_argument("output_json", help="Path of the output tuning results file.") + merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.") + + pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.") + pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.") + + args = parser.parse_args() + if len(vars(args)) == 0: + parser.print_help() + exit(-1) + return args + + +if __name__ == "__main__": + args = parse_args() + if args.cmd == "extract": + tuning_results = extract(onnx.load_model(args.input_onnx)) + if tuning_results is None: + sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") + exit(-1) + json.dump(tuning_results, open(args.output_json, "w")) + elif args.cmd == "embed": + model = onnx.load_model(args.input_onnx) + merger = Merger() + for tuning_results in [json.load(open(f)) for f in args.input_json]: + merger.merge(tuning_results) + model = embed(model, merger.get_merged(), args.force) + onnx.save_model(model, args.output_onnx) + elif args.cmd == "merge": + merger = Merger() + for tuning_results in [json.load(open(f)) for f in args.input_json]: + merger.merge(tuning_results) + json.dump(merger.get_merged(), open(args.output_json, "w")) + elif args.cmd == "pprint": + tuning_results = None + try: + tuning_results = json.load(open(args.json_or_onnx, "r")) + except: + pass + + if tuning_results is None: + try: + model = onnx.load_model(args.json_or_onnx) + tuning_results = extract(model) + if tuning_results is None: + sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n") + exit(-1) + except: + pass + + if tuning_results is None: + sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!") + exit(-1) + + pprint(tuning_results) + else: + # invalid choice will be handled by the parser + pass