diff --git a/pyzx/graph/base.py b/pyzx/graph/base.py index c27e1847..8f6cfec4 100644 --- a/pyzx/graph/base.py +++ b/pyzx/graph/base.py @@ -146,10 +146,11 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph': backend = type(self).backend g = Graph(backend = backend) g.track_phases = self.track_phases - g.scalar = self.scalar.copy() + g.scalar = self.scalar.copy(conjugate=adjoint) g.merge_vdata = self.merge_vdata mult:int = 1 - if adjoint: mult = -1 + if adjoint: + mult = -1 #g.add_vertices(self.num_vertices()) ty = self.types() diff --git a/pyzx/graph/scalar.py b/pyzx/graph/scalar.py index 3d377e48..28a8f4eb 100644 --- a/pyzx/graph/scalar.py +++ b/pyzx/graph/scalar.py @@ -79,15 +79,27 @@ def __str__(self) -> str: def __complex__(self) -> complex: return self.to_number() - def copy(self) -> 'Scalar': + def copy(self, conjugate: bool = False) -> 'Scalar': + """Create a copy of the Scalar. If ``conjugate`` is set, the copy will be complex conjugated. + + Args: + conjugate: set to True to return a complex-conjugated copy + + Returns: + A copy of the Scalar + """ s = Scalar() s.power2 = self.power2 - s.phase = self.phase - s.phasenodes = copy.copy(self.phasenodes) - s.floatfactor = self.floatfactor + s.phase = self.phase if not conjugate else -self.phase + s.phasenodes = copy.copy(self.phasenodes) if not conjugate else [-p for p in self.phasenodes] + s.floatfactor = self.floatfactor if not conjugate else self.floatfactor.conjugate() s.is_unknown = self.is_unknown s.is_zero = self.is_zero return s + + def conjugate(self) -> 'Scalar': + """Returns a new Scalar equal to the complex conjugate""" + return self.copy(conjugate=True) def to_number(self) -> complex: if self.is_zero: return 0 diff --git a/tests/test_graph.py b/tests/test_graph.py index a41fd8a0..b17e59e0 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -30,6 +30,7 @@ import numpy as np from pyzx.tensor import compare_tensors +from pyzx.graph.scalar import Scalar @@ -130,6 +131,17 @@ def test_copy(self): self.assertEqual(g.num_edges(),g2.num_edges()) v1, v2 = list(g2.vertices()) self.assertEqual(g.edge_type(g.edge(v1,v2)),EdgeType.HADAMARD) + + def test_adjoint_scalar(self): + g = Graph() + scalar = Scalar() + scalar.phase = Fraction(1, 4) + scalar.phasenodes = [Fraction(1, 2), Fraction(1, 4)] + scalar.floatfactor = 1.3 + 0.1*1j + scalar.power2 = 2 + g.scalar = scalar + g_adj = g.adjoint() + self.assertAlmostEqual(g_adj.scalar.to_number(), scalar.to_number().conjugate()) @unittest.skipUnless(np, "numpy needs to be installed for this to run") def test_remove_isolated_vertex_preserves_semantics(self): @@ -176,6 +188,7 @@ def setUp(self): g.set_outputs((o1, o2)) g.add_edges([(i1,v),(i2,w),(v,w),(v,o1),(w,o2)]) self.i1, self.i2, self.v, self.w, self.o1, self.o2 = i1, i2, v, w, o1, o2 + print(self.graph.scalar.to_number()) def test_qubit_index_and_depth(self): g = self.graph diff --git a/tests/test_scalar.py b/tests/test_scalar.py new file mode 100644 index 00000000..49433687 --- /dev/null +++ b/tests/test_scalar.py @@ -0,0 +1,78 @@ +# PyZX - Python library for quantum circuit rewriting +# and optimization using the ZX-calculus +# Copyright (C) 2018 - Aleks Kissinger and John van de Wetering + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import numpy as np +from fractions import Fraction +from pyzx.graph.scalar import Scalar + +if __name__ == '__main__': + sys.path.append('..') + sys.path.append('.') + + +class TestScalar(unittest.TestCase): + + def test_initialization(self): + scalar = Scalar() + self.assertEqual(scalar.power2, 0) + self.assertEqual(scalar.phase, Fraction(0)) + self.assertEqual(scalar.phasenodes, []) + self.assertEqual(scalar.floatfactor, 1.0) + self.assertFalse(scalar.is_unknown) + self.assertFalse(scalar.is_zero) + + def test_copy(self): + scalar = Scalar() + scalar.power2 = 2 + scalar.phase = Fraction(1, 3) + scalar.phasenodes = [Fraction(1, 2)] + scalar.floatfactor = 2.0 + copied_scalar = scalar.copy() + self.assertEqual(copied_scalar.power2, 2) + self.assertEqual(copied_scalar.phase, Fraction(1, 3)) + self.assertEqual(copied_scalar.phasenodes, [Fraction(1, 2)]) + self.assertEqual(copied_scalar.floatfactor, 2.0) + self.assertFalse(copied_scalar.is_unknown) + self.assertFalse(copied_scalar.is_zero) + + def test_conjugate(self): + scalar = Scalar() + scalar.phase = Fraction(1, 3) + scalar.phasenodes = [Fraction(1, 2), Fraction(3, 4)] + scalar.floatfactor = 2.0 + 1.0j + + conjugated_scalar = scalar.conjugate() + self.assertEqual(conjugated_scalar.phase, -Fraction(1, 3)) + self.assertEqual(conjugated_scalar.phasenodes, [-Fraction(1, 2), -Fraction(3, 4)]) + self.assertEqual(conjugated_scalar.floatfactor, 2.0 - 1.0j) + self.assertAlmostEqual(conjugated_scalar.to_number(), scalar.to_number().conjugate()) + + def test_to_number(self): + scalar = Scalar() + scalar.phase = Fraction(1, 4) + scalar.phasenodes = [Fraction(1, 2)] + scalar.floatfactor = 2.0 + scalar.power2 = 2 + number = scalar.to_number() + expected_number = (np.exp(1j * np.pi * 0.25) * (1 + np.exp(1j * np.pi * 0.5)) * (2 ** 0.5) ** 2) * 2.0 + self.assertAlmostEqual(number.real, expected_number.real, places=5) + self.assertAlmostEqual(number.imag, expected_number.imag, places=5) + + +if __name__ == '__main__': + unittest.main()