-
Notifications
You must be signed in to change notification settings - Fork 1
/
image_sample_diff_brain_tumor_1.py
96 lines (76 loc) · 2.94 KB
/
image_sample_diff_brain_tumor_1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""
Generate a large batch of image samples from a model and save them as a large
numpy array. This can be used to produce samples for FID evaluation.
"""
import argparse
import datetime
import json
import warnings
from pathlib import Path
import torch.distributed as dist
from mpi4py import MPI
from datasets.brain_tumor import BrainTumor1Dataset
from improved_diffusion import dist_util, logger
from improved_diffusion.sampling_util import sampling_major_vote_func
from improved_diffusion.script_util import (
model_and_diffusion_defaults,
create_model_and_diffusion,
add_dict_to_argparser,
args_to_dict,
)
from improved_diffusion.utils import set_random_seed
warnings.filterwarnings('ignore')
def main():
args = create_argparser().parse_args()
original_logs_path = Path(args.model_path).parent
args.__dict__.update(json.loads((original_logs_path / 'args.json').read_text()))
logger.info(args.__dict__)
dist_util.setup_dist()
number_of_generated_instances = args.n_gen
logs_path = original_logs_path / f"{Path(args.model_path).stem}_major_vote_{args.n_gen}"
logger.configure(dir=str(logs_path), log_suffix=f"val_{datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S-%f')}")
test_dataset=BrainTumor1Dataset(
mode='val',
image_size=args.image_size,
shard=MPI.COMM_WORLD.Get_rank(),
num_shards=MPI.COMM_WORLD.Get_size(),
soft_label_gt=args.soft_label_training,
consensus_gt=args.consensus_training
)
logger.log("creating model and diffusion...")
args.condition_input_channel = 4
model, diffusion = create_model_and_diffusion(
**args_to_dict(args, model_and_diffusion_defaults().keys())
)
model.load_state_dict(
dist_util.load_state_dict(args.model_path, map_location="cpu")
)
model.to(dist_util.dev())
model.eval()
if args.__dict__.get("seed") is None:
seed = 1234
else:
seed = int(args.__dict__.get("seed"))
set_random_seed(seed, deterministic=True)
logger.log("sampling major vote val")
(logs_path / "major_vote").mkdir(exist_ok=True)
step = int(Path(args.model_path).stem.split("_")[-1])
sampling_major_vote_func(diffusion, model, str(logs_path / "major_vote"), test_dataset, logger, args.clip_denoised,
step=step, number_of_generated_instances=number_of_generated_instances)
dist.barrier()
logger.log("sampling complete")
def create_argparser():
defaults = dict(
clip_denoised=True,
num_samples=10000,
batch_size=16,
use_ddim=False,
model_path="/media/media1/shmuelsh/TomrCode/logs/2023-01-10-22-16-18-756357_brain_tumor_single_annotator_256_10_5e-05_1_100_0.0_0/ema_val_0.8067150_0.9999_035000.pt",
n_gen=25,
)
defaults.update(model_and_diffusion_defaults())
parser = argparse.ArgumentParser()
add_dict_to_argparser(parser, defaults)
return parser
if __name__ == "__main__":
main()