From 164635d824849cb2afaaf78a367b83ce6d86a180 Mon Sep 17 00:00:00 2001 From: Atharva Satpute <55058959+atharva-satpute@users.noreply.github.com> Date: Sat, 1 Jun 2024 01:09:05 +0530 Subject: [PATCH] fix: fix issue with kwargs while calling subroutine --- src/autoqasm/api.py | 15 ++++-- test/unit_tests/autoqasm/test_api.py | 70 ++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/src/autoqasm/api.py b/src/autoqasm/api.py index d4a567c..ce5cb33 100644 --- a/src/autoqasm/api.py +++ b/src/autoqasm/api.py @@ -339,13 +339,17 @@ def _convert_subroutine( for i, param in enumerate(inspect.signature(f).parameters.values()) if param.annotation == aq_types.QubitIdentifierType } + + # Map args and kwargs to function signature + bound_args = inspect.signature(oqpy_sub).bind(*[oqpy_program, *args], **kwargs) + args = [ (aq_instructions.qubits._qubit(arg) if i in quantum_indices else arg) - for i, arg in enumerate(args) + for i, arg in enumerate(bound_args.args[1:]) ] # Process the program - subroutine_function_call = oqpy_sub(oqpy_program, *args, **kwargs) + subroutine_function_call = oqpy_sub(oqpy_program, *args) program_conversion_context.register_args(args) # Mark that we are finished processing this function @@ -357,8 +361,13 @@ def _convert_subroutine( _wrap_for_oqpy_subroutine(_dummy_function(f), options) ) + # Map args and kwargs to function signature + bound_args = inspect.signature(oqpy_sub).bind(*((oqpy_program, *args)), **kwargs) + + args = bound_args.args[1:] + # Process the program - subroutine_function_call = oqpy_sub(oqpy_program, *args, **kwargs) + subroutine_function_call = oqpy_sub(oqpy_program, *args) # Add the subroutine invocation to the program ret_type = subroutine_function_call.subroutine_decl.return_type diff --git a/test/unit_tests/autoqasm/test_api.py b/test/unit_tests/autoqasm/test_api.py index aa932d2..575ef4b 100644 --- a/test/unit_tests/autoqasm/test_api.py +++ b/test/unit_tests/autoqasm/test_api.py @@ -1240,3 +1240,73 @@ def main(): h __qubits__[2]; h __qubits__[3];""" assert main.build().to_ir() == expected_ir + + +def test_subroutine_call_with_kwargs(): + """Test that subroutine call works with keyword arguments""" + + @aq.subroutine + def test(a: int, b: int) -> None: + aq.instructions.h(a) + aq.instructions.h(b) + + @aq.main(num_qubits=2) + def main(): + test(a=0, b=1) + + expected = """OPENQASM 3.0; +def test(int[32] a, int[32] b) { + h __qubits__[a]; + h __qubits__[b]; +} +qubit[2] __qubits__; +test(0, 1);""" + assert main.build().to_ir() == expected + + +def test_subroutine_call_with_one_arg_one_kwarg(): + """ + Test that subroutine call works with one positional arg and + one keyword argument + """ + + @aq.subroutine + def test(a: int, b: int) -> None: + aq.instructions.h(a) + aq.instructions.h(b) + + @aq.main(num_qubits=2) + def main(): + test(0, b=1) + + expected = """OPENQASM 3.0; +def test(int[32] a, int[32] b) { + h __qubits__[a]; + h __qubits__[b]; +} +qubit[2] __qubits__; +test(0, 1);""" + + assert main.build().to_ir() == expected + + +def test_subroutine_call_with_kwargs_in_any_order(): + """Test that subroutine calls work with keyword arguments placed in any order""" + + @aq.subroutine + def test(a: int, b: int) -> None: + aq.instructions.h(a) + aq.instructions.h(b) + + @aq.main(num_qubits=2) + def main(): + test(b=1, a=0) + + expected = """OPENQASM 3.0; +def test(int[32] a, int[32] b) { + h __qubits__[a]; + h __qubits__[b]; +} +qubit[2] __qubits__; +test(0, 1);""" + assert main.build().to_ir() == expected