You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The training for the first ten batches proceeds smoothly, but an error occurs at the 11th batch. I suspect it might be related to the validation dataset. I tried setting bbox_file in the val_dataloader to None, but the error still persists.
11/25 15:46:21 - mmengine - INFO - Saving checkpoint at 10 epochs
Traceback (most recent call last):
File "F:/program/mmpose/tools/train.py", line 162, in
main()
File "F:/program/mmpose/tools/train.py", line 158, in main
runner.train()
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\runner\runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\runner\loops.py", line 105, in run
self.runner.val_loop.run()
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\runner\loops.py", line 379, in run
self.run_iter(idx, data_batch)
File "F:\anaconda\envs\openmmlab\lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\runner\loops.py", line 404, in run_iter
outputs = self.runner.model.val_step(data_batch)
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\model\base_model\base_model.py", line 133, in val_step
return self._run_forward(data, mode='predict') # type: ignore
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\model\base_model\base_model.py", line 361, in _run_forward
results = self(**data, mode=mode)
File "F:\anaconda\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "F:\anaconda\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "f:\program\mmpose\mmpose\models\pose_estimators\base.py", line 161, in forward
return self.predict(inputs, data_samples)
File "f:\program\mmpose\mmpose\models\pose_estimators\topdown.py", line 109, in predict
preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)
File "f:\program\mmpose\mmpose\models\heads\heatmap_heads\heatmap_head.py", line 259, in predict
_batch_heatmaps_flip = flip_heatmaps(
File "f:\program\mmpose\mmpose\models\utils\tta.py", line 39, in flip_heatmaps
assert len(flip_indices) == heatmaps.shape[1]
AssertionError
Additional information
No response
The text was updated successfully, but these errors were encountered:
Prerequisite
Environment
Package Version Source
mmcv 2.1.0 https://github.com/open-mmlab/mmcv
mmdet 3.3.0 https://github.com/open-mmlab/mmdetection
mmengine 0.10.5 https://github.com/open-mmlab/mmengine
mmpose 1.3.2 f:\program\mmpose
ystem environment:
sys.platform: win32
Python: 3.8.20 (default, Oct 3 2024, 15:19:54) [MSC v.1929 64 bit (AMD64)]
CUDA available: True
MUSA available: False
numpy_random_seed: 596127647
GPU 0: NVIDIA GeForce RTX 3080 Ti
CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8
NVCC: Cuda compilation tools, release 11.8, V11.8.89
GCC: n/a
PyTorch: 2.1.0+cu118
PyTorch compiling details: PyTorch built with:
C++ Version: 199711
MSVC 192930151
Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
Intel(R) MKL-DNN v3.1.1 (Git Hash 64f6bcbcbab628e96f33a62c3e975f8535a7bde4)
OpenMP 2019
LAPACK is enabled (usually provided by MKL)
CPU capability usage: AVX2
CUDA Runtime 11.8
NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_90,code=sm_90;-gencode;arch=compute_37,code=compute_37
CuDNN 8.7
Magma 2.5.4
Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.8, CUDNN_VERSION=8.7.0, CXX_COMPILER=C:/actions-runner/_work/pytorch/pytorch/builder/windows/tmp_bin/sccache-cl.exe, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /EHsc /bigobj /FS -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOCUPTI -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE /utf-8 /wd4624 /wd4068 /wd4067 /wd4267 /wd4661 /wd4717 /wd4244 /wd4804 /wd4273, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=OFF, TORCH_VERSION=2.1.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON, USE_ROCM=OFF,
TorchVision: 0.16.0+cu118
OpenCV: 4.10.0
MMEngine: 0.10.5
Runtime environment:
cudnn_benchmark: False
mp_cfg: {'mp_start_method': 'fork', 'opencv_num_threads': 0}
dist_cfg: {'backend': 'nccl'}
seed: 596127647
Distributed launcher: none
Distributed training: False
GPU number: 1
Reproduces the problem - code sample
custom dateset file
dataset_info = dict(
dataset_name='wheelchair tennis',
paper_info=dict(
author='zhang heng',
title='Microsoft coco: Common objects in context',
container='',
year='2024',
),
keypoint_info={
0:
dict(name='head', id=0, color=[0, 0, 255], type='upper', swap=''),
1:
dict(
name='neck',
id=1,
color=[255, 0, 0],
type='upper',
swap=''),
2:
dict(
name='right.shoulder',
id=2,
color=[0, 255, 0],
type='upper',
swap='left.shoulder'),
3:
dict(
name='right.elbow',
id=3,
color=[0, 255, 255],
type='upper',
swap='left.elbow'),
4:
dict(
name='right.wrist',
id=4,
color=[255, 0, 255],
type='upper',
swap='left.wrist'),
5:
dict(
name='left.shoulder',
id=5,
color=[255, 255, 0],
type='upper',
swap='right.shoulder'),
6:
dict(
name='left.elbow',
id=6,
color=[255, 128, 0],
type='upper',
swap='right.elbow'),
7:
dict(
name='left.wrist',
id=7,
color=[0, 255, 128],
type='upper',
swap='right.wrist'),
8:
dict(
name='mid.hip',
id=8,
color=[255, 128, 128],
type='upper',
swap=''),
9:
dict(
name='racket',
id=9,
color=[128, 128, 128],
type='',
swap='')
},
skeleton_info={
0:
dict(link=('head', 'neck'), id=0, color=[0, 255, 0]),
1:
dict(link=('neck', 'right.shoulder'), id=1, color=[0, 0, 255]),
2:
dict(link=('neck', 'left.shoulder'), id=2, color=[255, 128, 0]),
3:
dict(link=('neck', 'mid.hip'), id=3, color=[255, 128, 128]),
4:
dict(link=('right.shoulder', 'right.elbow'), id=4, color=[51, 51, 51]),
5:
dict(link=('right.elbow', 'right.wrist'), id=5, color=[51, 153, 51]),
6:
dict(link=('left.shoulder', 'left.elbow'), id=6, color=[153, 153, 255]),
7:
dict(
link=('left.elbow', 'left.wrist'),
id=7,
color=[51, 153, 255]),
8:
dict(link=('right.wrist', 'racket'), id=8, color=[51, 255, 51])
},
joint_weights=[
1., 1., 1.2, 1.2, 1.2, 1.2, 1.2, 1.2, 1., 1.5
],
sigmas=[
0.026, 0.035, 0.035, 0.079, 0.079, 0.072, 0.089, 0.089, 0.062, 0.107
])
config file
base = ['../../../base/default_runtime.py']
runtime
train_cfg = dict(max_epochs=20, val_interval=10)
optimizer
optim_wrapper = dict(optimizer=dict(
type='Adam',
lr=5e-4,
))
learning policy
param_scheduler = [
dict(
type='LinearLR', begin=0, end=500, start_factor=0.001,
by_epoch=False), # warm-up
dict(
type='MultiStepLR',
begin=0,
end=210,
milestones=[170, 200],
gamma=0.1,
by_epoch=True)
]
automatically scaling LR based on the actual training batch size
auto_scale_lr = dict(base_batch_size=512)
hooks
default_hooks = dict(checkpoint=dict(save_best='coco/AP', rule='greater'))
codec settings
codec = dict(
type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2)
model settings
model = dict(
type='TopdownPoseEstimator',
data_preprocessor=dict(
type='PoseDataPreprocessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True),
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256))),
init_cfg=dict(
type='Pretrained',
checkpoint='https://download.openmmlab.com/mmpose/'
'pretrain_models/hrnet_w32-36af842e.pth'),
),
head=dict(
type='HeatmapHead',
in_channels=32,
out_channels=10,
deconv_out_channels=None,
loss=dict(type='KeypointMSELoss', use_target_weight=True),
decoder=codec),
test_cfg=dict(
flip_test=True,
flip_mode='heatmap',
shift_heatmap=True,
))
base dataset settings
dataset_type = 'CocoDataset'
data_mode = 'topdown'
data_root = 'F:/program/mmpose/myprogram/data/wheelchair-tennis-game-12/'
pipelines
train_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='RandomFlip', direction='horizontal'),
dict(type='RandomHalfBody'),
dict(type='RandomBBoxTransform'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='GenerateTarget', encoder=codec),
dict(type='PackPoseInputs')
]
val_pipeline = [
dict(type='LoadImage'),
dict(type='GetBBoxCenterScale'),
dict(type='TopdownAffine', input_size=codec['input_size']),
dict(type='PackPoseInputs')
]
data loaders
train_dataloader = dict(
batch_size=64,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='train/_annotations.coco.json',
data_prefix=dict(img='train/'),
pipeline=train_pipeline,
metainfo=dict(from_file='F:/program/mmpose/configs/base/datasets/wheelchair_tennis_custom.py')
))
val_dataloader = dict(
batch_size=32,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
data_mode=data_mode,
ann_file='valid/_annotations.coco.json',
bbox_file=None,
data_prefix=dict(img='valid/'),
test_mode=True,
pipeline=val_pipeline,
))
test_dataloader = val_dataloader
evaluators
val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'valid/_annotations.coco.json')
test_evaluator = val_evaluator
Reproduces the problem - command or script
python F:/program/mmpose/tools/train.py F:/program/mmpose/configs/body_2d_keypoint/topdown_heatmap/coco/wheelchair_tennis_cfg.py
Reproduces the problem - error message
The training for the first ten batches proceeds smoothly, but an error occurs at the 11th batch. I suspect it might be related to the validation dataset. I tried setting bbox_file in the val_dataloader to None, but the error still persists.
11/25 15:46:21 - mmengine - INFO - Saving checkpoint at 10 epochs
Traceback (most recent call last):
File "F:/program/mmpose/tools/train.py", line 162, in
main()
File "F:/program/mmpose/tools/train.py", line 158, in main
runner.train()
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\runner\runner.py", line 1777, in train
model = self.train_loop.run() # type: ignore
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\runner\loops.py", line 105, in run
self.runner.val_loop.run()
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\runner\loops.py", line 379, in run
self.run_iter(idx, data_batch)
File "F:\anaconda\envs\openmmlab\lib\site-packages\torch\utils_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\runner\loops.py", line 404, in run_iter
outputs = self.runner.model.val_step(data_batch)
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\model\base_model\base_model.py", line 133, in val_step
return self._run_forward(data, mode='predict') # type: ignore
File "F:\anaconda\envs\openmmlab\lib\site-packages\mmengine\model\base_model\base_model.py", line 361, in _run_forward
results = self(**data, mode=mode)
File "F:\anaconda\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "F:\anaconda\envs\openmmlab\lib\site-packages\torch\nn\modules\module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "f:\program\mmpose\mmpose\models\pose_estimators\base.py", line 161, in forward
return self.predict(inputs, data_samples)
File "f:\program\mmpose\mmpose\models\pose_estimators\topdown.py", line 109, in predict
preds = self.head.predict(feats, data_samples, test_cfg=self.test_cfg)
File "f:\program\mmpose\mmpose\models\heads\heatmap_heads\heatmap_head.py", line 259, in predict
_batch_heatmaps_flip = flip_heatmaps(
File "f:\program\mmpose\mmpose\models\utils\tta.py", line 39, in flip_heatmaps
assert len(flip_indices) == heatmaps.shape[1]
AssertionError
Additional information
No response
The text was updated successfully, but these errors were encountered: