From d9b5a218c4794fb9b2c0495150ea87e9989659e3 Mon Sep 17 00:00:00 2001 From: hanrui1sensetime <83800577+hanrui1sensetime@users.noreply.github.com> Date: Wed, 18 Jan 2023 16:31:11 +0800 Subject: [PATCH] [Fix] Fix preprocess_model_config for CIFAR dataset (#1659) * fix cifar10 for mmcls * remove unnecessary code --- mmdeploy/codebase/mmcls/deploy/classification.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/mmdeploy/codebase/mmcls/deploy/classification.py b/mmdeploy/codebase/mmcls/deploy/classification.py index 6cb17b298a..dad8ba3ff4 100644 --- a/mmdeploy/codebase/mmcls/deploy/classification.py +++ b/mmdeploy/codebase/mmcls/deploy/classification.py @@ -63,13 +63,15 @@ def process_model_config(model_cfg: Config, cfg.test_pipeline.pop(0) # check whether input_shape is valid if input_shape is not None: - if 'crop_size' in cfg.test_pipeline[2]: - crop_size = cfg.test_pipeline[2]['crop_size'] - if tuple(input_shape) != (crop_size, crop_size): - logger = get_root_logger() - logger.warning( - f'`input shape` should be equal to `crop_size`: {crop_size},\ - but given: {input_shape}') + for pipeline_component in cfg.test_pipeline: + if 'Crop' in pipeline_component['type']: + if 'crop_size' in pipeline_component: + crop_size = pipeline_component['crop_size'] + if tuple(input_shape) != (crop_size, crop_size): + logger = get_root_logger() + logger.warning( + f'`input shape` should be equal to `crop_size`: {crop_size},\ + but given: {input_shape}') return cfg