-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathhept_to_hls4ml.py
62 lines (47 loc) · 1.67 KB
/
hept_to_hls4ml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from pathlib import Path
import torch
from transformer_simple import Transformer
from preprocessor import Preprocessor
import hls4ml
def data(num_elements, in_dim, coords_dim, num_batches, num_regions, num_heads, n_hashes, block_size):
torch.manual_seed(42)
x = torch.rand(num_elements, in_dim)
coords = torch.randn(num_elements, coords_dim)
batch = torch.randint(0, num_batches, (num_elements,))
preprocessor = Preprocessor(num_regions, num_heads, n_hashes, block_size, device="cpu")
return preprocessor.prepare_input(x, coords, batch)
def main():
num_elements = 8
in_dim = 6
coords_dim = 2
num_batches = 2
num_regions = 2
h_dim = 4
num_heads = 2
out_dim = 8
block_size = 4
n_hashes = 2
num_w_per_dist = 2
mlp_out_hdim = 8
x, _, coords, unpad_seq, _ = data(num_elements, in_dim, coords_dim, num_batches, num_regions, num_heads, n_hashes, block_size)
kwargs = {
"h_dim": h_dim,
"num_heads": num_heads,
"out_dim": out_dim,
"block_size": block_size,
"n_hashes": n_hashes,
"num_w_per_dist": num_w_per_dist,
"mlp_out_hdim": mlp_out_hdim,
}
model = Transformer(in_dim=in_dim, coords_dim=coords_dim, **kwargs)
model.eval()
pytorch_prediction = model(x, coords, unpad_seq)
print(pytorch_prediction)
# Symbolically trace
symbolically_traced_model = torch.fx.symbolic_trace(model)
print(symbolically_traced_model.graph)
filepath = Path(__file__).parent / "tracings" / "simple_transformer_traced_graph.txt"
with open(filepath, "w") as f:
f.write(str(symbolically_traced_model.graph))
if __name__ == "__main__":
main()