Skip to content

Commit

Permalink
Fix loss casting for cma
Browse files Browse the repository at this point in the history
  • Loading branch information
stavros11 committed Jun 10, 2021
1 parent d09f8f8 commit bc0fce5
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions src/qibo/models/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,12 @@ def _loss(params, circuit, hamiltonian):
for gate in self.circuit.queue:
_ = gate.cache
loss = K.compile(_loss)
elif method != "sgd":
dtype = getattr(K.np, K._dtypes.get("DTYPE"))
loss = lambda p, c, h: dtype(K.to_numpy(_loss(p, c, h)))
else:
loss = _loss

if method != "sgd":
loss = lambda p, c, h: K.to_numpy(_loss(p, c, h))

result, parameters, extra = self.optimizers.optimize(loss, initial_state,
args=(self.circuit, self.hamiltonian),
method=method, jac=jac, hess=hess, hessp=hessp,
Expand Down

0 comments on commit bc0fce5

Please sign in to comment.