-
Notifications
You must be signed in to change notification settings - Fork 1
/
evaluate.py
71 lines (61 loc) · 2.02 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from options.infer_options import EvalOptions
from pathlib import Path
import geopandas as gpd
import pandas as pd
import numpy as np
import rasterio
from joblib import Parallel, delayed
from tqdm import tqdm
import utils.data as dt
import utils.metrics as mt
import utils.model_utils as mu
from warnings import warn, filterwarnings
filterwarnings("ignore", category=UserWarning)
opts = EvalOptions().parse()
save_dir = Path(opts.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
probs = [0.6] if opts.grid == 1 else np.arange(0, 1, 1 / opts.grid)
# read in the true labels, but get a buffer
vector_label = gpd.read_file(opts.vector_label)
vector_label = vector_label.set_index("GL_ID")
buffer = vector_label.buffer(distance=opts.buffer)
# loop over paths, get predictions, and evaluate
metrics = {
"IoU": mt.IoU,
"precision": mt.precision,
"recall": mt.recall,
"frechet": mt.frechet_distance
}
m = []
def process_input(path):
gl_id = path.stem.split("_")[0]
y_reader = rasterio.open(path)
y_hat = y_reader.read().astype(np.float32)
# polygonized predictions for each probability
m = []
for p in probs:
if not gl_id in vector_label.index:
continue
y_hat_poly = mu.polygonize_preds(
y_hat, y_reader,
buffer.loc[gl_id],
threshold=p,
tol=opts.tol
)
if np.isclose(p, opts.geo_prob):
y_hat_poly.to_file(save_dir / f"{path.stem}.geojson", driver="GeoJSON")
# get metrics for these predictions
results = mu.polygon_metrics(
y_hat_poly,
vector_label.loc[gl_id:gl_id],
y_reader,
metrics=metrics
)
results["prob"] = p
results["GL_ID"] = gl_id
results["sample_id"] = path.stem
m.append(results)
return pd.DataFrame(m)
eval_paths = list(Path(opts.eval_dir).glob("*.tif"))
m = Parallel(n_jobs=opts.n_jobs)(delayed(process_input)(fn) for fn in tqdm(eval_paths))
pd.concat(m).to_csv(save_dir / opts.fname, index=False)