From d310f7811dc17de38f5df9b48dfa8252a0f9241d Mon Sep 17 00:00:00 2001 From: Charles MOUSSA Date: Tue, 10 Dec 2024 09:01:12 +0100 Subject: [PATCH] add current implementation of channel apply --- horqrux/__init__.py | 2 +- horqrux/api.py | 7 ++++--- horqrux/apply.py | 27 ++++++++++++++++++++------- horqrux/shots.py | 21 +++++++++++++++------ 4 files changed, 40 insertions(+), 17 deletions(-) diff --git a/horqrux/__init__.py b/horqrux/__init__.py index 513c031..2569a54 100644 --- a/horqrux/__init__.py +++ b/horqrux/__init__.py @@ -1,6 +1,6 @@ from __future__ import annotations -from .api import expectation +from .api import expectation, run from .apply import apply_gate, apply_operator from .circuit import QuantumCircuit, sample from .parametric import PHASE, RX, RY, RZ diff --git a/horqrux/api.py b/horqrux/api.py index f4f7382..76990b6 100644 --- a/horqrux/api.py +++ b/horqrux/api.py @@ -19,8 +19,9 @@ def run( circuit: GateSequence, state: Array, values: dict[str, float] = dict(), + is_state_densitymat: bool = False, ) -> Array: - return apply_gate(state, circuit, values) + return apply_gate(state, circuit, values, is_state_densitymat=is_state_densitymat) def sample( @@ -121,8 +122,8 @@ def expectation( ) # Type checking is disabled because mypy doesn't parse checkify.check. # type: ignore - if is_state_densitymat: - raise NotImplementedError("Expectation from density matrices is not yet supported!") + # if is_state_densitymat: + # raise NotImplementedError("Expectation with density matrices is not yet supported!") return finite_shots_fwd( state, gates, diff --git a/horqrux/apply.py b/horqrux/apply.py index dd380f5..9e8de9f 100644 --- a/horqrux/apply.py +++ b/horqrux/apply.py @@ -20,6 +20,7 @@ def apply_operator( operator: Array, target: Tuple[int, ...], control: Tuple[int | None, ...], + is_state_densitymat: bool = False, ) -> State: """Applies an operator, i.e. a single array of shape [2, 2, ...], on a given state of shape [2 for _ in range(n_qubits)] for a given set of target and control qubits. @@ -37,6 +38,7 @@ def apply_operator( operator: Array to contract over 'state'. target: Tuple of target qubits on which to apply the 'operator' to. control: Tuple of control qubits. + is_state_densitymat: Whether the state is provided as a density matrix. Returns: State after applying 'operator'. @@ -45,12 +47,21 @@ def apply_operator( if is_controlled(control): operator = _controlled(operator, len(control)) state_dims = (*control, *target) # type: ignore[arg-type] - n_qubits = int(np.log2(operator.shape[1])) - operator = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits))) - op_dims = tuple(np.arange(operator.ndim // 2, operator.ndim, dtype=int)) - state = jnp.tensordot(a=operator, b=state, axes=(op_dims, state_dims)) + n_qubits_op = int(np.log2(operator.shape[1])) + operator_reshaped = operator.reshape(tuple(2 for _ in np.arange(2 * n_qubits_op))) + op_dims = tuple(np.arange(operator_reshaped.ndim // 2, operator_reshaped.ndim, dtype=int)) + # Apply operator + new_state_dims = tuple(i for i in range(len(state_dims))) - return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + state = jnp.tensordot(a=operator_reshaped, b=state, axes=(op_dims, state_dims)) + if not is_state_densitymat: + return jnp.moveaxis(a=state, source=new_state_dims, destination=state_dims) + operator_dagger = _dagger(operator_reshaped) + + # Apply operator to density matrix: ρ' = O ρ O† + + state = jnp.tensordot(a=operator_dagger, b=state, axes=(op_dims, state_dims)) + return state def apply_kraus_operator( @@ -81,8 +92,9 @@ def apply_operator_with_noise( target: Tuple[int, ...], control: Tuple[int | None, ...], noise: NoiseProtocol, + is_state_densitymat: bool = False, ) -> State: - state_gate = apply_operator(state, operator, target, control) + state_gate = apply_operator(state, operator, target, control, is_state_densitymat) if len(noise) == 0: return state_gate else: @@ -188,10 +200,11 @@ def apply_gate( has_noise = len(reduce(add, noise)) > 0 if has_noise and not is_state_densitymat: state = density_mat(state) + is_state_densitymat = True output_state = reduce( lambda state, gate: apply_operator_with_noise(state, *gate), - zip(operator, target, control, noise), + zip(operator, target, control, noise, (is_state_densitymat,) * len(target)), state, ) diff --git a/horqrux/shots.py b/horqrux/shots.py index 9069425..614dfd8 100644 --- a/horqrux/shots.py +++ b/horqrux/shots.py @@ -49,12 +49,21 @@ def finite_shots_fwd( """ state = apply_gate(state, gates, values, is_state_densitymat=is_state_densitymat) n_qubits = len(state.shape) - mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] - eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] - eigvecs, eigvals = align_eigenvectors(eigs) - inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) - probs = jnp.abs(inner_prod) ** 2 - return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0) + if not is_state_densitymat: + mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] + eigs = [jnp.linalg.eigh(mat) for mat in mat_obs] + eigvecs, eigvals = align_eigenvectors(eigs) + inner_prod = jnp.matmul(jnp.conjugate(eigvecs.T), state.flatten()) + probs = jnp.abs(inner_prod) ** 2 + return jax.random.choice(key=key, a=eigvals, p=probs, shape=(n_shots,)).mean(axis=0) + else: + n_qubits = n_qubits // 2 + mat_obs = [observable_to_matrix(observable, n_qubits) for observable in observables] + mat_obs = jnp.stack(mat_obs) + dim = 2**n_qubits + rho = state.reshape((dim, dim)) + prod = jnp.matmul(mat_obs, rho) + return jnp.trace(prod, axis1=-2, axis2=-1).real def align_eigenvectors(eigs: list[tuple[Array, Array]]) -> tuple[Array, Array]: