Skip to content

Commit

Permalink
Add diagnostic handler
Browse files Browse the repository at this point in the history
  • Loading branch information
idavis committed Sep 16, 2024
1 parent 3c1bc83 commit 730ee5a
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 23 deletions.
5 changes: 2 additions & 3 deletions pyqir/pyqir/_native.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -681,13 +681,12 @@ class Module:
"""Converts this module into an LLVM IR string."""
...

def link(self, other: Module) -> Optional[str]:
def link(self, other: Module) -> None:
"""
Link the supplied module into the current module.
Destroys the supplied module.
:returns: An error message if linking failed or `None` if linking succeeded.
:rtype: typing.Optional[str]
:raises: An error if linking failed.
"""
...

Expand Down
28 changes: 16 additions & 12 deletions pyqir/src/module.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use llvm_sys::{
LLVMLinkage, LLVMModule,
};
use pyo3::{exceptions::PyValueError, prelude::*, pyclass::CompareOp, types::PyBytes};
use qirlib::module::FlagBehavior;
use qirlib::{context::set_diagnostic_handler, module::FlagBehavior};
use std::{
collections::hash_map::DefaultHasher,
ffi::CString,
Expand Down Expand Up @@ -269,26 +269,30 @@ impl Module {
/// Link the supplied module into the current module.
/// Destroys the supplied module.
///
/// :returns: An error message if linking failed or `None` if linking succeeded.
/// :rtype: typing.Optional[str]
pub fn link(&self, other: Py<Module>, py: Python) -> Option<String> {
if self.context.borrow(py).as_ptr() != other.borrow(py).context.borrow(py).as_ptr() {
return Some(
/// :raises: An error if linking failed.
pub fn link(&self, other: Py<Module>, py: Python) -> PyResult<()> {
let context = self.context.borrow(py).as_ptr();
if context != other.borrow(py).context.borrow(py).as_ptr() {
return Err(PyValueError::new_err(
"Cannot link modules from different contexts. Modules are untouched.".to_string(),
);
));
}
unsafe {
let mut char_ptr: *mut ::core::ffi::c_char = ptr::null_mut();
let char_ptr_ptr = &mut char_ptr as *mut *mut ::core::ffi::c_char
as *mut *mut ::core::ffi::c_void
as *mut ::core::ffi::c_void;

set_diagnostic_handler(context, char_ptr_ptr);
let result = LLVMLinkModules2(self.module.as_ptr(), other.borrow(py).module.as_ptr());
// `forget` the other module. LLVM has destroyed it
// and we'll get a segfault if we drop it.
forget(other);
if result == 0 {
None
Ok(())
} else {
// in the future we need to return a proper error message
// using `LLVMContextSetDiagnosticHandler`. This is a lot of work
// to get right, so we'll leave it for now.
Some("Failed to link modules".to_string())
let error = Message::from_raw(char_ptr);
return Err(PyValueError::new_err(error.to_str().unwrap().to_string()));
}
}
}
Expand Down
34 changes: 34 additions & 0 deletions pyqir/tests/profile_v1.0_compat.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
; ModuleID = 'OneDotZero'
%Result = type opaque
%Qubit = type opaque

define void @OneDotZero() #0 {
block_0:
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*))
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null)
ret void
}

declare void @__quantum__qis__h__body(%Qubit*)

declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*)

declare void @__quantum__rt__result_record_output(%Result*, i8*)

declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1

attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" }
attributes #1 = { "irreversible" }

; module flags

!llvm.module.flags = !{!0, !1, !2, !3}

!0 = !{i32 1, !"qir_major_version", i32 1}
!1 = !{i32 7, !"qir_minor_version", i32 0}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
34 changes: 34 additions & 0 deletions pyqir/tests/profile_v1.1_compat.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
; ModuleID = 'OneDotOne'
%Result = type opaque
%Qubit = type opaque

define void @OneDotOne() #0 {
block_0:
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*))
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null)
ret void
}

declare void @__quantum__qis__h__body(%Qubit*)

declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*)

declare void @__quantum__rt__result_record_output(%Result*, i8*)

declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1

attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" }
attributes #1 = { "irreversible" }

; module flags

!llvm.module.flags = !{!0, !1, !2, !3}

!0 = !{i32 1, !"qir_major_version", i32 1}
!1 = !{i32 7, !"qir_minor_version", i32 1}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
34 changes: 34 additions & 0 deletions pyqir/tests/profile_v2.0_compat.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
; ModuleID = 'TwoDotZero'
%Result = type opaque
%Qubit = type opaque

define void @TwoDotZero() #0 {
block_0:
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 0 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__cz__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Qubit* inttoptr (i64 0 to %Qubit*))
call void @__quantum__qis__h__body(%Qubit* inttoptr (i64 1 to %Qubit*))
call void @__quantum__qis__m__body(%Qubit* inttoptr (i64 1 to %Qubit*), %Result* inttoptr (i64 0 to %Result*))
call void @__quantum__rt__result_record_output(%Result* inttoptr (i64 0 to %Result*), i8* null)
ret void
}

declare void @__quantum__qis__h__body(%Qubit*)

declare void @__quantum__qis__cz__body(%Qubit*, %Qubit*)

declare void @__quantum__rt__result_record_output(%Result*, i8*)

declare void @__quantum__qis__m__body(%Qubit*, %Result*) #1

