diff --git a/cldm/ddim_hacked.py b/cldm/ddim_hacked.py index 25b1bc9472..7472e46578 100644 --- a/cldm/ddim_hacked.py +++ b/cldm/ddim_hacked.py @@ -87,6 +87,7 @@ def sample(self, elif isinstance(conditioning, list): for ctmp in conditioning: + cbs = ctmp.shape[0] if ctmp.shape[0] != batch_size: print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")