-
Notifications
You must be signed in to change notification settings - Fork 16
/
classification.py
113 lines (102 loc) · 3.9 KB
/
classification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import argparse
import pathlib
import time
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.utilities.seed import seed_everything
from datasets.solidletters import SolidLetters
from uvnet.models import Classification
parser = argparse.ArgumentParser("UV-Net solid model classification")
parser.add_argument(
"traintest", choices=("train", "test"), help="Whether to train or test"
)
parser.add_argument("--dataset", choices=("solidletters",), help="Dataset to train on")
parser.add_argument("--dataset_path", type=str, help="Path to dataset")
parser.add_argument("--batch_size", type=int, default=64, help="Batch size")
parser.add_argument(
"--num_workers",
type=int,
default=0,
help="Number of workers for the dataloader. NOTE: set this to 0 on Windows, any other value leads to poor performance",
)
parser.add_argument(
"--checkpoint",
type=str,
default=None,
help="Checkpoint file to load weights from for testing",
)
parser.add_argument(
"--experiment_name",
type=str,
default="classification",
help="Experiment name (used to create folder inside ./results/ to save logs and checkpoints)",
)
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args()
results_path = (
pathlib.Path(__file__).parent.joinpath("results").joinpath(args.experiment_name)
)
if not results_path.exists():
results_path.mkdir(parents=True, exist_ok=True)
# Define a path to save the results based date and time. E.g.
# results/args.experiment_name/0430/123103
month_day = time.strftime("%m%d")
hour_min_second = time.strftime("%H%M%S")
checkpoint_callback = ModelCheckpoint(
monitor="val_loss",
dirpath=str(results_path.joinpath(month_day, hour_min_second)),
filename="best",
save_last=True,
)
trainer = Trainer.from_argparse_args(
args,
callbacks=[checkpoint_callback],
logger=TensorBoardLogger(
str(results_path), name=month_day, version=hour_min_second,
),
)
if args.dataset == "solidletters":
Dataset = SolidLetters
else:
raise ValueError("Unsupported dataset")
if args.traintest == "train":
# Train/val
seed_everything(workers=True)
print(
f"""
-----------------------------------------------------------------------------------
UV-Net Classification
-----------------------------------------------------------------------------------
Logs written to results/{args.experiment_name}/{month_day}/{hour_min_second}
To monitor the logs, run:
tensorboard --logdir results/{args.experiment_name}/{month_day}/{hour_min_second}
The trained model with the best validation loss will be written to:
results/{args.experiment_name}/{month_day}/{hour_min_second}/best.ckpt
-----------------------------------------------------------------------------------
"""
)
model = Classification(num_classes=Dataset.num_classes())
train_data = Dataset(root_dir=args.dataset_path, split="train")
val_data = Dataset(root_dir=args.dataset_path, split="val")
train_loader = train_data.get_dataloader(
batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers
)
val_loader = val_data.get_dataloader(
batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
)
trainer.fit(model, train_loader, val_loader)
else:
# Test
assert (
args.checkpoint is not None
), "Expected the --checkpoint argument to be provided"
test_data = Dataset(root_dir=args.dataset_path, split="test")
test_loader = test_data.get_dataloader(
batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers
)
model = Classification.load_from_checkpoint(args.checkpoint)
results = trainer.test(model=model, test_dataloaders=[test_loader], verbose=False)
print(
f"Classification accuracy (%) on test set: {results[0]['test_acc_epoch'] * 100.0}"
)