Skip to content

Commit

Permalink
Add link function to Module
Browse files Browse the repository at this point in the history
  • Loading branch information
idavis committed Sep 16, 2024
1 parent 3ace643 commit 2d6c613
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 6 deletions.
10 changes: 10 additions & 0 deletions pyqir/pyqir/_native.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,16 @@ class Module:
"""Converts this module into an LLVM IR string."""
...

def link(self, other: Module) -> Optional[str]:
"""
Link the supplied module into the current module.
Destroys the supplied module.
:returns: An error message if linking failed or `None` if linking succeeded.
:rtype: typing.Optional[str]
"""
...

class ModuleFlagBehavior(Enum):
"""Module flag behavior choices"""

Expand Down
34 changes: 28 additions & 6 deletions pyqir/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,14 @@ use llvm_sys::{
bit_writer::LLVMWriteBitcodeToMemoryBuffer,
core::*,
ir_reader::LLVMParseIRInContext,
linker::LLVMLinkModules2,
LLVMLinkage, LLVMModule,
};
use pyo3::{exceptions::PyValueError, prelude::*, pyclass::CompareOp, types::PyBytes};
use qirlib::module::FlagBehavior;
use core::mem::forget;
use std::{
collections::hash_map::DefaultHasher,
ffi::CString,
hash::{Hash, Hasher},
ops::Deref,
ptr::{self, NonNull},
str,
collections::hash_map::DefaultHasher, ffi::CString, hash::{Hash, Hasher}, ops::Deref, ptr::{self, NonNull}, str
};

/// A module is a collection of global values.
Expand Down Expand Up @@ -263,6 +260,31 @@ impl Module {
.to_string()
}
}

/// Link the supplied module into the current module.
/// Destroys the supplied module.
///
/// :returns: An error message if linking failed or `None` if linking succeeded.
/// :rtype: typing.Optional[str]
pub fn link(&self, other: Py<Module>, py: Python) -> Option<String> {
if self.context.borrow(py).as_ptr() != other.borrow(py).context.borrow(py).as_ptr() {
return Some("Cannot link modules from different contexts. Modules are untouched.".to_string());
}
unsafe {
let result = LLVMLinkModules2(self.module.as_ptr(), other.borrow(py).module.as_ptr());
// `forget` the other module. LLVM has destroyed it
// and we'll get a segfault if we drop it.
forget(other);
if result == 0 {
return None;
} else {
// in the future we need to return a proper error message
// using `LLVMContextSetDiagnosticHandler`. This is a lot of work
// to get right, so we'll leave it for now.
Some("Failed to link modules".to_string())
}
}
}
}

impl Deref for Module {
Expand Down
44 changes: 44 additions & 0 deletions pyqir/tests/5_bit_random_number.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
; ModuleID = '5_bit_random_number'
%Result = type opaque
%Qubit = type opaque

define void @five_bit_random_number() #0 {
block_0:
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 3 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 4 to %Qubit*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 0 to %Qubit*), %Result* inttoptr (i64 0 to %Result*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 2 to %Result*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Result* inttoptr (i64 3 to %Result*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 4 to %Qubit*), %Result* inttoptr (i64 4 to %Result*))
call void @__quantum__rt__array_record_output(i64 5, i8* null)
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null)
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* null)
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 2 to %Result*), i8* null)
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 3 to %Result*), i8* null)
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 4 to %Result*), i8* null)
ret void
}

declare void @__quantum__qis__h__body(%Qubit*)

declare void @__quantum__rt__array_record_output(i64, i8*)

declare void @__quantum__rt__result_record_output(%Result*, i8*)

declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1

attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="5" "required_num_results"="5" }
attributes #1 = { "irreversible" }

; module flags

!llvm.module.flags = !{!0, !1, !2, !3}

!0 = !{i32 1, !"qir_major_version", i32 1}
!1 = !{i32 7, !"qir_minor_version", i32 0}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
56 changes: 56 additions & 0 deletions pyqir/tests/combined_module.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@

%Qubit = type opaque
%Result = type opaque

define void @random_bit() #0 {
block_0:
call void @__quantum__qis__h__body(%Qubit* null)
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* null)
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* null)
call void @__quantum__rt__result_record_output(%Result* null, i8* null)
ret void
}

declare void @__quantum__qis__h__body(%Qubit*)

declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*)

declare void @__quantum__rt__result_record_output(%Result*, i8*)

declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1

define void @five_bit_random_number() #2 {
block_0:
call void @__quantum__qis__h__body(%Qubit* null)
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 2 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 3 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 4 to %Qubit*))
call void @__quantum__qis__m__body(%Qubit* null, %Result* null)
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 1 to %Result*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 2 to %Qubit*), %Result* inttoptr (i64 2 to %Result*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 3 to %Qubit*), %Result* inttoptr (i64 3 to %Result*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 4 to %Qubit*), %Result* inttoptr (i64 4 to %Result*))
call void @__quantum__rt__array_record_output(i64 5, i8* null)
call void @__quantum__rt__result_record_output(%Result* null, i8* null)
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 1 to %Result*), i8* null)
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 2 to %Result*), i8* null)
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 3 to %Result*), i8* null)
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 4 to %Result*), i8* null)
ret void
}

declare void @__quantum__rt__array_record_output(i64, i8*)

attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" }
attributes #1 = { "irreversible" }
attributes #2 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="5" "required_num_results"="5" }

!llvm.module.flags = !{!0, !1, !2, !3}

!0 = !{i32 1, !"qir_major_version", i32 1}
!1 = !{i32 7, !"qir_minor_version", i32 0}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
34 changes: 34 additions & 0 deletions pyqir/tests/random_bit.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
; ModuleID = 'random_bit'
%Result = type opaque
%Qubit = type opaque

define void @random_bit() #0 {
block_0:
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*))
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null)
ret void
}

declare void @__quantum__qis__h__body(%Qubit*)

declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*)

declare void @__quantum__rt__result_record_output(%Result*, i8*)

declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1

attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" }
attributes #1 = { "irreversible" }

; module flags

!llvm.module.flags = !{!0, !1, !2, !3}

!0 = !{i32 1, !"qir_major_version", i32 1}
!1 = !{i32 7, !"qir_minor_version", i32 0}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
31 changes: 31 additions & 0 deletions pyqir/tests/test_module_linking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

from pathlib import Path

from pyqir import (
Context,
Module,
)


def test_link_modules_with_same_context() -> None:
context = Context()
ir = Path("tests/random_bit.ll").read_text()
dest = Module.from_ir(context, ir)
ir = Path("tests/5_bit_random_number.ll").read_text()
src = Module.from_ir(context, ir)
assert dest.link(src) is None
assert dest.verify() is None
actual_ir = str(dest)
expected_ir = str(Path("tests/combined_module.ll").read_text())
assert actual_ir == expected_ir

def test_link_modules_with_different_contexts() -> None:
ir = Path("tests/random_bit.ll").read_text()
dest = Module.from_ir(Context(), ir)
ir = Path("tests/5_bit_random_number.ll").read_text()
src = Module.from_ir(Context(), ir)
message = dest.link(src)
assert message == "Cannot link modules from different contexts. Modules are untouched."

0 comments on commit 2d6c613

Please sign in to comment.