Skip to content

Commit

Permalink
Fix loss casting
Browse files Browse the repository at this point in the history
  • Loading branch information
stavros11 committed Jun 10, 2021
1 parent bc0fce5 commit 2088340
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/qibo/models/variational.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,11 @@ def _loss(params, circuit, hamiltonian):
for gate in self.circuit.queue:
_ = gate.cache
loss = K.compile(_loss)
elif method != "sgd":
elif method == "cma":
dtype = getattr(K.np, K._dtypes.get("DTYPE"))
loss = lambda p, c, h: dtype(K.to_numpy(_loss(p, c, h)))
elif method != "sgd":
loss = lambda p, c, h: K.to_numpy(_loss(p, c, h))
else:
loss = _loss

Expand Down

0 comments on commit 2088340

Please sign in to comment.