forked from karpathy/llama2.c
-
Notifications
You must be signed in to change notification settings - Fork 1
/
save_torchscript.py
executable file
·66 lines (53 loc) · 1.7 KB
/
save_torchscript.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/usr/bin/env python
"""Saves the model as a TorchScript.
Usage examples:
./save_torchscript.py
./save_torchscript.py --dim=300
./save_torchscript.py --gzip_output=True --zero_params=True
The resulting file can be loaded in C++ code and then used for training or
inference with:
#include <torch/script.h>
torch::jit::Module module = torch::jit::load("model.pt")
Note that the serialized model includes the initial parameters and with the default
ModelArgs the file is 59M and gzips down to 55M. If you want to serialize/distribute
the model parameters separately you can zero out the parameters before saving it and
it will gzip down to 780K.
"""
import gzip
import os
import shutil
from inspect import signature
import torch
from model import ModelArgs, Transformer
# Model args config
dim = 288
n_layers = 6
n_heads = 6
n_kv_heads = n_heads
multiple_of = 32
max_seq_len = 256
dropout = 0.0
vocab_size = 32000
norm_eps = 1e-5
# Save config
model_path = "model.pt"
zero_params = False
gzip_output = False
# Allow config overrides
exec(open("configurator.py").read())
def main() -> None:
model_args = {k: globals()[k] for k in signature(ModelArgs).parameters}
model = Transformer(ModelArgs(**model_args))
# If requested zero params before saving the model. This is useful in
# conjunction with gzip_output.
if zero_params:
for p in model.parameters():
p.detach().zero_()
torch.jit.save(torch.jit.script(model), model_path)
if gzip_output:
with open(model_path, "rb") as f_in:
with gzip.open(f"{model_path}.gz", "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
os.unlink(model_path)
if __name__ == "__main__":
main()