diff --git a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py index 755788483390..784fc6160f7e 100644 --- a/compiler/bindings/python/iree/compiler/api/ctypes_dl.py +++ b/compiler/bindings/python/iree/compiler/api/ctypes_dl.py @@ -179,6 +179,16 @@ def _initializeGlobalCL(*cl_args: str): _dylib.ireeCompilerSetupGlobalCL(len(cl_args), arg_pointers, b"ctypes", False) +def _is_null_terminated(view: memoryview): + return view.nbytes > 0 and view[-1] == 0 + + +def _is_mlir_bytecode(view: memoryview): + """Compares the first 4 bytes of the view against the magic number 4d4cef52. + See https://mlir.llvm.org/docs/BytecodeFormat/#magic-number for more info.""" + return len(view) >= 4 and view[:4].hex() == "4d4cef52" + + class Session: def __init__(self): self._global_init = _global_init @@ -339,7 +349,7 @@ def wrap_buffer( buffer, buffer_len, # Detect if nul terminated. - True if buffer_len > 0 and view[-1] == 0 else False, + _is_null_terminated(view) and not _is_mlir_bytecode(view), byref(source_p), ) ) diff --git a/compiler/bindings/python/test/api/api_test.py b/compiler/bindings/python/test/api/api_test.py index ab70fcd0858e..0d52bd3ba004 100644 --- a/compiler/bindings/python/test/api/api_test.py +++ b/compiler/bindings/python/test/api/api_test.py @@ -17,7 +17,11 @@ import tempfile import unittest - from iree.compiler.api import * + from iree.compiler.api import ( + Session, + Source, + Output, + ) from iree.compiler import ir class DlFlagsTest(unittest.TestCase): @@ -81,6 +85,69 @@ def testInputBuffer(self): self.assertIn(b"module", bytes(mem)) out.close() + def testInputBytecode(self): + this_dir = os.path.dirname(__file__) + with open( + os.path.join(this_dir, "testdata", "bytecode_testfile.bc"), "rb" + ) as f: + bytecode = f.read() + session = Session() + inv = session.invocation() + source = Source.wrap_buffer(session, bytecode) + inv.parse_source(source) + out = Output.open_membuffer() + inv.output_ir(out) + mem = out.map_memory() + self.assertIn(b"module", bytes(mem)) + out.close() + + def testInputZeroTerminatedBytecode(self): + this_dir = os.path.dirname(__file__) + with open( + os.path.join( + this_dir, "testdata", "bytecode_zero_terminated_testfile.bc" + ), + "rb", + ) as f: + bytecode = f.read() + session = Session() + inv = session.invocation() + source = Source.wrap_buffer(session, bytecode) + inv.parse_source(source) + out = Output.open_membuffer() + inv.output_ir(out) + mem = out.map_memory() + self.assertIn(b"module", bytes(mem)) + out.close() + + def testInputRoundtrip(self): + test_ir = b"builtin.module {}" + session = Session() + inv = session.invocation() + source = Source.wrap_buffer( + session, + bytes(test_ir), + ) + inv.parse_source(source) + out = Output.open_membuffer() + inv.output_ir_bytecode(out) + mem = out.map_memory() + bytecode = bytes(mem) + out.close() + session = Session() + inv = session.invocation() + source = Source.wrap_buffer( + session, + bytecode, + ) + inv.parse_source(source) + out = Output.open_membuffer() + inv.output_ir(out) + mem = out.map_memory() + text_out = bytes(mem) + out.close() + self.assertIn(b"module", text_out) + def testOutputBytecode(self): session = Session() inv = session.invocation() diff --git a/compiler/bindings/python/test/api/testdata/bytecode_testfile.bc b/compiler/bindings/python/test/api/testdata/bytecode_testfile.bc new file mode 100644 index 000000000000..0ee6e0115747 Binary files /dev/null and b/compiler/bindings/python/test/api/testdata/bytecode_testfile.bc differ diff --git a/compiler/bindings/python/test/api/testdata/bytecode_zero_terminated_testfile.bc b/compiler/bindings/python/test/api/testdata/bytecode_zero_terminated_testfile.bc new file mode 100644 index 000000000000..500820fe9cca Binary files /dev/null and b/compiler/bindings/python/test/api/testdata/bytecode_zero_terminated_testfile.bc differ diff --git a/compiler/bindings/python/test/api/testdata/generate_mlir_bytecode.py b/compiler/bindings/python/test/api/testdata/generate_mlir_bytecode.py new file mode 100644 index 000000000000..5e26d00426c3 --- /dev/null +++ b/compiler/bindings/python/test/api/testdata/generate_mlir_bytecode.py @@ -0,0 +1,47 @@ +# Copyright 2024 The IREE Authors +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +from iree.compiler.api import ( + Session, + Source, + Output, +) +import os +import iree.compiler.tools.tflite + + +def generate_test_bytecode(): + session = Session() + inv = session.invocation() + source = Source.wrap_buffer(session, b"builtin.module {}") + inv.parse_source(source) + out = Output.open_membuffer() + inv.output_ir_bytecode(out) + mem = out.map_memory() + + this_dir = os.path.dirname(__file__) + with open(os.path.join(this_dir, "bytecode_testfile.bc"), "wb") as file: + file.write(bytes(mem)) + + +def generate_zero_terminated_bytecode(): + """MLIR Bytecode can also be zero terminated. I couldn't find a way to generate zero terminated + bytecode apart from this. Printing as textual IR and then reparsing and printing as bytecode + removes the zero termination on this IR. This might very well be an odity of TF.""" + if not iree.compiler.tools.tflite.is_available(): + return + this_dir = os.path.dirname(__file__) + path = os.path.join(this_dir, "..", "..", "tools", "testdata", "tflite_sample.fb") + bytecode = iree.compiler.tools.tflite.compile_file(path, import_only=True) + with open( + os.path.join(this_dir, "bytecode_zero_terminated_testfile.bc"), "wb" + ) as file: + file.write(bytecode) + + +if __name__ == "__main__": + generate_test_bytecode() + generate_zero_terminated_bytecode() diff --git a/compiler/bindings/python/test/tools/compiler_tflite_test.py b/compiler/bindings/python/test/tools/compiler_tflite_test.py index 2dec1b36052c..867b4ae88b37 100644 --- a/compiler/bindings/python/test/tools/compiler_tflite_test.py +++ b/compiler/bindings/python/test/tools/compiler_tflite_test.py @@ -10,7 +10,11 @@ import tempfile import unittest -from iree.compiler.tools.ir_tool import __main__ as ir_tool +from iree.compiler.api import ( + Session, + Source, + Output, +) # TODO: No idea why pytype cannot find names from this module. # pytype: disable=name-error @@ -24,18 +28,16 @@ sys.exit(0) -def mlir_bytecode_file_to_text(bytecode_file): - with tempfile.NamedTemporaryFile() as temp_file: - args = ir_tool.parse_arguments(["copy", bytecode_file, "-o", temp_file.name]) - ir_tool.main(args) - return str(temp_file.read()) - - def mlir_bytecode_to_text(bytecode): - with tempfile.NamedTemporaryFile("wb") as temp_bytecode_file: - temp_bytecode_file.write(bytecode) - temp_bytecode_file.flush() - return mlir_bytecode_file_to_text(temp_bytecode_file.name) + session = Session() + inv = session.invocation() + source = Source.wrap_buffer(session, bytecode) + inv.parse_source(source) + out = Output.open_membuffer() + inv.output_ir(out) + text_ir = str(bytes(out.map_memory())) + out.close() + return text_ir class CompilerTest(unittest.TestCase):