diff --git a/src/qibo/models/variational.py b/src/qibo/models/variational.py index 8fabe5c46c..16952200b1 100644 --- a/src/qibo/models/variational.py +++ b/src/qibo/models/variational.py @@ -341,7 +341,11 @@ def execute(self, initial_state=None): Returns: State vector after applying the QAOA exponential gates. """ - state = self.get_initial_state(initial_state) + if initial_state is None: + state = self.hamiltonian.backend.plus_state(self.nqubits) + else: + state = self.hamiltonian.backend.cast(initial_state) + self.calculate_callbacks(state) n = int(self.params.shape[0]) for i in range(n // 2): @@ -355,19 +359,6 @@ def __call__(self, initial_state=None): """Equivalent to :meth:`qibo.models.QAOA.execute`.""" return self.execute(initial_state) - def get_initial_state(self, state=None): - """""" - #TODO: update this - # if self.accelerators is not None: - # c = self.hamiltonian.circuit(self.params[0]) - # if state is None: - # state = self.states.DistributedState.plus_state(c) - # return c.get_initial_state(state) - - if state is None: - return self.hamiltonian.backend.plus_state(self.nqubits) - return self.hamiltonian.backend.cast(state) - def minimize(self, initial_p, initial_state=None, method='Powell', jac=None, hess=None, hessp=None, bounds=None, constraints=(), tol=None, callback=None, options=None, compile=False, processes=None): diff --git a/src/qibo/tests/test_models_variational.py b/src/qibo/tests/test_models_variational.py index 8ebcea5454..06929f005d 100644 --- a/src/qibo/tests/test_models_variational.py +++ b/src/qibo/tests/test_models_variational.py @@ -126,15 +126,6 @@ def test_vqe(backend, method, options, compile, filename, skip_parallel): assert_regression_fixture(backend, params, filename) -def test_initial_state(backend, accelerators): - h = hamiltonians.TFIM(5, h=1.0, dense=False, backend=backend) - qaoa = models.QAOA(h, accelerators=accelerators) - qaoa.set_parameters(np.random.random(4)) - target_state = np.ones(2 ** 5) / np.sqrt(2 ** 5) - final_state = qaoa.get_initial_state() - backend.assert_allclose(final_state, target_state) - - @pytest.mark.parametrize("solver,dense", [("exp", False), ("exp", True), ("rk4", False), ("rk4", True), @@ -177,11 +168,7 @@ def test_qaoa_callbacks(backend, accelerators): from qibo import callbacks # use ``Y`` Hamiltonian so that there are no errors # in the Trotter decomposition - if accelerators: - with backend.on_cpu(): - h = hamiltonians.Y(5, backend=backend) - else: - h = hamiltonians.Y(5, backend=backend) + h = hamiltonians.Y(5, backend=backend) energy = callbacks.Energy(h) params = 0.1 * np.random.random(4) state = random_state(5)