Skip to content

Commit

Permalink
Add missing components
Browse files Browse the repository at this point in the history
  • Loading branch information
jurgendn committed Oct 13, 2024
1 parent 7b108c8 commit 3ba819a
Show file tree
Hide file tree
Showing 8 changed files with 772 additions and 8 deletions.
2 changes: 1 addition & 1 deletion configs/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class HRNetStageConfig(BaseModel):
class HRNetModelExtraConfig(BaseModel):
pretrained_layers: List[str]
final_conv_kernel: int
stem_inplane: int

stage2: HRNetStageConfig
stage3: HRNetStageConfig
Expand All @@ -68,6 +67,7 @@ class HRNetModelConfig(BaseModel):
image_size: Tuple[int, int]
heatmap_size: Tuple[int, int]
sigma: int
return_hoe: bool
extra: HRNetModelExtraConfig


Expand Down
4 changes: 2 additions & 2 deletions configs/hrnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ model:
tag_per_joint: True
target_type: gaussian
use_featuremap: True
return_hoe: True

extra:
stem_inplane: 64
final_conv_kernel: 1
pretrained_layers:
[
Expand Down Expand Up @@ -72,4 +72,4 @@ debug_debug: True
debug_save_batch_images_gt: True
debug_save_batch_images_pred: True
debug_save_heatmaps_gt: True
debug_save_heatmaps_pred: True
debug_save_heatmaps_pred: True
59 changes: 59 additions & 0 deletions configs/pose_hrnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
auto_resume: False

cudnn_benchmark: True
cudnn_deterministic: False
cudnn_enabled: True

data_dir: ""
gpus: [0, 1, 2, 3]
output_dir: "output"
log_dir: "log"
workers: "8x"
print_freq: 30

model:
image_size: [256, 192]
heatmap_size: [64, 48]
init_weights: True
name: pose_hrnet
num_joints: 17
pretrained: ./pretrained/hrnet_w32_256x192.pth
sigma: 2
type: simple
tag_per_joint: True
target_type: gaussian
use_featuremap: True
return_hoe: False

extra:
final_conv_kernel: 1
pretrained_layers: ["*"]
stage2:
block: BASIC
fuse_method: SUM
num_blocks: [4, 4]
num_channels: [32, 64]
num_branches: 2
num_modules: 1
stage3:
block: BASIC
fuse_method: SUM
num_blocks: [4, 4, 4]
num_channels: [32, 64, 128]
num_branches: 3
num_modules: 4
stage4:
block: BASIC
fuse_method: SUM
num_blocks: [4, 4, 4, 4]
num_channels: [32, 64, 128, 256]
num_branches: 4
num_modules: 3
loss_use_different_joints_weight: False
loss_use_target_weight: True

debug_debug: True
debug_save_batch_images_gt: True
debug_save_batch_images_pred: True
debug_save_heatmaps_gt: True
debug_save_heatmaps_pred: True
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"more-itertools>=10.5.0",
"numpy>=1.26.4",
"pydantic>=2.9.2",
"pytest>=8.3.3",
"pytorch-lightning>=2.4.0",
"timm>=1.0.9",
"torch-geometric>=2.6.1",
Expand Down
89 changes: 84 additions & 5 deletions scripts/get_pose.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,6 +1,85 @@
#!/bin/bash
export PYTHONPATH=.
python tools/get_pose.py --config configs/hrnet.yaml \
--pretrained pretrained/model_hboe.pth \
--dataset /Users/jurgendn/Documents/projects/personal/dataset/LTCC_ReID/ \
--target-set query

# Default values for optional parameters
config_file="configs/pose_hrnet.yaml"
pretrained="pretrained/hrnet_w32_256x192.pth"

# Parse flags using getopts
while [[ "$#" -gt 0 ]]; do
case $1 in
--metadata)
metadata="$2"
shift
;;
--dataset-name)
dataset_name="$2"
shift
;;
--target-set)
target_set="$2"
shift
;;
--batch-size)
batch_size="$2"
shift
;;
--device)
device="$2"
shift
;;
--config-file)
config_file="$2"
shift
;; # Optional
--pretrained)
pretrained="$2"
shift
;; # Optional
--help)
echo "Usage: $0 --metadata <metadata> --dataset-name <dataset_name> --target-set <target_set> --batch-size <batch_size> --device <device> [--config-file <config_file>] [--pretrained <pretrained>]"
exit 0
;;
*)
echo "Unknown parameter: $1"
echo "Use --help for usage."
exit 1
;;
esac
shift
done

# Check for required parameters
if [[ -z "$metadata" ]]; then
echo "Error: --metadata is required"
exit 1
fi

if [[ -z "$dataset_name" ]]; then
echo "Error: --dataset-name is required"
exit 1
fi

if [[ -z "$target_set" ]]; then
echo "Error: --target-set is required"
exit 1
fi

if [[ -z "$batch_size" ]]; then
echo "Error: --batch-size is required"
exit 1
fi

if [[ -z "$device" ]]; then
echo "Error: --device is required"
exit 1
fi

# Run the Python script
PYTHONPATH=. uv run tools/get_pose.py \
--metadata "$metadata" \
--dataset-name "$dataset_name" \
--target-set "$target_set" \
--batch-size "$batch_size" \
--device "$device" \
--config-file "$config_file" \
--pretrained "$pretrained"
Loading

0 comments on commit 3ba819a

Please sign in to comment.