-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathmodel.py
112 lines (95 loc) · 3.61 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import torch
from typing import *
import torch.nn as nn
import torch.nn.functional as F
class SplineLinear(nn.Linear):
def __init__(self, in_features: int, out_features: int, init_scale: float = 0.1, **kw) -> None:
self.init_scale = init_scale
super().__init__(in_features, out_features, bias=False, **kw)
def reset_parameters(self) -> None:
nn.init.trunc_normal_(self.weight, mean=0, std=self.init_scale)
class RadialBasisFunction(nn.Module):
def __init__(
self,
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 8,
denominator: float = None, # larger denominators lead to smoother basis
):
super().__init__()
grid = torch.linspace(grid_min, grid_max, num_grids)
self.grid = torch.nn.Parameter(grid, requires_grad=False)
self.denominator = denominator or (grid_max - grid_min) / (num_grids - 1)
def forward(self, x):
return torch.exp(-((x[..., None] - self.grid) / self.denominator) ** 2)
class FastKANLayer(nn.Module):
def __init__(
self,
input_dim: int,
output_dim: int,
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 8,
use_base_update: bool = True,
base_activation = F.silu,
spline_weight_init_scale: float = 0.1,
) -> None:
super().__init__()
self.layernorm = nn.LayerNorm(input_dim)
self.rbf = RadialBasisFunction(grid_min, grid_max, num_grids)
self.spline_linear = SplineLinear(input_dim * num_grids, output_dim, spline_weight_init_scale)
self.use_base_update = use_base_update
if use_base_update:
self.base_activation = base_activation
self.base_linear = nn.Linear(input_dim, output_dim)
def forward(self, x, time_benchmark=False):
if not time_benchmark:
spline_basis = self.rbf(self.layernorm(x))
else:
spline_basis = self.rbf(x)
ret = self.spline_linear(spline_basis.view(*spline_basis.shape[:-2], -1))
if self.use_base_update:
base = self.base_linear(self.base_activation(x))
ret = ret + base
return ret
class FastKAN(nn.Module):
def __init__(
self,
layers_hidden: List[int],
grid_min: float = -2.,
grid_max: float = 2.,
num_grids: int = 8,
use_base_update: bool = True,
base_activation = F.silu,
spline_weight_init_scale: float = 0.1,
) -> None:
super().__init__()
self.layers = nn.ModuleList([
FastKANLayer(
in_dim, out_dim,
grid_min=grid_min,
grid_max=grid_max,
num_grids=num_grids,
use_base_update=use_base_update,
base_activation=base_activation,
spline_weight_init_scale=spline_weight_init_scale,
) for in_dim, out_dim in zip(layers_hidden[:-1], layers_hidden[1:])
])
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class my_model(nn.Module):
def __init__(self, dims):
super(my_model, self).__init__()
self.kan_1 = FastKAN([dims[0], dims[1]])
self.kan_2 = FastKAN([dims[0], dims[1]])
def forward(self, x, is_train=True, sigma=0.01):
out1 = self.kan_1(x)
out2 = self.kan_2(x)
out1 = F.normalize(out1, dim=1, p=2)
if is_train:
out2 = F.normalize(out2, dim=1, p=2) + torch.normal(0, torch.ones_like(out2) * sigma).cuda()
else:
out2 = F.normalize(out2, dim=1, p=2)
return out1, out2