-
Notifications
You must be signed in to change notification settings - Fork 0
/
swnu_network.py
151 lines (136 loc) · 5.78 KB
/
swnu_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import annotations
import torch
import torch.nn as nn
from sig_networks.feature_concatenation import FeatureConcatenation
from sig_networks.ffn_baseline import FeedforwardNeuralNetModel
from sig_networks.swnu import SWNU
class SWNUNetwork(nn.Module):
"""
Stacked Deep Signature Neural Network for classification.
"""
def __init__(
self,
input_channels: int,
num_features: int,
embedding_dim: int,
log_signature: bool,
sig_depth: int,
pooling: str,
hidden_dim_swnu: list[int] | int,
hidden_dim_ffn: list[int] | int,
output_dim: int,
dropout_rate: float,
reverse_path: bool = False,
output_channels: int | None = None,
augmentation_type: str = "Conv1d",
hidden_dim_aug: list[int] | int | None = None,
BiLSTM: bool = False,
comb_method: str = "concatenation",
):
"""
SWNU network for classification.
Parameters
----------
input_channels : int
Dimension of the embeddings in the path that will be passed in.
num_features : int
Number of time features to add to FFN input. If none, set to zero.
embedding_dim : int
Dimension of embedding to add to FFN input. If none, set to zero.
log_signature : bool
Whether or not to use the log signature or standard signature.
sig_depth : int
The depth to truncate the path signature at.
pooling: str
Pooling operation to apply in SWNU to obtain history representation.
Options are:
- "signature": apply signature on the LSTM units at the end
to obtain the final history representation
- "lstm": take the final (non-padded) LSTM unit as the final
history representation
hidden_dim_swnu : list[int] | int
Dimensions of the hidden layers in the SNWU blocks.
hidden_dim_ffn : list[int] | int
Dimension of the hidden layers in the FFN.
output_dim : int
Dimension of the output layer in the FFN.
dropout_rate : float
Dropout rate in the FFN.
reverse_path : bool, optional
Whether or not to reverse the path before passing it through the
signature layers, by default False.
output_channels : int | None, optional
Requested dimension of the embeddings after convolution layer.
If None, will be set to the last item in `hidden_dim`, by default None.
augmentation_type : str, optional
Method of augmenting the path, by default "Conv1d".
Options are:
- "Conv1d": passes path through 1D convolution layer.
- "signatory": passes path through `Augment` layer from `signatory` package.
hidden_dim_aug : list[int] | int | None
Dimensions of the hidden layers in the augmentation layer.
Passed into `Augment` class from `signatory` package if
`augmentation_type='signatory'`, by default None.
BiLSTM : bool, optional
Whether or not a birectional LSTM is used,
by default False (unidirectional LSTM is used in this case).
comb_method : str, optional
Determines how to combine the path signature and embeddings,
by default "gated_addition".
Options are:
- concatenation: concatenation of path signature and embedding vector
- gated_addition: element-wise addition of path signature
and embedding vector
- gated_concatenation: concatenation of linearly gated path signature
and embedding vector
- scaled_concatenation: concatenation of single value scaled path
signature and embedding vector
"""
super().__init__()
if pooling not in ["signature", "lstm"]:
raise ValueError(
"`pooling` must be 'signature' or 'lstm'. " f"Got {pooling} instead."
)
self.swnu = SWNU(
input_channels=input_channels,
output_channels=output_channels,
log_signature=log_signature,
sig_depth=sig_depth,
hidden_dim=hidden_dim_swnu,
pooling=pooling,
reverse_path=reverse_path,
BiLSTM=BiLSTM,
augmentation_type=augmentation_type,
hidden_dim_aug=hidden_dim_aug,
)
# determining how to concatenate features to the SWNU features
self.embedding_dim = embedding_dim
self.num_features = num_features
self.comb_method = comb_method
self.feature_concat = FeatureConcatenation(
input_dim=self.swnu.swlstm.output_dim,
num_features=self.num_features,
embedding_dim=self.embedding_dim,
comb_method=self.comb_method,
)
# FFN for classification
# make sure hidden_dim_ffn a list of integers
if isinstance(hidden_dim_ffn, int):
hidden_dim_ffn = [hidden_dim_ffn]
self.hidden_dim_ffn = hidden_dim_ffn
self.ffn = FeedforwardNeuralNetModel(
input_dim=self.feature_concat.output_dim,
hidden_dim=self.hidden_dim_ffn,
output_dim=output_dim,
dropout_rate=dropout_rate,
)
def forward(self, path: torch.Tensor, features: torch.Tensor | None = None):
# path has dimensions [batch, length of signal, channels]
# features has dimensions [batch, num_features+embedding_dim]
# use SWNU to obtain feature set
out = self.swnu(path)
# combine with features provided
out = self.feature_concat(out, features)
# FFN
out = self.ffn(out.float())
return out