attributes #0 = { "entry_point" "output_labeling_schema" "qir_profiles"="base_profile" "required_num_qubits"="2" "required_num_results"="1" }
attributes #1 = { "irreversible" }

; module flags

!llvm.module.flags = !{!0, !1, !2, !3}

!0 = !{i32 1, !"qir_major_version", i32 2}
!1 = !{i32 7, !"qir_minor_version", i32 0}
!2 = !{i32 1, !"dynamic_qubit_management", i1 false}
!3 = !{i32 1, !"dynamic_result_management", i1 false}
78 changes: 70 additions & 8 deletions pyqir/tests/test_module_linking.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,93 @@

from pathlib import Path

import pytest

current_file_path = Path(__file__)
# Get the directory of the current file
current_dir = current_file_path.parent

from pyqir import (
Context,
Module,
)


def read_file(file_name: str) -> str:
return Path(current_dir / file_name).read_text(encoding="utf-8")


def test_link_modules_with_same_context() -> None:
context = Context()
ir = Path("tests/random_bit.ll").read_text()
ir = read_file("random_bit.ll")
dest = Module.from_ir(context, ir)
ir = Path("tests/5_bit_random_number.ll").read_text()
ir = read_file("5_bit_random_number.ll")
src = Module.from_ir(context, ir)
assert dest.link(src) is None
dest.link(src)
assert dest.verify() is None
actual_ir = str(dest)
expected_ir = str(Path("tests/combined_module.ll").read_text())
expected_ir = str(read_file("combined_module.ll"))
assert actual_ir == expected_ir


def test_link_modules_with_different_contexts() -> None:
ir = Path("tests/random_bit.ll").read_text()
ir = read_file("random_bit.ll")
dest = Module.from_ir(Context(), ir)
ir = Path("tests/5_bit_random_number.ll").read_text()
ir = read_file("5_bit_random_number.ll")
src = Module.from_ir(Context(), ir)
message = dest.link(src)
with pytest.raises(ValueError) as ex:
dest.link(src)
assert (
str(ex.value)
== "Cannot link modules from different contexts. Modules are untouched."
)


def test_link_module_with_src_minor_version_less() -> None:
context = Context()
ir = read_file("profile_v1.0_compat.ll")
dest = Module.from_ir(context, ir)
ir = read_file("profile_v1.1_compat.ll")
src = Module.from_ir(context, ir)
dest.link(src)
assert dest.get_flag("qir_minor_version").value.value == 1


def test_link_module_with_src_minor_version_greater() -> None:
context = Context()
ir = read_file("profile_v1.1_compat.ll")
dest = Module.from_ir(context, ir)
ir = read_file("profile_v1.0_compat.ll")
src = Module.from_ir(context, ir)
dest.link(src)
assert dest.get_flag("qir_minor_version").value.value == 1


def test_link_module_with_src_major_version_less() -> None:
context = Context()
ir = read_file("profile_v2.0_compat.ll")
dest = Module.from_ir(context, ir)
ir = read_file("profile_v1.0_compat.ll")
src = Module.from_ir(context, ir)
with pytest.raises(ValueError) as ex:
dest.link(src)

assert (
"linking module flags 'qir_major_version': IDs have conflicting values"
in str(ex)
)


def test_link_module_with_src_major_version_greater() -> None:
context = Context()
ir = read_file("profile_v1.0_compat.ll")
dest = Module.from_ir(context, ir)
ir = read_file("profile_v2.0_compat.ll")
src = Module.from_ir(context, ir)
with pytest.raises(ValueError) as ex:
dest.link(src)

assert (
message == "Cannot link modules from different contexts. Modules are untouched."
"linking module flags 'qir_major_version': IDs have conflicting values"
in str(ex)
)
26 changes: 26 additions & 0 deletions qirlib/src/context.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

use llvm_sys::{
core::{LLVMContextSetDiagnosticHandler, LLVMGetDiagInfoDescription, LLVMGetDiagInfoSeverity},
prelude::{LLVMContextRef, LLVMDiagnosticInfoRef},
LLVMDiagnosticSeverity,
};

pub fn set_diagnostic_handler(context: LLVMContextRef, output_ptr: *mut core::ffi::c_void) {
unsafe { LLVMContextSetDiagnosticHandler(context, Some(diagnostic_handler), output_ptr) };
}

pub(crate) extern "C" fn diagnostic_handler(
diagnostic_info: LLVMDiagnosticInfoRef,
output: *mut ::core::ffi::c_void,
) {
unsafe {
let severity = LLVMGetDiagInfoSeverity(diagnostic_info);
if severity == LLVMDiagnosticSeverity::LLVMDSError {
let c_char_output =
output as *mut *mut ::core::ffi::c_void as *mut *mut ::core::ffi::c_char;
*c_char_output = LLVMGetDiagInfoDescription(diagnostic_info)
}
}
}
2 changes: 2 additions & 0 deletions qirlib/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ extern crate llvm_sys_140 as llvm_sys;
#[cfg(not(feature = "no-llvm-linking"))]
pub mod builder;
#[cfg(not(feature = "no-llvm-linking"))]
pub mod context;
#[cfg(not(feature = "no-llvm-linking"))]
pub(crate) mod llvm_wrapper;
#[cfg(not(feature = "no-llvm-linking"))]
pub mod metadata;
Expand Down

0 comments on commit 730ee5a

Please sign in to comment.