Skip to content

Commit

Permalink
fix mmseg config (#281)
Browse files Browse the repository at this point in the history
* fix mmseg config

* fix mmpose evaluate outputs

* fix lint

* update pre-commit config

* fix lint

* Revert "update pre-commit config"

This reverts commit c3fd716.
  • Loading branch information
RunningLeon authored and lvhan028 committed Apr 1, 2022
1 parent a08acee commit 4a081a0
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 5 deletions.
4 changes: 4 additions & 0 deletions configs/mmseg/segmentation_openvino_static-1024x2048.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
_base_ = ['./segmentation_static.py', '../_base_/backends/openvino.py']
onnx_config = dict(input_shape=[2048, 1024])
backend_config = dict(
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 1024, 2048]))])
2 changes: 1 addition & 1 deletion configs/mmseg/segmentation_openvino_static-512x512.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = ['./segmentation_static.py', '../_base_/backends/openvino.py']

onnx_config = dict(input_shape=[512, 512])
backend_config = dict(
model_inputs=[dict(opt_shapes=dict(input=[1, 3, 512, 512]))])
4 changes: 2 additions & 2 deletions mmdeploy/apis/pytorch2torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def torch2torchscript_impl(model: torch.nn.Module,
ir=IR.TORCHSCRIPT), torch.no_grad(), torch.jit.optimized_execution(
True):
# for exporting models with weight that depends on inputs
patched_model(
*inputs) if isinstance(inputs, Sequence) else patched_model(inputs)
patched_model(*inputs) if isinstance(inputs, Sequence) \
else patched_model(inputs)
ts_model = torch.jit.trace(patched_model, inputs)

# perform optimize, note that optimizing models may trigger errors when
Expand Down
10 changes: 8 additions & 2 deletions mmdeploy/codebase/mmpose/deploy/pose_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def evaluate_outputs(model_cfg: mmcv.Config,
out: Optional[str] = None,
metric_options: Optional[dict] = None,
format_only: bool = False,
log_file: Optional[str] = None,
**kwargs):
"""Perform post-processing to predictions of model.
Expand All @@ -215,10 +216,15 @@ def evaluate_outputs(model_cfg: mmcv.Config,
evaluation. It is useful when you want to format the result
to a specific format and submit it to the test server. Defaults
to `False`.
log_file (str | None): The file to write the evaluation results.
Defaults to `None` and the results will only print on stdout.
"""
from mmcv.utils import get_logger
logger = get_logger('test', log_file=log_file, log_level=logging.INFO)

res_folder = '.'
if out:
logging.info(f'\nwriting results to {out}')
logger.info(f'\nwriting results to {out}')
mmcv.dump(outputs, out)
res_folder, _ = os.path.split(out)
os.makedirs(res_folder, exist_ok=True)
Expand All @@ -229,7 +235,7 @@ def evaluate_outputs(model_cfg: mmcv.Config,

results = dataset.evaluate(outputs, res_folder, **eval_config)
for k, v in sorted(results.items()):
print(f'{k}: {v}')
logger.info(f'{k}: {v:.4f}')

def get_model_name(self) -> str:
"""Get the model name.
Expand Down

0 comments on commit 4a081a0

Please sign in to comment.