diff --git a/imaginairy/api.py b/imaginairy/api.py index 0aaafa82..68f8c931 100755 --- a/imaginairy/api.py +++ b/imaginairy/api.py @@ -23,7 +23,7 @@ from imaginairy.utils import ( fix_torch_nn_layer_norm, get_device, - img_path_to_torch_image, + img_path_or_url_to_torch_image, instantiate_from_config, ) @@ -204,7 +204,7 @@ def imagine( ddim_steps = int(prompt.steps / generation_strength) sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=ddim_eta) - init_image, w, h = img_path_to_torch_image(prompt.init_image) + init_image, w, h = img_path_or_url_to_torch_image(prompt.init_image) init_image = init_image.to(get_device()) init_latent = model.get_first_stage_encoding( model.encode_first_stage(init_image) diff --git a/imaginairy/utils.py b/imaginairy/utils.py index 7de7accf..6f00f10f 100644 --- a/imaginairy/utils.py +++ b/imaginairy/utils.py @@ -2,6 +2,7 @@ import logging import os.path import platform +import urllib.parse from contextlib import contextmanager from functools import lru_cache from typing import List, Optional @@ -99,10 +100,13 @@ def fix_torch_nn_layer_norm(): finally: functional.layer_norm = orig_function - -def img_path_to_torch_image(path, max_height=512, max_width=512): - image = Image.open(path).convert("RGB") - logger.info(f"Loaded input 🖼 of size {image.size} from {path}") +def img_path_or_url_to_torch_image(path_or_url, max_height=512, max_width=512): + is_url = urllib.parse.urlparse(path_or_url).scheme in ('http', 'https',) + if (is_url): + image = Image.open(requests.get(path_or_url, stream=True).raw).convert("RGB") + else: + image = Image.open(path_or_url).convert("RGB") + logger.info(f"Loaded input 🖼 of size {image.size} from {path_or_url}") return pillow_img_to_torch_image(image, max_height=max_height, max_width=max_width)