Skip to content

Commit

Permalink
Merge pull request #1433 from millie-v/sample-image-without-cuda
Browse files Browse the repository at this point in the history
Generate sample images without having CUDA (such as on Macs)
  • Loading branch information
kohya-ss authored Sep 7, 2024
2 parents d5c076c + 2e67978 commit 319e4d9
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5404,7 +5404,7 @@ def sample_images_common(
clean_memory_on_device(accelerator.device)

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
if torch.cuda.is_available() and cuda_rng_state is not None:
torch.cuda.set_rng_state(cuda_rng_state)
vae.to(org_vae_device)

Expand Down Expand Up @@ -5438,11 +5438,13 @@ def sample_image_inference(

if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
if torch.cuda.is_available():
torch.cuda.seed()

scheduler = get_my_scheduler(
sample_sampler=sampler_name,
Expand Down Expand Up @@ -5477,8 +5479,9 @@ def sample_image_inference(
controlnet_image=controlnet_image,
)

with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()
if torch.cuda.is_available():
with torch.cuda.device(torch.cuda.current_device()):
torch.cuda.empty_cache()

image = pipeline.latents_to_image(latents)[0]

Expand Down

0 comments on commit 319e4d9

Please sign in to comment.