From 953d33cc09f9254308afe62a5597aeded9e09190 Mon Sep 17 00:00:00 2001 From: "q.yao" Date: Tue, 17 Jan 2023 11:00:24 +0800 Subject: [PATCH] [Fix] fix headless device visualize (#1630) * fix headless device visualize * remove pretrain * sync with 1641 * better show --- mmdeploy/apis/visualize.py | 45 +++++++++++++++++----------------- mmdeploy/codebase/base/task.py | 1 + 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/mmdeploy/apis/visualize.py b/mmdeploy/apis/visualize.py index a201131c5c..a8524cc25d 100644 --- a/mmdeploy/apis/visualize.py +++ b/mmdeploy/apis/visualize.py @@ -71,27 +71,26 @@ def visualize_model(model_cfg: Union[str, mmengine.Config], with torch.no_grad(): result = model.test_step(model_inputs)[0] - visualize = True - try: - # check headless - import tkinter - tkinter.Tk() - except Exception as e: - from mmdeploy.utils import get_root_logger - logger = get_root_logger() - logger.warning( - f'render and display result skipped for headless device, exception {e}' # noqa: E501 - ) - visualize = False + if show_result: + try: + # check headless + import tkinter + tkinter.Tk() + except Exception as e: + from mmdeploy.utils import get_root_logger + logger = get_root_logger() + logger.warning( + f'render and display result skipped for headless device, exception {e}' # noqa: E501 + ) + show_result = False - if visualize is True: - if not isinstance(img, list): - img = [img] - for single_img in img: - task_processor.visualize( - image=single_img, - model=model, - result=result, - output_file=output_file, - window_name=backend.value, - show_result=show_result) + if isinstance(img, str) or not isinstance(img, Sequence): + img = [img] + for single_img in img: + task_processor.visualize( + image=single_img, + model=model, + result=result, + output_file=output_file, + window_name=backend.value, + show_result=show_result) diff --git a/mmdeploy/codebase/base/task.py b/mmdeploy/codebase/base/task.py index 2e97b96933..df876a83d7 100644 --- a/mmdeploy/codebase/base/task.py +++ b/mmdeploy/codebase/base/task.py @@ -112,6 +112,7 @@ def build_pytorch_model(self, from mmengine.registry import MODELS model = deepcopy(self.model_cfg.model) + model.pop('pretrained', None) preprocess_cfg = deepcopy(self.model_cfg.get('preprocess_cfg', {})) preprocess_cfg.update( deepcopy(self.model_cfg.get('data_preprocessor', {})))