Skip to content

Commit

Permalink
add inference_custom.py
Browse files Browse the repository at this point in the history
  • Loading branch information
rostyslavhereha committed Jan 7, 2025
1 parent f065ffa commit 9707df2
Showing 1 changed file with 27 additions and 29 deletions.
56 changes: 27 additions & 29 deletions tools/inference_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def parse_args():
parser.add_argument('--img-dir', type=str, required=True, help='Directory with input images')
parser.add_argument('--bbox-json', type=str, required=True, help='Path to COCO format bounding box JSON')
parser.add_argument('--out-dir', type=str, help='Directory to save visualized results (optional)')
parser.add_argument('--output-file', type=str, help='File to save keypoint results in JSON')
parser.add_argument('--predictions-dir', type=str, required=True, help='Directory to save individual prediction files')
parser.add_argument('--device', default='cuda:0', help='Device to run inference on (e.g., "cuda:0" or "cpu")')
parser.add_argument('--score-thr', type=float, default=0.3, help='Keypoint score threshold')
parser.add_argument(
Expand All @@ -26,6 +26,13 @@ def parse_args():
return parser.parse_args()


def draw_bboxes(image, bboxes):
"""Draw bounding boxes on the image using OpenCV."""
for bbox in bboxes:
x, y, x2, y2 = map(int, bbox)
cv2.rectangle(image, (x, y), (x2, y2), (255, 0, 0), 2) # Draw rectangle


def draw_keypoints(image, keypoints, scores, score_thr):
"""Draw keypoints on the image using OpenCV."""
for el1, el2 in zip(keypoints, scores):
Expand All @@ -39,24 +46,18 @@ def draw_keypoints(image, keypoints, scores, score_thr):
def main():
args = parse_args()

# Load the COCO bounding boxes
coco = COCO(args.bbox_json)
img_ids = list(coco.imgs.keys())

# Initialize the model
cfg = Config.fromfile(args.config)
if args.cfg_options:
cfg.merge_from_dict(args.cfg_options)
model = init_model(cfg, args.model, device=args.device)

# Ensure output directories exist if `out-dir` is provided
if args.out_dir:
os.makedirs(args.out_dir, exist_ok=True)
os.makedirs(args.predictions_dir, exist_ok=True)

# Results to be saved
results = []

# Progress bar
progress_bar = ProgressBar(len(img_ids))

for img_id in img_ids:
Expand All @@ -68,64 +69,61 @@ def main():
progress_bar.update()
continue

# Load the image
image = cv2.imread(img_path)
if image is None:
print(f"Failed to read image: {img_path}")
progress_bar.update()
continue

# Convert image to RGB for inference
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

# Load bounding boxes for the image
ann_ids = coco.getAnnIds(imgIds=[img_id])
annotations = coco.loadAnns(ann_ids)
person_bboxes = np.array([ann['bbox'] for ann in annotations])

# Run pose inference
pose_results = inference_topdown(
model,
image_rgb,
person_bboxes,
bbox_format='xywh' # COCO annotations typically use 'xywh' format
bbox_format='xyxy' # COCO annotations typically use 'xywh' format
)

# Extract keypoints and bounding boxes from PoseDataSample
keypoints_results = []
for pose in pose_results:
pred_instances = pose.pred_instances
if pred_instances is not None:
keypoints = pred_instances.keypoints
scores = pred_instances.keypoint_scores
bbox = person_bboxes[0] # Taking first bbox, adjust for multiple detections

# Save the bbox along with the keypoints data
keypoints_results.append({
'keypoints': keypoints.tolist(),
'scores': scores.tolist()
'scores': scores.tolist(),
'bbox': bbox.tolist() # Add bbox to the result
})

# Draw keypoints on the image
draw_keypoints(image, keypoints, scores, args.score_thr)
if args.out_dir:
draw_keypoints(image, keypoints, scores, args.score_thr)
draw_bboxes(image, person_bboxes)

# Save the visualized image if `out-dir` is provided
if args.out_dir:
out_file = os.path.join(args.out_dir, img_info['file_name'])
cv2.imwrite(out_file, image)

# Save keypoints to results
results.append({
'image_id': img_id,
'file_name': img_info['file_name'],
'keypoints': keypoints_results
})
# Save individual prediction file
prediction_file = os.path.join(args.predictions_dir, f"{os.path.splitext(img_info['file_name'])[0]}.json")
with open(prediction_file, 'w') as f:
json.dump({
"result": keypoints_results,
"score": 0 # Placeholder for score; replace with actual logic if needed
}, f)

progress_bar.update()

# Save results to output file
if args.output_file:
with open(args.output_file, 'w') as f:
json.dump(results, f, indent=4)
print(f"Inference completed. Results {'saved to ' + args.out_dir if args.out_dir else ''}.")
print("Inference completed.")


if __name__ == '__main__':
main()
main()

0 comments on commit 9707df2

Please sign in to comment.