diff --git a/invokeai/backend/flux/math.py b/invokeai/backend/flux/math.py index 84e7c79be19..57ff8259932 100644 --- a/invokeai/backend/flux/math.py +++ b/invokeai/backend/flux/math.py @@ -32,4 +32,4 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso xk_ = xk.view(*xk.shape[:-1], -1, 1, 2) xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] - return xq_out.view(*xq.shape), xk_out.view(*xk.shape) + return xq_out.view(*xq.shape).type_as(xq), xk_out.view(*xk.shape).type_as(xk)