diff --git a/annotator/hed/__init__.py b/annotator/hed/__init__.py index a6a8fc712f..7279d0e792 100644 --- a/annotator/hed/__init__.py +++ b/annotator/hed/__init__.py @@ -12,6 +12,7 @@ from einops import rearrange from annotator.util import annotator_ckpts_path +import config class DoubleConvBlock(torch.nn.Module): @@ -60,14 +61,14 @@ def __init__(self): if not os.path.exists(modelpath): from basicsr.utils.download_util import load_file_from_url load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) - self.netNetwork = ControlNetHED_Apache2().float().cuda().eval() + self.netNetwork = ControlNetHED_Apache2().float().to(config.device).eval() self.netNetwork.load_state_dict(torch.load(modelpath)) def __call__(self, input_image): assert input_image.ndim == 3 H, W, C = input_image.shape with torch.no_grad(): - image_hed = torch.from_numpy(input_image.copy()).float().cuda() + image_hed = torch.from_numpy(input_image.copy()).float().to(config.device) image_hed = rearrange(image_hed, 'h w c -> 1 c h w') edges = self.netNetwork(image_hed) edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] diff --git a/annotator/midas/__init__.py b/annotator/midas/__init__.py index 36789767f3..5132bdf411 100644 --- a/annotator/midas/__init__.py +++ b/annotator/midas/__init__.py @@ -8,17 +8,18 @@ from einops import rearrange from .api import MiDaSInference +import config class MidasDetector: def __init__(self): - self.model = MiDaSInference(model_type="dpt_hybrid").cuda() + self.model = MiDaSInference(model_type="dpt_hybrid").to(config.device) def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1): assert input_image.ndim == 3 image_depth = input_image with torch.no_grad(): - image_depth = torch.from_numpy(image_depth).float().cuda() + image_depth = torch.from_numpy(image_depth).float().to(config.device) image_depth = image_depth / 127.5 - 1.0 image_depth = rearrange(image_depth, 'h w c -> 1 c h w') depth = self.model(image_depth)[0] diff --git a/annotator/mlsd/__init__.py b/annotator/mlsd/__init__.py index c1860702df..34f18cc0a0 100644 --- a/annotator/mlsd/__init__.py +++ b/annotator/mlsd/__init__.py @@ -13,6 +13,7 @@ from .utils import pred_lines from annotator.util import annotator_ckpts_path +import config remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/mlsd_large_512_fp32.pth" @@ -26,7 +27,7 @@ def __init__(self): load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) model = MobileV2_MLSD_Large() model.load_state_dict(torch.load(model_path), strict=True) - self.model = model.cuda().eval() + self.model = model.to(config.device).eval() def __call__(self, input_image, thr_v, thr_d): assert input_image.ndim == 3 diff --git a/annotator/mlsd/utils.py b/annotator/mlsd/utils.py index ae3cf9420a..3980de82e8 100644 --- a/annotator/mlsd/utils.py +++ b/annotator/mlsd/utils.py @@ -14,6 +14,7 @@ import cv2 import torch from torch.nn import functional as F +import config def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): @@ -58,7 +59,7 @@ def pred_lines(image, model, batch_image = np.expand_dims(resized_image, axis=0).astype('float32') batch_image = (batch_image / 127.5) - 1.0 - batch_image = torch.from_numpy(batch_image).float().cuda() + batch_image = torch.from_numpy(batch_image).float().to(config.device) outputs = model(batch_image) pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) start = vmap[:, :, :2] @@ -109,7 +110,7 @@ def pred_squares(image, batch_image = np.expand_dims(resized_image, axis=0).astype('float32') batch_image = (batch_image / 127.5) - 1.0 - batch_image = torch.from_numpy(batch_image).float().cuda() + batch_image = torch.from_numpy(batch_image).float().to(config.device) outputs = model(batch_image) pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) diff --git a/annotator/openpose/body.py b/annotator/openpose/body.py index 7c3cf7a388..c0686b4028 100644 --- a/annotator/openpose/body.py +++ b/annotator/openpose/body.py @@ -10,13 +10,14 @@ from . import util from .model import bodypose_model +import config class Body(object): def __init__(self, model_path): self.model = bodypose_model() if torch.cuda.is_available(): - self.model = self.model.cuda() - print('cuda') + self.model = self.model.to(config.device) + model_dict = util.transfer(self.model, torch.load(model_path)) self.model.load_state_dict(model_dict) self.model.eval() @@ -42,7 +43,7 @@ def __call__(self, oriImg): data = torch.from_numpy(im).float() if torch.cuda.is_available(): - data = data.cuda() + data = data.to(config.device) # data = data.permute([2, 0, 1]).unsqueeze(0).float() with torch.no_grad(): Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) diff --git a/annotator/openpose/hand.py b/annotator/openpose/hand.py index 3d0bf17165..c8a0b7b041 100644 --- a/annotator/openpose/hand.py +++ b/annotator/openpose/hand.py @@ -11,13 +11,13 @@ from .model import handpose_model from . import util +import config class Hand(object): def __init__(self, model_path): self.model = handpose_model() if torch.cuda.is_available(): - self.model = self.model.cuda() - print('cuda') + self.model = self.model.to(config.device) model_dict = util.transfer(self.model, torch.load(model_path)) self.model.load_state_dict(model_dict) self.model.eval() @@ -42,7 +42,7 @@ def __call__(self, oriImg): data = torch.from_numpy(im).float() if torch.cuda.is_available(): - data = data.cuda() + data = data.to(config.device) # data = data.permute([2, 0, 1]).unsqueeze(0).float() with torch.no_grad(): output = self.model(data).cpu().numpy() diff --git a/annotator/uniformer/__init__.py b/annotator/uniformer/__init__.py index 3364d40997..9a0856a157 100644 --- a/annotator/uniformer/__init__.py +++ b/annotator/uniformer/__init__.py @@ -7,6 +7,7 @@ from annotator.uniformer.mmseg.apis import init_segmentor, inference_segmentor, show_result_pyplot from annotator.uniformer.mmseg.core.evaluation import get_palette from annotator.util import annotator_ckpts_path +import config checkpoint_file = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/upernet_global_small.pth" @@ -19,7 +20,7 @@ def __init__(self): from basicsr.utils.download_util import load_file_from_url load_file_from_url(checkpoint_file, model_dir=annotator_ckpts_path) config_file = os.path.join(os.path.dirname(annotator_ckpts_path), "uniformer", "exp", "upernet_global_small", "config.py") - self.model = init_segmentor(config_file, modelpath).cuda() + self.model = init_segmentor(config_file, modelpath).to(config.device) def __call__(self, img): result = inference_segmentor(self.model, img) diff --git a/annotator/uniformer/mmcv/engine/test.py b/annotator/uniformer/mmcv/engine/test.py index 8dbeef271d..a6bcdfd7ea 100644 --- a/annotator/uniformer/mmcv/engine/test.py +++ b/annotator/uniformer/mmcv/engine/test.py @@ -10,6 +10,7 @@ import annotator.uniformer.mmcv as mmcv from annotator.uniformer.mmcv.runner import get_dist_info +import config def single_gpu_test(model, data_loader): @@ -114,12 +115,12 @@ def collect_results_cpu(result_part, size, tmpdir=None): dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8, - device='cuda') + device=config.device) if rank == 0: mmcv.mkdir_or_exist('.dist_test') tmpdir = tempfile.mkdtemp(dir='.dist_test') tmpdir = torch.tensor( - bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + bytearray(tmpdir.encode()), dtype=torch.uint8, device=config.device) dir_tensor[:len(tmpdir)] = tmpdir dist.broadcast(dir_tensor, 0) tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() @@ -170,14 +171,14 @@ def collect_results_gpu(result_part, size): rank, world_size = get_dist_info() # dump result part to tensor with pickle part_tensor = torch.tensor( - bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device=config.device) # gather all result part tensor shape - shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_tensor = torch.tensor(part_tensor.shape, device=config.device) shape_list = [shape_tensor.clone() for _ in range(world_size)] dist.all_gather(shape_list, shape_tensor) # padding result part tensor to max length shape_max = torch.tensor(shape_list).max() - part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send = torch.zeros(shape_max, dtype=torch.uint8, device=config.device) part_send[:shape_tensor[0]] = part_tensor part_recv_list = [ part_tensor.new_zeros(shape_max) for _ in range(world_size) diff --git a/annotator/uniformer/mmseg/apis/inference.py b/annotator/uniformer/mmseg/apis/inference.py index 90bc1c0c68..e77437c42e 100644 --- a/annotator/uniformer/mmseg/apis/inference.py +++ b/annotator/uniformer/mmseg/apis/inference.py @@ -6,9 +6,10 @@ from annotator.uniformer.mmseg.datasets.pipelines import Compose from annotator.uniformer.mmseg.models import build_segmentor +import config -def init_segmentor(config, checkpoint=None, device='cuda:0'): +def init_segmentor(config, checkpoint=None, device=config.device): """Initialize a segmentor from config file. Args: diff --git a/annotator/uniformer/mmseg/apis/test.py b/annotator/uniformer/mmseg/apis/test.py index e574eb7da0..a5ec8dd76a 100644 --- a/annotator/uniformer/mmseg/apis/test.py +++ b/annotator/uniformer/mmseg/apis/test.py @@ -9,6 +9,7 @@ import torch.distributed as dist from annotator.uniformer.mmcv.image import tensor2imgs from annotator.uniformer.mmcv.runner import get_dist_info +import config def np2tmp(array, temp_file_name=None): @@ -171,11 +172,11 @@ def collect_results_cpu(result_part, size, tmpdir=None): dir_tensor = torch.full((MAX_LEN, ), 32, dtype=torch.uint8, - device='cuda') + device=config.device) if rank == 0: tmpdir = tempfile.mkdtemp() tmpdir = torch.tensor( - bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda') + bytearray(tmpdir.encode()), dtype=torch.uint8, device=config.device) dir_tensor[:len(tmpdir)] = tmpdir dist.broadcast(dir_tensor, 0) tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip() @@ -209,14 +210,14 @@ def collect_results_gpu(result_part, size): rank, world_size = get_dist_info() # dump result part to tensor with pickle part_tensor = torch.tensor( - bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda') + bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device=config.device) # gather all result part tensor shape - shape_tensor = torch.tensor(part_tensor.shape, device='cuda') + shape_tensor = torch.tensor(part_tensor.shape, device=config.device) shape_list = [shape_tensor.clone() for _ in range(world_size)] dist.all_gather(shape_list, shape_tensor) # padding result part tensor to max length shape_max = torch.tensor(shape_list).max() - part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda') + part_send = torch.zeros(shape_max, dtype=torch.uint8, device=config.device) part_send[:shape_tensor[0]] = part_tensor part_recv_list = [ part_tensor.new_zeros(shape_max) for _ in range(world_size) diff --git a/annotator/uniformer/mmseg/apis/train.py b/annotator/uniformer/mmseg/apis/train.py index 63f319a919..72d94dc88a 100644 --- a/annotator/uniformer/mmseg/apis/train.py +++ b/annotator/uniformer/mmseg/apis/train.py @@ -9,6 +9,7 @@ from annotator.uniformer.mmseg.core import DistEvalHook, EvalHook from annotator.uniformer.mmseg.datasets import build_dataloader, build_dataset from annotator.uniformer.mmseg.utils import get_root_logger +import config def set_random_seed(seed, deterministic=False): @@ -60,7 +61,7 @@ def train_segmentor(model, # Sets the `find_unused_parameters` parameter in # torch.nn.parallel.DistributedDataParallel model = MMDistributedDataParallel( - model.cuda(), + model.to(config.device), device_ids=[torch.cuda.current_device()], broadcast_buffers=False, find_unused_parameters=find_unused_parameters) diff --git a/cldm/cldm.py b/cldm/cldm.py index 0b3ac7a575..9ae6d8ee4d 100644 --- a/cldm/cldm.py +++ b/cldm/cldm.py @@ -17,6 +17,7 @@ from ldm.models.diffusion.ddpm import LatentDiffusion from ldm.util import log_txt_as_img, exists, instantiate_from_config from ldm.models.diffusion.ddim import DDIMSampler +import config class ControlledUnetModel(UNetModel): @@ -424,12 +425,12 @@ def configure_optimizers(self): def low_vram_shift(self, is_diffusing): if is_diffusing: - self.model = self.model.cuda() - self.control_model = self.control_model.cuda() + self.model = self.model.to(config.device) + self.control_model = self.control_model.to(config.device) self.first_stage_model = self.first_stage_model.cpu() self.cond_stage_model = self.cond_stage_model.cpu() else: self.model = self.model.cpu() self.control_model = self.control_model.cpu() - self.first_stage_model = self.first_stage_model.cuda() - self.cond_stage_model = self.cond_stage_model.cuda() + self.first_stage_model = self.first_stage_model.to(config.device) + self.cond_stage_model = self.cond_stage_model.to(config.device) diff --git a/config.py b/config.py index e0c738d8cb..cfed93ec6e 100644 --- a/config.py +++ b/config.py @@ -1 +1,2 @@ save_memory = False +device = 'cuda' # 'cpu' diff --git a/gradio_canny2image.py b/gradio_canny2image.py index 9866cac5b3..9e143954b3 100644 --- a/gradio_canny2image.py +++ b/gradio_canny2image.py @@ -19,7 +19,7 @@ model = create_model('./models/cldm_v15.yaml').cpu() model.load_state_dict(load_state_dict('./models/control_sd15_canny.pth', location='cuda')) -model = model.cuda() +model = model.to(config.device) ddim_sampler = DDIMSampler(model) @@ -31,7 +31,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = apply_canny(img, low_threshold, high_threshold) detected_map = HWC3(detected_map) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(config.device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_depth2image.py b/gradio_depth2image.py index ee678999ae..7288a4f164 100644 --- a/gradio_depth2image.py +++ b/gradio_depth2image.py @@ -19,7 +19,7 @@ model = create_model('./models/cldm_v15.yaml').cpu() model.load_state_dict(load_state_dict('./models/control_sd15_depth.pth', location='cuda')) -model = model.cuda() +model = model.to(config.device) ddim_sampler = DDIMSampler(model) @@ -33,7 +33,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(config.device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_fake_scribble2image.py b/gradio_fake_scribble2image.py index a7cd375f75..905b3e0996 100644 --- a/gradio_fake_scribble2image.py +++ b/gradio_fake_scribble2image.py @@ -19,7 +19,7 @@ model = create_model('./models/cldm_v15.yaml').cpu() model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) -model = model.cuda() +model = model.to(config.device) ddim_sampler = DDIMSampler(model) @@ -37,7 +37,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map[detected_map > 4] = 255 detected_map[detected_map < 255] = 0 - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(config.device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_hed2image.py b/gradio_hed2image.py index 1ceff67969..8c7546897d 100644 --- a/gradio_hed2image.py +++ b/gradio_hed2image.py @@ -19,7 +19,7 @@ model = create_model('./models/cldm_v15.yaml').cpu() model.load_state_dict(load_state_dict('./models/control_sd15_hed.pth', location='cuda')) -model = model.cuda() +model = model.to(config.device) ddim_sampler = DDIMSampler(model) @@ -33,7 +33,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(config.device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_hough2image.py b/gradio_hough2image.py index 6095eeb676..b931eb560e 100644 --- a/gradio_hough2image.py +++ b/gradio_hough2image.py @@ -19,7 +19,7 @@ model = create_model('./models/cldm_v15.yaml').cpu() model.load_state_dict(load_state_dict('./models/control_sd15_mlsd.pth', location='cuda')) -model = model.cuda() +model = model.to(config.device) ddim_sampler = DDIMSampler(model) @@ -33,7 +33,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(config.device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_normal2image.py b/gradio_normal2image.py index 30aea2f8d4..7206b82651 100644 --- a/gradio_normal2image.py +++ b/gradio_normal2image.py @@ -19,7 +19,7 @@ model = create_model('./models/cldm_v15.yaml').cpu() model.load_state_dict(load_state_dict('./models/control_sd15_normal.pth', location='cuda')) -model = model.cuda() +model = model.to(config.device) ddim_sampler = DDIMSampler(model) @@ -33,7 +33,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LINEAR) - control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map[:, :, ::-1].copy()).float().to(config.device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_pose2image.py b/gradio_pose2image.py index 700973bfab..9c9645bed3 100644 --- a/gradio_pose2image.py +++ b/gradio_pose2image.py @@ -19,7 +19,7 @@ model = create_model('./models/cldm_v15.yaml').cpu() model.load_state_dict(load_state_dict('./models/control_sd15_openpose.pth', location='cuda')) -model = model.cuda() +model = model.to(config.device) ddim_sampler = DDIMSampler(model) @@ -33,7 +33,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(config.device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_scribble2image.py b/gradio_scribble2image.py index 8abbc25bde..801d33951d 100644 --- a/gradio_scribble2image.py +++ b/gradio_scribble2image.py @@ -16,7 +16,7 @@ model = create_model('./models/cldm_v15.yaml').cpu() model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) -model = model.cuda() +model = model.to(config.device) ddim_sampler = DDIMSampler(model) @@ -28,7 +28,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = np.zeros_like(img, dtype=np.uint8) detected_map[np.min(img, axis=2) < 127] = 255 - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(config.device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_scribble2image_interactive.py b/gradio_scribble2image_interactive.py index 7308bcc1bb..e7e82143a1 100644 --- a/gradio_scribble2image_interactive.py +++ b/gradio_scribble2image_interactive.py @@ -16,7 +16,7 @@ model = create_model('./models/cldm_v15.yaml').cpu() model.load_state_dict(load_state_dict('./models/control_sd15_scribble.pth', location='cuda')) -model = model.cuda() +model = model.to(config.device) ddim_sampler = DDIMSampler(model) @@ -28,7 +28,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = np.zeros_like(img, dtype=np.uint8) detected_map[np.min(img, axis=2) > 127] = 255 - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(config.device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/gradio_seg2image.py b/gradio_seg2image.py index c3854dc762..5209953f9a 100644 --- a/gradio_seg2image.py +++ b/gradio_seg2image.py @@ -19,7 +19,7 @@ model = create_model('./models/cldm_v15.yaml').cpu() model.load_state_dict(load_state_dict('./models/control_sd15_seg.pth', location='cuda')) -model = model.cuda() +model = model.to(config.device) ddim_sampler = DDIMSampler(model) @@ -32,7 +32,7 @@ def process(input_image, prompt, a_prompt, n_prompt, num_samples, image_resoluti detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) - control = torch.from_numpy(detected_map.copy()).float().cuda() / 255.0 + control = torch.from_numpy(detected_map.copy()).float().to(config.device) / 255.0 control = torch.stack([control for _ in range(num_samples)], dim=0) control = einops.rearrange(control, 'b h w c -> b c h w').clone() diff --git a/ldm/modules/encoders/modules.py b/ldm/modules/encoders/modules.py index 4edd5496b9..afd4a555c3 100644 --- a/ldm/modules/encoders/modules.py +++ b/ldm/modules/encoders/modules.py @@ -6,6 +6,7 @@ import open_clip from ldm.util import default, count_params +import config class AbstractEncoder(nn.Module): @@ -42,7 +43,7 @@ def forward(self, batch, key=None, disable_dropout=False): c = self.embedding(c) return c - def get_unconditional_conditioning(self, bs, device="cuda"): + def get_unconditional_conditioning(self, bs, device=config.device): uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) uc = torch.ones((bs,), device=device) * uc_class uc = {self.key: uc} @@ -57,7 +58,7 @@ def disabled_train(self, mode=True): class FrozenT5Embedder(AbstractEncoder): """Uses the T5 transformer encoder for text""" - def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + def __init__(self, version="google/t5-v1_1-large", device=config.device, max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl super().__init__() self.tokenizer = T5Tokenizer.from_pretrained(version) self.transformer = T5EncoderModel.from_pretrained(version) @@ -92,7 +93,7 @@ class FrozenCLIPEmbedder(AbstractEncoder): "pooled", "hidden" ] - def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + def __init__(self, version="openai/clip-vit-large-patch14", device=config.device, max_length=77, freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32 super().__init__() assert layer in self.LAYERS @@ -140,7 +141,7 @@ class FrozenOpenCLIPEmbedder(AbstractEncoder): "last", "penultimate" ] - def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device=config.device, max_length=77, freeze=True, layer="last"): super().__init__() assert layer in self.LAYERS @@ -194,7 +195,7 @@ def encode(self, text): class FrozenCLIPT5Encoder(AbstractEncoder): - def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device=config.device, clip_max_length=77, t5_max_length=77): super().__init__() self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length)