diff --git a/mmdeploy/backend/torchscript/wrapper.py b/mmdeploy/backend/torchscript/wrapper.py index 5d8c791772..4220db4d4f 100644 --- a/mmdeploy/backend/torchscript/wrapper.py +++ b/mmdeploy/backend/torchscript/wrapper.py @@ -1,10 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import importlib import os.path as osp from typing import Dict, Optional, Sequence, Union import torch -from mmdeploy.utils import Backend +from mmdeploy.utils import Backend, get_root_logger from mmdeploy.utils.timer import TimeCounter from ..base import BACKEND_WRAPPER, BaseWrapper from .init_plugins import get_ops_path @@ -39,10 +40,20 @@ def __init__(self, model: Union[str, torch.jit.RecursiveScriptModule], input_names: Optional[Sequence[str]] = None, output_names: Optional[Sequence[str]] = None): + logger = get_root_logger() + # load custom ops if exist custom_ops_path = get_ops_path() if osp.exists(custom_ops_path): torch.ops.load_library(custom_ops_path) + + # import torchvision for ops + try: + importlib.import_module('torchvision') + except Exception: + logger.warning( + 'Can not import torchvision. ' + 'Models require ops in torchvision might not available.') super().__init__(output_names) self.ts_model = model if isinstance(self.ts_model, str):