From 79fff0eb57e60b9e66fc85ce0e1b0c0ac443cbf4 Mon Sep 17 00:00:00 2001 From: Chris Saunders Date: Thu, 31 Oct 2024 01:05:38 +1000 Subject: [PATCH] Workaround for Do not promote FP8 error Make sure both tensors are on the same device and cast to FP32 to avoid the do not promote FP8 error and allow FP8 models to work --- modules/models/sd35/mmditx.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/modules/models/sd35/mmditx.py b/modules/models/sd35/mmditx.py index d558bc333..b1b043a2f 100644 --- a/modules/models/sd35/mmditx.py +++ b/modules/models/sd35/mmditx.py @@ -904,7 +904,10 @@ def forward( hw = x.shape[-2:] # The line below should be unnecessary when full integrated. x = x[:1,:16,:,:] - x = self.x_embedder(x) + self.cropped_pos_embed(hw).to("cuda") + # Workaround for unable to promote FP8 error with FP8 models + x_embed = self.x_embedder(x).to(torch.float32) + pos_embed = self.cropped_pos_embed(hw).to(torch.float32).to("cuda") + x = x_embed + pos_embed c = self.t_embedder(t, dtype=x.dtype) # (N, D) if y is not None: y = self.y_embedder(y) # (N, D)