-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathcalc_acc.py
124 lines (105 loc) · 3.39 KB
/
calc_acc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import argparse
from collections import OrderedDict
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from monai.metrics import compute_generalized_dice, compute_surface_dice
from tqdm import tqdm
def compute_multi_class_dsc(gt, seg):
dsc = []
for i in range(1, gt.max() + 1):
gt_i = gt == i
seg_i = seg == i
dsc.append(
float(
compute_generalized_dice(
torch.tensor(seg_i).unsqueeze(0).unsqueeze(0),
torch.tensor(gt_i).unsqueeze(0).unsqueeze(0),
)[0]
)
)
return np.mean(dsc)
def compute_multi_class_nsd(gt, seg, spacing, tolerance=2.0):
nsd = []
for i in range(1, gt.max() + 1):
gt_i = torch.tensor(gt == i)
seg_i = torch.tensor(seg == i)
nsd.append(
float(
compute_surface_dice(
seg_i.unsqueeze(0).unsqueeze(0),
gt_i.unsqueeze(0).unsqueeze(0),
class_thresholds=[tolerance],
spacing=spacing,
)
)
)
return np.mean(nsd)
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--segs",
required=True,
type=str,
help="directory of segmentation output",
)
parser.add_argument(
"--gts",
required=True,
type=str,
help="directory of ground truth",
)
parser.add_argument(
"--imgs",
required=True,
type=str,
help="directory of original images",
)
parser.add_argument(
"--output_csv",
type=str,
default="seg_metrics.csv",
help="output csv file",
)
args = parser.parse_args()
seg_metrics = OrderedDict()
seg_metrics["case"] = []
seg_metrics["dsc"] = []
seg_metrics["nsd"] = []
segs = sorted(Path(args.segs).glob("*.npz"))
for seg_file in tqdm(segs):
gt_file = Path(args.gts) / seg_file.name
img_file = Path(args.imgs) / seg_file.name
if not gt_file.exists() or not img_file.exists():
continue
npz_seg = np.load(seg_file, "r")
npz_gt = np.load(gt_file, "r")
seg = npz_seg["segs"]
gt = npz_gt["gts"] if "gts" in npz_gt else npz_gt["segs"]
dsc = compute_multi_class_dsc(gt, seg)
if seg_file.name.startswith("3D"):
npz_img = np.load(img_file, "r")
spacing = npz_img["spacing"]
nsd = compute_multi_class_nsd(gt, seg, spacing)
else:
spacing = [1.0, 1.0, 1.0]
nsd = compute_multi_class_nsd(
np.expand_dims(gt, -1), np.expand_dims(seg, -1), spacing
)
seg_metrics["case"].append(seg_file.name)
seg_metrics["dsc"].append(np.round(dsc, 4))
seg_metrics["nsd"].append(np.round(nsd, 4))
dsc_np = np.array(seg_metrics["dsc"])
nsd_np = np.array(seg_metrics["nsd"])
avg_dsc = np.mean(dsc_np[~np.isnan(dsc_np)])
avg_nsd = np.mean(nsd_np[~np.isnan(nsd_np)])
seg_metrics["case"].append("average")
seg_metrics["dsc"].append(avg_dsc)
seg_metrics["nsd"].append(avg_nsd)
df = pd.DataFrame(seg_metrics)
df.to_csv(args.output_csv, index=False, na_rep="NaN")
print("Average DSC: {:.4f}".format(avg_dsc))
print("Average NSD: {:.4f}".format(avg_nsd))
if __name__ == "__main__":
main()