Skip to content

Commit

Permalink
Merge branch 'dev-1.x' of https://github.com/open-mmlab/mmpose into d…
Browse files Browse the repository at this point in the history
…ev-1.x
  • Loading branch information
Tau-J committed Oct 16, 2023
2 parents 96d5a3b + d7463ca commit 9e5aa6b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 16 deletions.
12 changes: 5 additions & 7 deletions mmpose/models/heads/regression_heads/temporal_regression_head.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch import Tensor, nn

from mmpose.evaluation.functional import keypoint_pck_accuracy
from mmpose.evaluation.functional import keypoint_mpjpe
from mmpose.registry import KEYPOINT_CODECS, MODELS
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList,
Expand Down Expand Up @@ -133,14 +132,13 @@ def loss(self,
losses.update(loss_pose3d=loss)

# calculate accuracy
_, avg_acc, _ = keypoint_pck_accuracy(
mpjpe_err = keypoint_mpjpe(
pred=to_numpy(pred_outputs),
gt=to_numpy(lifting_target_label),
mask=to_numpy(lifting_target_weight) > 0,
thr=0.05,
norm_factor=np.ones((pred_outputs.size(0), 3), dtype=np.float32))
mask=to_numpy(lifting_target_weight) > 0)

mpjpe_pose = torch.tensor(avg_acc, device=lifting_target_label.device)
mpjpe_pose = torch.tensor(
mpjpe_err, device=lifting_target_label.device)
losses.update(mpjpe=mpjpe_pose)

return losses
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch import Tensor, nn

from mmpose.evaluation.functional import keypoint_pck_accuracy
from mmpose.evaluation.functional import keypoint_mpjpe
from mmpose.registry import KEYPOINT_CODECS, MODELS
from mmpose.utils.tensor_utils import to_numpy
from mmpose.utils.typing import (ConfigType, OptConfigType, OptSampleList,
Expand Down Expand Up @@ -132,14 +131,13 @@ def loss(self,
losses.update(loss_traj=loss)

# calculate accuracy
_, avg_acc, _ = keypoint_pck_accuracy(
mpjpe_err = keypoint_mpjpe(
pred=to_numpy(pred_outputs),
gt=to_numpy(lifting_target_label),
mask=to_numpy(trajectory_weights) > 0,
thr=0.05,
norm_factor=np.ones((pred_outputs.size(0), 3), dtype=np.float32))
mask=to_numpy(trajectory_weights) > 0)

mpjpe_traj = torch.tensor(avg_acc, device=lifting_target_label.device)
mpjpe_traj = torch.tensor(
mpjpe_err, device=lifting_target_label.device)
losses.update(mpjpe_traj=mpjpe_traj)

return losses
Expand Down
6 changes: 4 additions & 2 deletions mmpose/models/utils/rtmcc_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,15 @@ def __init__(self,

nn.init.xavier_uniform_(self.uv.weight)

if act_fn == 'SiLU':
if act_fn == 'SiLU' or act_fn == nn.SiLU:
assert digit_version(TORCH_VERSION) >= digit_version('1.7.0'), \
'SiLU activation requires PyTorch version >= 1.7'

self.act_fn = nn.SiLU(True)
else:
elif act_fn == 'ReLU' or act_fn == nn.ReLU:
self.act_fn = nn.ReLU(True)
else:
raise NotImplementedError

if in_token_dims == out_token_dims:
self.shortcut = True
Expand Down

0 comments on commit 9e5aa6b

Please sign in to comment.