-
Notifications
You must be signed in to change notification settings - Fork 58
/
tabnet.py
123 lines (103 loc) · 3.99 KB
/
tabnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Reported (reproduced) results of of TabNet model in the original paper
https://arxiv.org/abs/1908.07442.
Forest Cover Type: 96.99 (96.53)
KDD Census Income: 95.5 (95.41)
"""
import argparse
import os.path as osp
import torch
import torch.nn.functional as F
from torch.optim.lr_scheduler import ExponentialLR
from tqdm import tqdm
from torch_frame.data import DataLoader
from torch_frame.datasets import ForestCoverType, KDDCensusIncome
from torch_frame.nn import TabNet
parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default="ForestCoverType",
choices=["ForestCoverType", "KDDCensusIncome"])
parser.add_argument('--channels', type=int, default=128)
parser.add_argument('--gamma', type=int, default=1.2)
parser.add_argument('--num_layers', type=int, default=6)
parser.add_argument('--batch_size', type=int, default=4096)
parser.add_argument('--lr', type=float, default=0.005)
parser.add_argument('--epochs', type=int, default=50)
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--compile', action='store_true')
args = parser.parse_args()
torch.manual_seed(args.seed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Prepare datasets
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data',
args.dataset)
if args.dataset == "ForestCoverType":
dataset = ForestCoverType(root=path)
elif args.dataset == "KDDCensusIncome":
dataset = KDDCensusIncome(root=path)
else:
raise ValueError(f"Unsupported dataset called {args.dataset}")
dataset.materialize()
assert dataset.task_type.is_classification
dataset = dataset.shuffle()
# Split ratio is set to 80% / 10% / 10% (no clear mentioning of split in the
# original TabNet paper)
train_dataset, val_dataset, test_dataset = dataset[:0.8], dataset[
0.8:0.9], dataset[0.9:]
# Set up data loaders
train_tensor_frame = train_dataset.tensor_frame
val_tensor_frame = val_dataset.tensor_frame
test_tensor_frame = test_dataset.tensor_frame
train_loader = DataLoader(train_tensor_frame, batch_size=args.batch_size,
shuffle=True)
val_loader = DataLoader(val_tensor_frame, batch_size=args.batch_size)
test_loader = DataLoader(test_tensor_frame, batch_size=args.batch_size)
# Set up model and optimizer
model = TabNet(
out_channels=dataset.num_classes,
num_layers=args.num_layers,
split_attn_channels=args.channels,
split_feat_channels=args.channels,
gamma=args.gamma,
col_stats=dataset.col_stats,
col_names_dict=train_tensor_frame.col_names_dict,
).to(device)
model = torch.compile(model, dynamic=True) if args.compile else model
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
lr_scheduler = ExponentialLR(optimizer, gamma=0.95)
def train(epoch: int) -> float:
model.train()
loss_accum = total_count = 0
for tf in tqdm(train_loader, desc=f'Epoch: {epoch}'):
tf = tf.to(device)
pred = model(tf)
loss = F.cross_entropy(pred, tf.y)
optimizer.zero_grad()
loss.backward()
loss_accum += float(loss) * len(tf.y)
total_count += len(tf.y)
optimizer.step()
return loss_accum / total_count
@torch.no_grad()
def test(loader: DataLoader) -> float:
model.eval()
accum = total_count = 0
for tf in loader:
tf = tf.to(device)
pred = model(tf)
pred_class = pred.argmax(dim=-1)
accum += float((tf.y == pred_class).sum())
total_count += len(tf.y)
return accum / total_count
best_val_acc = 0
best_test_acc = 0
for epoch in range(1, args.epochs + 1):
train_loss = train(epoch)
train_acc = test(train_loader)
val_acc = test(val_loader)
test_acc = test(test_loader)
if best_val_acc < val_acc:
best_val_acc = val_acc
best_test_acc = test_acc
print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')
lr_scheduler.step()
print(f'Best Val Acc: {best_val_acc:.4f}, Best Test Acc: {best_test_acc:.4f}')