Skip to content

Commit

Permalink
Added tqdm to track progress for multiple files
Browse files Browse the repository at this point in the history
  • Loading branch information
ancestor-mithril committed Jul 29, 2024
1 parent 8799d18 commit b472b4f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 10 deletions.
28 changes: 18 additions & 10 deletions dice_score_3d/metrics.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import json
import os.path
from concurrent.futures import ProcessPoolExecutor
from typing import List, Sequence, Tuple, Union

import numpy as np
from dice_score_3d.reader import read_mask
from numpy import ndarray
from tqdm import tqdm
from tqdm.contrib.concurrent import process_map

from dice_score_3d.reader import read_mask


def dice_metrics(ground_truths: str, predictions: str, output_path: Union[str, None], indices: dict,
Expand Down Expand Up @@ -118,22 +120,28 @@ def evaluate_prediction(gt: str, pred: str, reorient: bool, dtype: np.dtype, ind
return multi_class_dice(gt, pred, indices)


def execute_evaluate_predictions(gt_files: List[str], pred_files: List[str], reorient: bool, dtype: np.dtype,
indices: Sequence[int], num_workers: int) \
-> Sequence[Tuple[ndarray, ndarray, ndarray, ndarray]]:
if num_workers == 0:
ret = [evaluate_prediction(gt, pred, reorient, dtype, indices) for gt, pred in tqdm(zip(gt_files, pred_files))]
else:
ret = process_map(evaluate_prediction,
[(gt, pred, reorient, dtype, indices) for gt, pred in zip(gt_files, pred_files)],
max_workers=num_workers)
return ret


def evaluate_predictions(gt_files: List[str], pred_files: List[str], reorient: bool, dtype: np.dtype,
indices: Sequence[int], num_workers: int) -> Tuple[ndarray, ndarray, ndarray, ndarray]:
""" Evaluates each pair of prediction and GT sequentially or in parallel and collects metrics.
"""
if num_workers == 0:
ret = [evaluate_prediction(gt, pred, reorient, dtype, indices) for gt, pred in zip(gt_files, pred_files)]
else:
with ProcessPoolExecutor(max_workers=num_workers) as executor:
ret = executor.map(evaluate_prediction,
[(gt, pred, reorient, dtype, indices) for gt, pred in zip(gt_files, pred_files)])

scores = execute_evaluate_predictions(gt_files, pred_files, reorient, dtype, indices, num_workers)
common_voxels = []
all_voxels = []
gt_voxels = []
dice_scores = []
for a, b, c, d in ret:
for a, b, c, d in scores:
common_voxels.append(a)
all_voxels.append(b)
gt_voxels.append(c)
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ keywords = [
dependencies = [
"numpy",
"SimpleITK",
"tqdm",
]

[project.urls]
Expand Down

0 comments on commit b472b4f

Please sign in to comment.