From 72038728411ff6b2a3320581b5dd08b3442b197d Mon Sep 17 00:00:00 2001 From: messense Date: Fri, 3 Dec 2021 22:45:30 +0800 Subject: [PATCH] Add a Python import hook --- maturin/import_hook.py | 110 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) create mode 100644 maturin/import_hook.py diff --git a/maturin/import_hook.py b/maturin/import_hook.py new file mode 100644 index 000000000..916dd201e --- /dev/null +++ b/maturin/import_hook.py @@ -0,0 +1,110 @@ +import importlib +import importlib.util +from importlib import abc +import os +import pathlib +import sys +import subprocess +from typing import Optional + +import toml + + +class Importer(abc.MetaPathFinder): + """A meta-path importer for the maturin based packages""" + + def __init__(self, bindings: Optional[str] = None): + self.bindings = bindings + + def find_spec(self, fullname, path, target=None): + if fullname in sys.modules: + return + mod_parts = fullname.split(".") + module_name = mod_parts[-1] + + # Full Cargo project + cargo_toml = pathlib.Path(os.getcwd()) / "Cargo.toml" + if os.path.exists(cargo_toml): + with open(cargo_toml) as f: + cargo = toml.load(f) + package_name = cargo.get("package", {}).get("name") + if ( + package_name == module_name + or package_name.replace("-", "_") == module_name + ): + build_module(cargo_toml, bindings=self.bindings) + loader = Loader(fullname) + return importlib.util.spec_from_loader(fullname, loader) + + # Single .rs file + rust_file = pathlib.Path(os.getcwd()) / (module_name + ".rs") + if os.path.exists(rust_file): + project_dir = generate_project(rust_file, bindings=self.bindings or "pyo3") + cargo_toml = project_dir / "Cargo.toml" + build_module(cargo_toml, bindings=self.bindings) + loader = Loader(fullname) + return importlib.util.spec_from_loader(fullname, loader) + + +class Loader(abc.Loader): + def __init__(self, fullname): + self.fullname = fullname + + def load_module(self, fullname): + return importlib.import_module(self.fullname) + + +def generate_project(rust_file: pathlib.Path, bindings: str = "pyo3") -> pathlib.Path: + build_dir = pathlib.Path(os.getcwd()) / "build" + project_dir = build_dir / rust_file.stem + command = ["maturin", "new", "-b", bindings, project_dir] + result = subprocess.run(command, stdout=subprocess.PIPE) + if result.returncode != 0: + sys.stderr.write( + f"Error: command {command} returned non-zero exit status {result.returncode}\n" + ) + raise ImportError("Failed to generate cargo project") + return project_dir + + +def build_module(manifest_path: pathlib.Path, bindings: Optional[str] = None): + command = ["maturin", "develop", "-m", manifest_path] + if bindings: + command.append("-b") + command.append(bindings) + result = subprocess.run(command, stdout=subprocess.PIPE) + sys.stdout.buffer.write(result.stdout) + sys.stdout.flush() + if result.returncode != 0: + sys.stderr.write( + f"Error: command {command} returned non-zero exit status {result.returncode}\n" + ) + raise ImportError("Failed to build module with maturin") + + +def _have_importer() -> bool: + for importer in sys.meta_path: + if isinstance(importer, Importer): + return True + return False + + +def install(bindings: Optional[str] = None): + """ + Install the import hook. + """ + if _have_importer(): + return + importer = Importer(bindings=bindings) + sys.meta_path.append(importer) + return importer + + +def uninstall(importer: Importer): + """ + Uninstall the import hook. + """ + try: + sys.meta_path.remove(importer) + except ValueError: + pass