Skip to content

Commit

Permalink
Merge pull request #230 from rafaelha/master
Browse files Browse the repository at this point in the history
Fix bug where scalar phase was not conjugated by BaseGraph.adjoint()
  • Loading branch information
jvdwetering authored May 31, 2024
2 parents efd400c + 944a0b6 commit 851fdf2
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 6 deletions.
5 changes: 3 additions & 2 deletions pyzx/graph/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,11 @@ def copy(self, adjoint:bool=False, backend:Optional[str]=None) -> 'BaseGraph':
backend = type(self).backend
g = Graph(backend = backend)
g.track_phases = self.track_phases
g.scalar = self.scalar.copy()
g.scalar = self.scalar.copy(conjugate=adjoint)
g.merge_vdata = self.merge_vdata
mult:int = 1
if adjoint: mult = -1
if adjoint:
mult = -1

#g.add_vertices(self.num_vertices())
ty = self.types()
Expand Down
20 changes: 16 additions & 4 deletions pyzx/graph/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,27 @@ def __str__(self) -> str:
def __complex__(self) -> complex:
return self.to_number()

def copy(self) -> 'Scalar':
def copy(self, conjugate: bool = False) -> 'Scalar':
"""Create a copy of the Scalar. If ``conjugate`` is set, the copy will be complex conjugated.
Args:
conjugate: set to True to return a complex-conjugated copy
Returns:
A copy of the Scalar
"""
s = Scalar()
s.power2 = self.power2
s.phase = self.phase
s.phasenodes = copy.copy(self.phasenodes)
s.floatfactor = self.floatfactor
s.phase = self.phase if not conjugate else -self.phase
s.phasenodes = copy.copy(self.phasenodes) if not conjugate else [-p for p in self.phasenodes]
s.floatfactor = self.floatfactor if not conjugate else self.floatfactor.conjugate()
s.is_unknown = self.is_unknown
s.is_zero = self.is_zero
return s

def conjugate(self) -> 'Scalar':
"""Returns a new Scalar equal to the complex conjugate"""
return self.copy(conjugate=True)

def to_number(self) -> complex:
if self.is_zero: return 0
Expand Down
13 changes: 13 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import numpy as np
from pyzx.tensor import compare_tensors
from pyzx.graph.scalar import Scalar



Expand Down Expand Up @@ -130,6 +131,17 @@ def test_copy(self):
self.assertEqual(g.num_edges(),g2.num_edges())
v1, v2 = list(g2.vertices())
self.assertEqual(g.edge_type(g.edge(v1,v2)),EdgeType.HADAMARD)

def test_adjoint_scalar(self):
g = Graph()
scalar = Scalar()
scalar.phase = Fraction(1, 4)
scalar.phasenodes = [Fraction(1, 2), Fraction(1, 4)]
scalar.floatfactor = 1.3 + 0.1*1j
scalar.power2 = 2
g.scalar = scalar
g_adj = g.adjoint()
self.assertAlmostEqual(g_adj.scalar.to_number(), scalar.to_number().conjugate())

@unittest.skipUnless(np, "numpy needs to be installed for this to run")
def test_remove_isolated_vertex_preserves_semantics(self):
Expand Down Expand Up @@ -176,6 +188,7 @@ def setUp(self):
g.set_outputs((o1, o2))
g.add_edges([(i1,v),(i2,w),(v,w),(v,o1),(w,o2)])
self.i1, self.i2, self.v, self.w, self.o1, self.o2 = i1, i2, v, w, o1, o2
print(self.graph.scalar.to_number())

def test_qubit_index_and_depth(self):
g = self.graph
Expand Down
78 changes: 78 additions & 0 deletions tests/test_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# PyZX - Python library for quantum circuit rewriting
# and optimization using the ZX-calculus
# Copyright (C) 2018 - Aleks Kissinger and John van de Wetering

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import sys
import numpy as np
from fractions import Fraction
from pyzx.graph.scalar import Scalar

if __name__ == '__main__':
sys.path.append('..')
sys.path.append('.')


class TestScalar(unittest.TestCase):

def test_initialization(self):
scalar = Scalar()
self.assertEqual(scalar.power2, 0)
self.assertEqual(scalar.phase, Fraction(0))
self.assertEqual(scalar.phasenodes, [])
self.assertEqual(scalar.floatfactor, 1.0)
self.assertFalse(scalar.is_unknown)
self.assertFalse(scalar.is_zero)

def test_copy(self):
scalar = Scalar()
scalar.power2 = 2
scalar.phase = Fraction(1, 3)
scalar.phasenodes = [Fraction(1, 2)]
scalar.floatfactor = 2.0
copied_scalar = scalar.copy()
self.assertEqual(copied_scalar.power2, 2)
self.assertEqual(copied_scalar.phase, Fraction(1, 3))
self.assertEqual(copied_scalar.phasenodes, [Fraction(1, 2)])
self.assertEqual(copied_scalar.floatfactor, 2.0)
self.assertFalse(copied_scalar.is_unknown)
self.assertFalse(copied_scalar.is_zero)

def test_conjugate(self):
scalar = Scalar()
scalar.phase = Fraction(1, 3)
scalar.phasenodes = [Fraction(1, 2), Fraction(3, 4)]
scalar.floatfactor = 2.0 + 1.0j

conjugated_scalar = scalar.conjugate()
self.assertEqual(conjugated_scalar.phase, -Fraction(1, 3))
self.assertEqual(conjugated_scalar.phasenodes, [-Fraction(1, 2), -Fraction(3, 4)])
self.assertEqual(conjugated_scalar.floatfactor, 2.0 - 1.0j)
self.assertAlmostEqual(conjugated_scalar.to_number(), scalar.to_number().conjugate())

def test_to_number(self):
scalar = Scalar()
scalar.phase = Fraction(1, 4)
scalar.phasenodes = [Fraction(1, 2)]
scalar.floatfactor = 2.0
scalar.power2 = 2
number = scalar.to_number()
expected_number = (np.exp(1j * np.pi * 0.25) * (1 + np.exp(1j * np.pi * 0.5)) * (2 ** 0.5) ** 2) * 2.0
self.assertAlmostEqual(number.real, expected_number.real, places=5)
self.assertAlmostEqual(number.imag, expected_number.imag, places=5)


if __name__ == '__main__':
unittest.main()

0 comments on commit 851fdf2

Please sign in to comment.