From e2b190f39d45b5e4828b0f0a1429fa68f733d7bc Mon Sep 17 00:00:00 2001 From: nikhilkhatri Date: Sun, 24 Mar 2024 12:49:29 +0000 Subject: [PATCH 1/3] Add support for fixed params in GeneralEncoder --- torchquantum/encoding/encodings.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index f8d2056d..e73c526e 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -90,7 +90,11 @@ def __init__(self, func_list): def forward(self, qdev: tq.QuantumDevice, x): for info in self.func_list: if tq.op_name_dict[info["func"]].num_params > 0: - params = x[:, info["input_idx"]] + # If params are provided in encoder, use those, + # else use params from x + params = (torch.Tensor(info["params"]).repeat(x.shape[0], 1) + if info.get("params") + else x[:, info["input_idx"]]) else: params = None func_name_dict[info["func"]]( From a5d47d2744147d39e2586f536b1ecaeaf757545a Mon Sep 17 00:00:00 2001 From: nikhilkhatri Date: Mon, 25 Mar 2024 09:45:32 +0000 Subject: [PATCH 2/3] Add pre-parameterised example to GeneralEncoder --- torchquantum/encoding/encodings.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/torchquantum/encoding/encodings.py b/torchquantum/encoding/encodings.py index e73c526e..03010ade 100644 --- a/torchquantum/encoding/encodings.py +++ b/torchquantum/encoding/encodings.py @@ -80,6 +80,18 @@ class GeneralEncoder(Encoder, metaclass=ABCMeta): {'input_idx': [12, 13, 14], 'func': 'u3', 'wires': [3]}, {'input_idx': [15], 'func': 'u1', 'wires': [3]}, ] + + Example 3: + [ + {'params': [0.25], 'func': 'rx', 'wires': [0]}, + {'params': [0.25], 'func': 'rx', 'wires': [1]}, + {'params': [0.25], 'func': 'rx', 'wires': [2]}, + {'params': [0.25], 'func': 'rx', 'wires': [3]}, + {'input_idx': [0], 'func': 'ry', 'wires': [0]}, + {'input_idx': [1], 'func': 'ry', 'wires': [1]}, + {'input_idx': [2], 'func': 'ry', 'wires': [2]}, + {'input_idx': [3], 'func': 'ry', 'wires': [3]} + ] """ def __init__(self, func_list): From 69f06998273248dee6ffb4a6246ccf506bbd61a7 Mon Sep 17 00:00:00 2001 From: nikhilkhatri Date: Sun, 31 Mar 2024 12:45:56 +0100 Subject: [PATCH 3/3] Add unit tests for GeneralEncoder --- test/encoding/test_encodings.py | 81 +++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 test/encoding/test_encodings.py diff --git a/test/encoding/test_encodings.py b/test/encoding/test_encodings.py new file mode 100644 index 00000000..9ee64910 --- /dev/null +++ b/test/encoding/test_encodings.py @@ -0,0 +1,81 @@ +""" +MIT License + +Copyright (c) 2020-present TorchQuantum Authors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +# test the controlled unitary function + + +import torchquantum as tq +import torch +from test.utils import check_all_close + + +def test_GeneralEncoder(): + + parameterised_funclist = [ + {"input_idx": [0], "func": "crx", "wires": [1, 0]}, + {"input_idx": [1, 2, 3], "func": "u3", "wires": [1]}, + {"input_idx": [4], "func": "ry", "wires": [0]}, + {"input_idx": [5], "func": "ry", "wires": [1]}, + ] + + semiparam_funclist = [ + {"params": [0.2], "func": "crx", "wires": [1, 0]}, + {"params": [0.3, 0.4, 0.5], "func": "u3", "wires": [1]}, + {"input_idx": [0], "func": "ry", "wires": [0]}, + {"input_idx": [1], "func": "ry", "wires": [1]}, + ] + + expected_states = torch.complex( + torch.Tensor( + [[0.8423, 0.4474, 0.2605, 0.1384], [0.7649, 0.5103, 0.3234, 0.2157]] + ), + torch.Tensor( + [[-0.0191, 0.0522, -0.0059, 0.0162], [-0.0233, 0.0483, -0.0099, 0.0204]] + ), + ) + + parameterised_enc = tq.GeneralEncoder(parameterised_funclist) + semiparam_enc = tq.GeneralEncoder(semiparam_funclist) + + param_vec = torch.Tensor( + [[0.2, 0.3, 0.4, 0.5, 0.6, 0.7], [0.2, 0.3, 0.4, 0.5, 0.8, 0.9]] + ) + semiparam_vec = torch.Tensor([[0.6, 0.7], [0.8, 0.9]]) + + qd = tq.QuantumDevice(n_wires=2) + + qd.reset_states(bsz=2) + parameterised_enc(qd, param_vec) + state1 = qd.get_states_1d() + + qd.reset_states(bsz=2) + semiparam_enc(qd, semiparam_vec) + state2 = qd.get_states_1d() + + check_all_close(state1, state2) + check_all_close(state1, expected_states) + + +if __name__ == "__main__": + test_GeneralEncoder()