Skip to content

Commit

Permalink
feat: allow selecting from multiple GPUs to run (#17)
Browse files Browse the repository at this point in the history
* feat: added --gpu option at command line
* docs: added README
  • Loading branch information
tnwei authored Mar 1, 2022
1 parent bbc50aa commit 82aff79
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 3 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@ In the web app, select settings on the sidebar, key in the text prompt, and clic

A one-time download of additional pre-trained weights will occur before generating the first image. Might take a few minutes depending on your internet connection.

If you have multiple GPUs, specify the GPU you want to use by adding `-- --gpu X`. An extra double dash is required to [bypass Streamlit argument parsing](https://github.com/streamlit/streamlit/issues/337). Example commands:

```bash
# Use 2nd GPU
streamlit run app.py -- --gpu 1

# Use 3rd GPU
streamlit run diffusion_app.py -- --gpu 2
```

See: [tips and tricks](docs/tips-n-tricks.md)

## Output and gallery viewer
Expand Down
30 changes: 30 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
import sys
import datetime
import shutil
import torch
import json
import os
import base64
import traceback

import argparse

sys.path.append("./taming-transformers")

Expand Down Expand Up @@ -61,6 +65,7 @@ def generate_image(
zoom_factor: float = 1,
transform_interval: int = 10,
use_cutout_augmentations: bool = True,
device: Optional[torch.device] = None,
) -> None:

### Init -------------------------------------------------------------------
Expand All @@ -85,6 +90,7 @@ def generate_image(
zoom_factor=zoom_factor,
transform_interval=transform_interval,
use_cutout_augmentations=use_cutout_augmentations,
device=device,
)

### Load model -------------------------------------------------------------
Expand Down Expand Up @@ -323,6 +329,29 @@ def generate_image(


if __name__ == "__main__":

# Argparse to capture GPU num
parser = argparse.ArgumentParser()

parser.add_argument(
"--gpu", type=str, default=None, help="Specify GPU number. Defaults to None."
)
args = parser.parse_args()

# Select specific GPU if chosen
if args.gpu is not None:
for i in args.gpu.split(","):
assert (
int(i) < torch.cuda.device_count()
), f"You specified --gpu {args.gpu} but torch.cuda.device_count() returned {torch.cuda.device_count()}"

try:
device = torch.device(f"cuda:{args.gpu}")
except RuntimeError:
print(traceback.format_exc())
else:
device = None

defaults = OmegaConf.load("defaults.yaml")
outputdir = Path("output")
if not outputdir.exists():
Expand Down Expand Up @@ -629,6 +658,7 @@ def generate_image(
zoom_factor=zoom_factor,
transform_interval=transform_interval,
use_cutout_augmentations=use_cutout_augmentations,
device=device,
)

vid_display_slot.video("temp.mp4")
Expand Down
28 changes: 28 additions & 0 deletions diffusion_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@
import shutil
import json
import os
import torch
import traceback
import base64
from PIL import Image
from typing import Optional
import argparse

sys.path.append("./taming-transformers")

Expand All @@ -31,6 +34,7 @@ def generate_image(
init_image: Optional[Image.Image] = None,
skip_timesteps: int = 0,
use_cutout_augmentations: bool = False,
device: Optional[torch.device] = None,
) -> None:

### Init -------------------------------------------------------------------
Expand All @@ -42,6 +46,7 @@ def generate_image(
continue_prev_run=continue_prev_run,
skip_timesteps=skip_timesteps,
use_cutout_augmentations=use_cutout_augmentations,
device=device,
)

# Generate random run ID
Expand Down Expand Up @@ -228,6 +233,28 @@ def generate_image(


if __name__ == "__main__":
# Argparse to capture GPU num
parser = argparse.ArgumentParser()

parser.add_argument(
"--gpu", type=str, default=None, help="Specify GPU number. Defaults to None."
)
args = parser.parse_args()

# Select specific GPU if chosen
if args.gpu is not None:
for i in args.gpu.split(","):
assert (
int(i) < torch.cuda.device_count()
), f"You specified --gpu {args.gpu} but torch.cuda.device_count() returned {torch.cuda.device_count()}"

try:
device = torch.device(f"cuda:{args.gpu}")
except RuntimeError:
print(traceback.format_exc())
else:
device = None

outputdir = Path("output")
if not outputdir.exists():
outputdir.mkdir()
Expand Down Expand Up @@ -398,6 +425,7 @@ def generate_image(
init_image=reference_image,
skip_timesteps=skip_timesteps,
use_cutout_augmentations=use_cutout_augmentations,
device=device,
)
vid_display_slot.video("temp.mp4")
# debug_slot.write(st.session_state) # DEBUG
8 changes: 7 additions & 1 deletion diffusion_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import lpips
from PIL import Image
import kornia.augmentation as K
from typing import Optional

sys.path.append("./guided-diffusion")

Expand Down Expand Up @@ -95,6 +96,7 @@ def __init__(
continue_prev_run: bool = True,
skip_timesteps: int = 0,
use_cutout_augmentations: bool = False,
device: Optional[torch.device] = None,
) -> None:

assert ckpt in DIFFUSION_METHODS_AND_WEIGHTS.keys()
Expand Down Expand Up @@ -160,7 +162,11 @@ def __init__(

self.use_cutout_augmentations = use_cutout_augmentations

self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
if device is None:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = device

print("Using device:", self.device)

def load_model(
Expand Down
8 changes: 6 additions & 2 deletions logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
rotation_angle: float = 0,
zoom_factor: float = 1,
transform_interval: int = 10,
device: Optional[torch.device] = None,
) -> None:
super().__init__()
self.text_input = text_input
Expand Down Expand Up @@ -126,8 +127,11 @@ def __init__(
seed=seed,
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.device = device
if device is None:
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
else:
self.device = device

print("Using device:", device)

self.iterate_counter = 0
Expand Down

0 comments on commit 82aff79

Please sign in to comment.