From 294a362d8b094f68af6ed1f481e8814fe980dff2 Mon Sep 17 00:00:00 2001 From: PENGUINLIONG Date: Wed, 15 Mar 2023 15:36:02 +0800 Subject: [PATCH] [aot] Load GfxRuntime140 module from TCM (#7539) Issue: # ### Brief Summary --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/_ti_module/__init__.py | 10 +++------- python/taichi/_ti_module/cppgen.py | 4 +--- .../aot/conventions/gfxruntime140/__init__.py | 19 +++++++++++++++++++ 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/python/taichi/_ti_module/__init__.py b/python/taichi/_ti_module/__init__.py index b9c7737d627ee..cea14f1335b75 100644 --- a/python/taichi/_ti_module/__init__.py +++ b/python/taichi/_ti_module/__init__.py @@ -38,12 +38,6 @@ def module_cppgen_impl(a): f"Generating C++ header for Taichi module: {Path(module_path).absolute()}" ) - with open(f"{module_path}/metadata.json") as f: - metadata_json = json.load(f) - - with open(f"{module_path}/graphs.json") as f: - graphs_json = json.load(f) - if a.module_name: module_name = a.module_name else: @@ -51,7 +45,9 @@ def module_cppgen_impl(a): if module_name.endswith(".tcm"): module_name = module_name[:-4] - out = generate_header(metadata_json, graphs_json, module_name, a.namespace) + m = GfxRuntime140.from_module(module_path) + + out = generate_header(m, module_name, a.namespace) with open(a.output, "w") as f: f.write('\n'.join(out)) diff --git a/python/taichi/_ti_module/cppgen.py b/python/taichi/_ti_module/cppgen.py index 8ecfd862e729c..9d9da1b23ca68 100644 --- a/python/taichi/_ti_module/cppgen.py +++ b/python/taichi/_ti_module/cppgen.py @@ -280,12 +280,10 @@ def generate_module_content(m: GfxRuntime140, module_name: str) -> List[str]: return out -def generate_header(metadata_json: str, graphs_json: str, module_name: str, +def generate_header(m: GfxRuntime140, module_name: str, namespace: str) -> List[str]: out = [] - m = GfxRuntime140(metadata_json, graphs_json) - out += [ "// THIS IS A GENERATED HEADER; PLEASE DO NOT MODIFY.", "#pragma once", diff --git a/python/taichi/aot/conventions/gfxruntime140/__init__.py b/python/taichi/aot/conventions/gfxruntime140/__init__.py index 555ab61ca91d1..dc3ea82acb175 100644 --- a/python/taichi/aot/conventions/gfxruntime140/__init__.py +++ b/python/taichi/aot/conventions/gfxruntime140/__init__.py @@ -1,3 +1,6 @@ +import json +import zipfile +from pathlib import Path from typing import Any, List from taichi.aot.conventions.gfxruntime140 import dr, sr @@ -10,6 +13,22 @@ def __init__(self, metadata_json: Any, graphs_json: Any) -> None: self.metadata = sr.from_dr_metadata(metadata) self.graphs = [sr.from_dr_graph(self.metadata, x) for x in graphs] + @staticmethod + def from_module(module_path: str) -> 'GfxRuntime140': + if Path(module_path).is_file(): + with zipfile.ZipFile(module_path) as z: + with z.open('metadata.json') as f: + metadata_json = json.load(f) + with z.open('graphs.json') as f: + graphs_json = json.load(f) + else: + with open(f"{module_path}/metadata.json") as f: + metadata_json = json.load(f) + with open(f"{module_path}/graphs.json") as f: + graphs_json = json.load(f) + + return GfxRuntime140(metadata_json, graphs_json) + def to_metadata_json(self) -> Any: return dr.to_json_metadata(sr.to_dr_metadata(self.metadata))