diff --git a/src/qibo/models/variational.py b/src/qibo/models/variational.py index de6cfbdc65..1fa040780b 100644 --- a/src/qibo/models/variational.py +++ b/src/qibo/models/variational.py @@ -78,12 +78,12 @@ def _loss(params, circuit, hamiltonian): for gate in self.circuit.queue: _ = gate.cache loss = K.compile(_loss) + elif method != "sgd": + dtype = getattr(K.np, K._dtypes.get("DTYPE")) + loss = lambda p, c, h: dtype(K.to_numpy(_loss(p, c, h))) else: loss = _loss - if method != "sgd": - loss = lambda p, c, h: K.to_numpy(_loss(p, c, h)) - result, parameters, extra = self.optimizers.optimize(loss, initial_state, args=(self.circuit, self.hamiltonian), method=method, jac=jac, hess=hess, hessp=hessp,