Skip to content

Commit

Permalink
load_spandrel_model: always return a model descriptor
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Dec 30, 2023
1 parent 3be9074 commit c0ca634
Showing 1 changed file with 13 additions and 10 deletions.
23 changes: 13 additions & 10 deletions modules/modelloader.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from __future__ import annotations

import importlib
import logging
import os
import importlib
from typing import TYPE_CHECKING
from urllib.parse import urlparse

import torch

from modules import shared
from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, UpscalerNone

if TYPE_CHECKING:
import spandrel

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -142,17 +145,17 @@ def load_spandrel_model(
half: bool = False,
dtype: str | None = None,
expected_architecture: str | None = None,
):
) -> spandrel.ModelDescriptor:
import spandrel
model = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model.architecture != expected_architecture:
model_descriptor = spandrel.ModelLoader(device=device).load_from_file(path)
if expected_architecture and model_descriptor.architecture != expected_architecture:
logger.warning(
f"Model {path!r} is not a {expected_architecture!r} model (got {model.architecture!r})",
f"Model {path!r} is not a {expected_architecture!r} model (got {model_descriptor.architecture!r})",
)
if half:
model = model.model.half()
model_descriptor.model.half()
if dtype:
model = model.model.to(dtype=dtype)
model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model, path, device, half, dtype)
return model
model_descriptor.model.to(dtype=dtype)
model_descriptor.model.eval()
logger.debug("Loaded %s from %s (device=%s, half=%s, dtype=%s)", model_descriptor, path, device, half, dtype)
return model_descriptor

0 comments on commit c0ca634

Please sign in to comment.