Skip to content

Commit

Permalink
feat: enable initializing NativeGates from list[str]
Browse files Browse the repository at this point in the history
  • Loading branch information
changsookim committed Oct 16, 2024
1 parent 6c77fde commit e913307
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
7 changes: 1 addition & 6 deletions src/qibo/backends/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
from functools import reduce
from importlib import import_module
from operator import or_

import networkx as nx
import numpy as np
Expand Down Expand Up @@ -146,9 +144,6 @@ def _default_transpiler(cls):
and natives is not None
and connectivity_edges is not None
):
# natives_enum = NativeGates.from_gatelist(natives)
natives_enum = reduce(or_, [NativeGates[nat] for nat in natives])

# only for q{i} naming
node_mapping = {q: i for i, q in enumerate(qubits)}
edges = [
Expand All @@ -162,7 +157,7 @@ def _default_transpiler(cls):
Preprocessing(connectivity),
Trivial(connectivity),
Sabre(connectivity),
Unroller(natives_enum),
Unroller(NativeGates[natives]),
],
)

Expand Down
13 changes: 11 additions & 2 deletions src/qibo/transpiler/unroller.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from enum import Flag, auto
from enum import EnumMeta, Flag, auto
from functools import reduce
from operator import or_

from qibo import gates
from qibo.backends import _check_backend
Expand All @@ -15,7 +17,14 @@
)


class NativeGates(Flag):
class FlagMeta(EnumMeta):
def __getitem__(self, keys):
if isinstance(keys, str):
return super().__getitem__(keys)
return reduce(or_, [self[key] for key in keys])


class NativeGates(Flag, metaclass=FlagMeta):
"""Define native gates supported by the unroller. A native gate set should contain at least
one two-qubit gate (:class:`qibo.gates.gates.CZ` or :class:`qibo.gates.gates.iSWAP`),
and at least one single-qubit gate
Expand Down
11 changes: 8 additions & 3 deletions tests/test_transpiler_unroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@ def test_native_gates_from_gatelist_fail():
NativeGates.from_gatelist([gates.RZ, gates.X(0)])


def test_native_gate_str_from_gatelist_fail():
with pytest.raises(ValueError):
NativeGates.from_gatelist(["qibo"])
def test_native_gate_str_list():
testlist = ["I", "Z", "RZ", "M", "GPI2", "U3", "CZ", "iSWAP", "CNOT"]
natives = NativeGates[testlist]
for gate in testlist:
assert NativeGates[gate] in natives

with pytest.raises(KeyError):
NativeGates[["qi", "bo"]] # Invalid gate names


def test_translate_gate_error_1q():
Expand Down

0 comments on commit e913307

Please sign in to comment.