-
Notifications
You must be signed in to change notification settings - Fork 64
/
sentencetransformermodel.py
1303 lines (1152 loc) · 53.5 KB
/
sentencetransformermodel.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# SPDX-License-Identifier: Apache-2.0
# The OpenSearch Contributors require contributions made to
# this file be licensed under the Apache-2.0 license or a
# compatible open source license.
# Any modifications Copyright OpenSearch Contributors. See
# GitHub history for details.
import json
import os
import pickle
import platform
import random
import re
import shutil
import subprocess
import time
from pathlib import Path
from typing import List
from zipfile import ZipFile
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml
from accelerate import Accelerator, notebook_launcher
from mdutils.fileutils import MarkDownFile
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import Normalize, Pooling, Transformer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import TrainingArguments, get_linear_schedule_with_warmup
from transformers.convert_graph_to_onnx import convert
from opensearch_py_ml.ml_commons.ml_common_utils import (
_generate_model_content_hash_value,
)
from .base_models import BaseUploadModel
class SentenceTransformerModel(BaseUploadModel):
"""
Class for training, exporting and configuring the SentenceTransformers model.
"""
DEFAULT_MODEL_ID = "sentence-transformers/msmarco-distilbert-base-tas-b"
SYNTHETIC_QUERY_FOLDER = "synthetic_queries"
def __init__(
self,
model_id: str = DEFAULT_MODEL_ID,
folder_path: str = None,
overwrite: bool = False,
) -> None:
"""
Initiate a sentence transformer model class object. The model id will be used to download
pretrained model from the hugging-face and served as the default name for model files, and the folder_path
will be the default location to store files generated in the following functions
:param model_id: Optional, the huggingface mode id to download sentence transformer model,
default model id: 'sentence-transformers/msmarco-distilbert-base-tas-b'
:type model_id: string
:param folder_path: Optional, the path of the folder to save output files, such as queries, pre-trained model,
after-trained custom model and configuration files. if None, default as "/model_files/" under the current
work directory
:type folder_path: string
:param overwrite: Optional, choose to overwrite the folder at folder path. Default as false. When training
different sentence transformer models, it's recommended to give designated folder path every time.
Users can choose to overwrite = True to overwrite previous runs
:type overwrite: bool
:return: no return value expected
:rtype: None
"""
super().__init__(model_id, folder_path, overwrite)
default_folder_path = os.path.join(
os.getcwd(), "sentence_transformer_model_files"
)
if folder_path is None:
self.folder_path = default_folder_path
else:
self.folder_path = folder_path
# Check if self.folder_path exists
if os.path.exists(self.folder_path) and not overwrite:
print(
"To prevent overwriting, please enter a different folder path or delete the folder or enable "
"overwrite = True "
)
raise Exception(
str("The default folder path already exists at : " + self.folder_path)
)
self.model_id = model_id
self.torch_script_zip_file_path = None
self.onnx_zip_file_path = None
def train(
self,
read_path: str,
overwrite: bool = False,
output_model_name: str = None,
zip_file_name: str = None,
compute_environment: str = None,
num_machines: int = 1,
num_gpu: int = 0,
learning_rate: float = 2e-5,
num_epochs: int = 10,
batch_size: int = 32,
verbose: bool = False,
percentile: float = 95,
) -> None:
"""
Read the synthetic queries and use it to fine tune/train (and save) a sentence transformer model.
Parameters
----------
:param read_path:
required, path to the zipped file that contains generated queries, if None, raise exception.
the zipped file should contain pickled file in list of dictionary format with key named as 'query',
'probability' and 'passages'. For example: [{'query':q1,'probability': p1,'passages': pa1}, ...].
'probability' is not required for training purpose
:type read_path: string
:param overwrite:
optional, synthetic_queries/ folder in current directory is to store unzip queries files.
Default to set overwrite as false and if the folder is not empty, raise exception to recommend users
to either clean up folder or enable overwriting is True
:type overwrite: bool
:param output_model_name:
the name of the trained custom model. If None, default as model_id + '.pt'
:type output_model_name: string
:param zip_file_name:
Optional, file name for zip file. if None, default as model_id + '.zip'
:type zip_file_name: string
:param compute_environment:
optional, compute environment type to run model, if None, default using `LOCAL_MACHINE`
:type compute_environment: string
:param num_machines:
optional, number of machine to run model , if None, default using 1
:type num_machines: int
:param num_gpu:
optional, number of gpus to run model , if None, default to 0. If number of gpus > 1, use HuggingFace
accelerate to launch distributed training
:param learning_rate:
optional, learning rate to train model, default is 2e-5
:type learning_rate: float
:param num_epochs:
optional, number of epochs to train model, default is 10
:type num_epochs: int
:param batch_size:
optional, batch size for training, default is 32
:type batch_size: int
:param verbose:
optional, use plotting to plot the training progress. Default as false
:type verbose: bool
:param percentile:
we find the max length of {percentile}% of the documents. Default is 95%
Since this length is measured in terms of words and not tokens we multiply it by 1.4 to approximate the fact
that 1 word in the english vocabulary roughly translates to 1.3 to 1.5 tokens
:type percentile: float
Returns
-------
:return: no return value expected
:rtype: None
"""
query_df = self.read_queries(read_path, overwrite)
train_examples = self.load_training_data(query_df)
if num_gpu > 1:
self.set_up_accelerate_config(
compute_environment=compute_environment,
num_machines=num_machines,
num_processes=num_gpu,
verbose=verbose,
)
if self.__is_notebook():
# MPS needs to be only enabled for MACOS: https://pytorch.org/docs/master/notes/mps.html
if platform.system() == "Darwin":
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
print(
"MPS not available because the current PyTorch install was not "
"built with MPS enabled."
)
else:
print(
"MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)
exit(1) # Existing the script as the script will break anyway
notebook_launcher(
self.train_model,
args=(
train_examples,
self.model_id,
output_model_name,
learning_rate,
num_epochs,
batch_size,
verbose,
num_gpu,
percentile,
),
num_processes=num_gpu,
)
else:
try:
subprocess.run(
[
"accelerate",
"launch",
self.train_model(
train_examples,
self.model_id,
output_model_name,
learning_rate,
num_epochs,
batch_size,
verbose,
num_gpu,
percentile,
),
],
)
# TypeError: expected str, bytes or os.PathLike object, not TopLevelTracedModule happens after
# running process.
except TypeError:
self.zip_model(self.folder_path, output_model_name, zip_file_name)
return None
else: # Do not use accelerate when num_gpu is 1 or 0
self.train_model(
train_examples,
self.model_id,
output_model_name,
learning_rate,
num_epochs,
batch_size,
verbose,
num_gpu,
percentile,
)
self.zip_model(self.folder_path, output_model_name, zip_file_name)
return None
# public step by step functions:
def read_queries(self, read_path: str, overwrite: bool = False) -> pd.DataFrame:
"""
Read the queries generated from the Synthetic Query Generator (SQG) model, unzip files to current directory
within synthetic_queries/ folder, output as a dataframe
:param read_path:
required, path to the zipped file that contains generated queries
:type read_path: string
:param overwrite:
optional, synthetic_queries/ folder in current directory is to store unzip queries files.
Default to set overwrite as false and if the folder is not empty, raise exception to recommend users
to either clean up folder or enable overwriting is True
:type overwrite: bool
:return: The dataframe of queries.
:rtype: panda dataframe
"""
# assign a local folder 'synthetic_queries/' to store the unzip file,
# check if the folder contains sub-folders and files, remove and clean up the folder before unzip.
# walk through the zip file and read the file paths into file_list
unzip_path = os.path.join(self.folder_path, self.SYNTHETIC_QUERY_FOLDER)
if os.path.exists(unzip_path):
if len(os.listdir(unzip_path)) > 0:
if overwrite:
for files in os.listdir(unzip_path):
sub_path = os.path.join(unzip_path, files)
if os.path.isfile(sub_path):
os.remove(sub_path)
else:
try:
shutil.rmtree(sub_path)
except OSError as err:
print(
"Failed to delete files, please delete all files in "
+ str(unzip_path)
+ " "
+ str(err)
)
else:
raise Exception(
"'synthetic_queries' folder is not empty, please clean up folder, or enable overwrite = "
+ "True. Try again. Please check "
+ unzip_path
)
# appending all the file paths of synthetic query files in a list.
file_list = []
process = []
with ZipFile(read_path, "r") as zip_ref:
zip_ref.extractall(unzip_path)
for root, dirnames, filenames in os.walk(unzip_path):
for filename in filenames:
file_list.append(os.path.join(root, filename))
# check empty zip file
num_file = len(file_list)
if num_file == 0:
raise Exception(
"Zipped file is empty. Please provide a zip file with synthetic queries."
)
for file_path in file_list:
try:
with open(file_path, "rb") as f:
print("Reading synthetic query file: " + file_path + "\n")
process.append(pickle.load(f))
except IOError:
print("Failed to open synthetic query file: " + file_path + "\n")
# reading the files to get the probability, queries and passages
prob = []
query = []
passages = []
for j in range(0, num_file):
for dict_str in process[j]:
if "query" in dict_str.keys() and "passage" in dict_str.keys():
query.append(dict_str["query"])
passages.append(dict_str["passage"])
if "probability" in dict_str.keys():
prob.append(dict_str["probability"]) # language modeling score
else:
prob.append(
"-1"
) # "-1" will serve as a label saying that the probability does not exist.
df = pd.DataFrame(
list(zip(prob, query, passages)), columns=["prob", "query", "passages"]
)
# dropping duplicate queries
df = df.drop_duplicates(subset=["query"])
# for removing the "QRY:" token if they exist in passages
df["passages"] = df.apply(lambda x: self.__qryrem(x), axis=1)
# shuffle data within dataframe
df = df.sample(frac=1)
return df
def load_training_data(self, query_df) -> List[List[str]]:
"""
Create input data for training the model
:param query_df:
required for loading training data
:type query_df: pd.DataFrame
:return: the list of train examples.
:rtype: list
"""
train_examples = []
print("Loading training examples... \n")
queries = list(query_df["query"])
passages = list(query_df["passages"])
for i in tqdm(range(len(query_df)), total=len(query_df)):
train_examples.append([queries[i], passages[i]])
return train_examples
def train_model(
self,
train_examples: List[List[str]],
model_id: str = None,
output_model_name: str = None,
learning_rate: float = 2e-5,
num_epochs: int = 10,
batch_size: int = 32,
verbose: bool = False,
num_gpu: int = 0,
percentile: float = 95,
):
"""
Takes in training data and a sentence transformer url to train a custom semantic search model
:param train_examples:
required, input for the sentence transformer model training
:type train_examples: List of strings in another list
:param model_id:
[optional] the url to download sentence transformer model, if None,
default as 'sentence-transformers/msmarco-distilbert-base-tas-b'
:type model_id: string
:param output_model_name:
optional,the name of the trained custom model. If None, default as model_id + '.pt'
:type output_model_name: string
:param learning_rate:
optional, learning rate to train model, default is 2e-5
:type learning_rate: float
:param num_epochs:
optional, number of epochs to train model, default is 10
:type num_epochs: int
:param batch_size:
optional, batch size for training, default is 32
:type batch_size: int
:param verbose:
optional, use plotting to plot the training progress and printing more logs. Default as false
:type verbose: bool
:param num_gpu:
Number of gpu will be used for training. Default 0
:type num_gpu: int
:param percentile:
To save memory while training we truncate all passages beyond a certain max_length.
Most middle-sized transformers have a max length limit of 512 tokens. However, certain corpora can
have shorter documents. We find the word length of all documents, sort them in increasing order and
take the max length of {percentile}% of the documents. Default is 95%
:type percentile: float
:return: the torch script format trained model.
:rtype: .pt file
"""
if model_id is None:
model_id = "sentence-transformers/msmarco-distilbert-base-tas-b"
if output_model_name is None:
output_model_name = str(self.model_id.split("/")[-1] + ".pt")
# declare variables before assignment for training
corp_len = []
# Load a model from HuggingFace
model = SentenceTransformer(model_id)
# Calculate the length of passages
for i in range(len(train_examples)):
corp_len.append(len(train_examples[i][1].split(" ")))
# In the following, we find the max length of 95% of the documents (when sorted by increasing word length).
# Since this length is measured in terms of words and not tokens we multiply it by 1.4 to approximate the
# fact that 1 word in the english vocabulary roughly translates to 1.3 to 1.5 tokens. For instance the word
# butterfly will be split by most tokenizers into butter and fly, but the word sun will be probably kept as it
# is. Note that this ratio will be higher if the corpus is jargon heavy and/or domain specific.
corp_max_tok_len = int(np.percentile(corp_len, percentile) * 1.4)
model.tokenizer.model_max_length = corp_max_tok_len
model.max_seq_length = corp_max_tok_len
# use accelerator for training
if num_gpu > 1:
# the default_args are required for initializing train_dataloader,
# but output_dir is not used in this function.
default_args = {
"output_dir": "~/",
"evaluation_strategy": "steps",
"num_train_epochs": num_epochs,
"log_level": "error",
"report_to": "none",
}
training_args = TrainingArguments(
per_device_train_batch_size=batch_size,
gradient_accumulation_steps=1,
fp16=False,
**default_args,
)
train_dataloader = DataLoader(
train_examples,
shuffle=True,
batch_size=training_args.per_device_train_batch_size, # Trains with this batch size.
)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=min(10000, 0.05 * len(train_dataloader)),
num_training_steps=num_epochs * len(train_dataloader),
)
accelerator = Accelerator()
model, optimizer, train_dataloader, scheduler = accelerator.prepare(
model, optimizer, train_dataloader, scheduler
)
print("Device using for training: ", accelerator.device)
model.to(accelerator.device)
init_time = time.time()
total_loss = []
if accelerator.process_index == 0:
print("Start training with accelerator...\n")
print(f"The number of training epochs per process are {num_epochs}\n")
print(
f"The total number of steps per training epoch are {len(train_dataloader)}\n"
)
# The following training loop trains the sentence transformer model using the standard contrastive loss
# with in-batch negatives. The particular contrastive loss that we use is the Multi-class N-pair Loss (
# eq. 6, Sohn, NeurIPS 2016). In addition, we symmetrize the loss with respect to queries and passages (
# as also used in OpenAI's CLIP model). The performance improves with the number of in-batch negatives
# but larger batch sizes can lead to out of memory issues, so please use the batch-size judiciously.
for epoch in range(num_epochs):
print(
"Training epoch "
+ str(epoch)
+ " in process "
+ str(accelerator.process_index)
+ "...\n"
)
for step, batch in tqdm(
enumerate(train_dataloader), total=len(train_dataloader)
):
batch_q = batch[0]
batch_p = batch[1]
model = accelerator.unwrap_model(model)
out_q = accelerator.unwrap_model(model).tokenize(batch_q)
for key in out_q.keys():
out_q[key] = out_q[key].to(model.device)
out_p = accelerator.unwrap_model(model).tokenize(batch_p)
for key in out_p.keys():
out_p[key] = out_p[key].to(model.device)
Y = model(out_q)["sentence_embedding"]
X = model(out_p)["sentence_embedding"]
XY = torch.exp(torch.matmul(X, Y.T) / (X.shape[1]) ** 0.5)
num = torch.diagonal(XY)
den0 = torch.sum(XY, dim=0)
den1 = torch.sum(XY, dim=1)
batch_loss = -torch.sum(torch.log(num / den0)) - torch.sum(
torch.log(num / den1)
)
accelerator.backward(batch_loss)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
total_loss.append(batch_loss.item())
if verbose is True and not step % 500 and step != 0:
plt.plot(total_loss[::100])
plt.show()
accelerator.wait_for_everyone()
# When number of GPU is less than 2, we don't need to accelerate
else:
# identify if running on gpu or cpu
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
total_steps = (len(train_examples) // batch_size) * num_epochs
steps_size = len(train_examples) // batch_size
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
scheduler = get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=max(10000, total_steps * 0.05),
num_training_steps=total_steps,
)
loss = []
init_time = time.time()
print("Start training without accelerator...\n")
print(f"The number of training epoch are {num_epochs}\n")
print(
f"The total number of steps per training epoch are {len(train_examples) // batch_size}\n"
)
# The following training loop trains the sentence transformer model using the standard contrastive loss
# with in-batch negatives. The particular contrastive loss that we use is the Multi-class N-pair Loss (
# eq. 6, Sohn, NeurIPS 2016). In addition, we symmetrize the loss with respect to queries and passages (
# as also used in OpenAI's CLIP model). The performance improves with the number of in-batch negatives
# but larger batch sizes can lead to out of memory issues, so please use the batch-size judiciously.
for epoch in range(num_epochs):
random.shuffle(train_examples)
print("Training epoch " + str(epoch) + "...\n")
for j in tqdm(range(steps_size), total=steps_size):
batch_q = []
batch_p = []
for example in train_examples[
j * batch_size : (j + 1) * batch_size
]:
batch_q.append(example[0])
batch_p.append(example[1])
out_q = model.tokenize(batch_q)
for key in out_q.keys():
out_q[key] = out_q[key].to(device)
Y = model(out_q)["sentence_embedding"]
out = model.tokenize(batch_p)
for key in out.keys():
out[key] = out[key].to(device)
X = model(out)["sentence_embedding"]
XY = torch.exp(torch.matmul(X, Y.T) / (X.shape[1]) ** 0.5)
num = torch.diagonal(XY)
den0 = torch.sum(XY, dim=0)
den1 = torch.sum(XY, dim=1)
train_loss = -torch.sum(torch.log(num / den0)) - torch.sum(
torch.log(num / den1)
)
loss.append(train_loss.item())
train_loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
if verbose is True and not j % 500 and j != 0:
plt.plot(loss[::100])
plt.show()
# saving the pytorch model and the tokenizers.json file is saving at this step
model.save(self.folder_path)
device = "cpu"
cpu_model = model.to(device)
print(f"Total training time: {time.time() - init_time}\n")
for key in out_q.keys():
out_q[key] = out_q[key].to(device)
traced_cpu = torch.jit.trace(
cpu_model,
(
{
"input_ids": out_q["input_ids"],
"attention_mask": out_q["attention_mask"],
}
),
strict=False,
)
if verbose:
print("Preparing model to save...\n")
torch.jit.save(traced_cpu, os.path.join(self.folder_path, output_model_name))
print("Model saved to path: " + self.folder_path + "\n")
return traced_cpu
def zip_model(
self,
model_path: str = None,
model_name: str = None,
zip_file_name: str = None,
add_apache_license: bool = False,
verbose: bool = False,
) -> None:
"""
Zip the model file and its tokenizer.json file to prepare to upload to the OpenSearch cluster
:param model_path:
Optional, path to find the model file, if None, default as concatenate model_id and
'.pt' file in current path
:type model_path: string
:param model_name:
the name of the trained custom model. If None, default as concatenate model_id and '.pt'
:type model_name: string
:param zip_file_name: str =None
Optional, file name for zip file. if None, default as concatenate model_id and '.zip'
:type zip_file_name: string
:param add_apache_license:
Optional, whether to add a Apache-2.0 license file to model zip file
:type add_apache_license: string
:param verbose:
optional, use to print more logs. Default as false
:type verbose: bool
:return: no return value expected
:rtype: None
"""
if model_name is None:
model_name = str(self.model_id.split("/")[-1] + ".pt")
if model_path is None:
model_path = os.path.join(self.folder_path, str(model_name))
else:
model_path = os.path.join(model_path, str(model_name))
if verbose:
print("model path is: ", model_path)
if zip_file_name is None:
zip_file_name = str(self.model_id.split("/")[-1] + ".zip")
zip_file_path = os.path.join(self.folder_path, zip_file_name)
zip_file_name_without_extension = zip_file_name.split(".")[0]
if verbose:
print("Zip file name without extension: ", zip_file_name_without_extension)
tokenizer_json_path = os.path.join(self.folder_path, "tokenizer.json")
print("tokenizer_json_path: ", tokenizer_json_path)
if not os.path.exists(tokenizer_json_path):
raise Exception(
"Cannot find tokenizer.json file, please check at "
+ tokenizer_json_path
)
if not os.path.exists(model_path):
raise Exception(
"Cannot find model in the model path , please check at " + model_path
)
# Create a ZipFile Object
with ZipFile(str(zip_file_path), "w") as zipObj:
zipObj.write(model_path, arcname=str(model_name))
zipObj.write(
tokenizer_json_path,
arcname="tokenizer.json",
)
if add_apache_license:
super()._add_apache_license_to_model_zip_file(zip_file_path)
print("zip file is saved to " + zip_file_path + "\n")
def save_as_pt(
self,
sentences: [str],
model_id="sentence-transformers/msmarco-distilbert-base-tas-b",
model_name: str = None,
save_json_folder_path: str = None,
model_output_path: str = None,
zip_file_name: str = None,
add_apache_license: bool = False,
) -> str:
"""
Download sentence transformer model directly from huggingface, convert model to torch script format,
zip the model file and its tokenizer.json file to prepare to upload to the Open Search cluster
:param sentences:
Required, for example sentences = ['today is sunny']
:type sentences: List of string [str]
:param model_id:
sentence transformer model id to download model from sentence transformers.
default model_id = "sentence-transformers/msmarco-distilbert-base-tas-b"
:type model_id: string
:param model_name:
Optional, model name to name the model file, e.g, "sample_model.pt". If None, default takes the
model_id and add the extension with ".pt"
:type model_name: string
:param save_json_folder_path:
Optional, path to save model json file, e.g, "home/save_pre_trained_model_json/"). If None, default as
default_folder_path from the constructor
:type save_json_folder_path: string
:param model_output_path:
Optional, path to save traced model zip file. If None, default as
default_folder_path from the constructor
:type model_output_path: string
:param zip_file_name:
Optional, file name for zip file. e.g, "sample_model.zip". If None, default takes the model_id
and add the extension with ".zip"
:type zip_file_name: string
:param add_apache_license:
Optional, whether to add a Apache-2.0 license file to model zip file
:type add_apache_license: string
:return: model zip file path. The file path where the zip file is being saved
:rtype: string
"""
model = SentenceTransformer(model_id)
if model_name is None:
model_name = str(model_id.split("/")[-1] + ".pt")
model_path = os.path.join(self.folder_path, model_name)
if save_json_folder_path is None:
save_json_folder_path = self.folder_path
if model_output_path is None:
model_output_path = self.folder_path
if zip_file_name is None:
zip_file_name = str(model_id.split("/")[-1] + ".zip")
zip_file_path = os.path.join(model_output_path, zip_file_name)
# handle when model_max_length is unproperly defined in model's tokenizer (e.g. "intfloat/e5-small-v2")
# (See PR #219 and https://github.com/huggingface/transformers/issues/14561 for more context)
if model.tokenizer.model_max_length > model.get_max_seq_length():
model.tokenizer.model_max_length = model.get_max_seq_length()
print(
f"The model_max_length is not properly defined in tokenizer_config.json. Setting it to be {model.tokenizer.model_max_length}"
)
# save tokenizer.json in save_json_folder_name
model.save(save_json_folder_path)
super()._fill_null_truncation_field(
save_json_folder_path, model.tokenizer.model_max_length
)
# convert to pt format will need to be in cpu,
# set the device to cpu, convert its input_ids and attention_mask in cpu and save as .pt format
device = torch.device("cpu")
cpu_model = model.to(device)
features = cpu_model.tokenizer(
sentences, return_tensors="pt", padding=True, truncation=True
).to(device)
compiled_model = torch.jit.trace(
cpu_model,
(
{
"input_ids": features["input_ids"],
"attention_mask": features["attention_mask"],
}
),
strict=False,
)
torch.jit.save(compiled_model, model_path)
print("model file is saved to ", model_path)
# zip model file along with tokenizer.json (and license file) as output
with ZipFile(str(zip_file_path), "w") as zipObj:
zipObj.write(
model_path,
arcname=str(model_name),
)
zipObj.write(
os.path.join(save_json_folder_path, "tokenizer.json"),
arcname="tokenizer.json",
)
if add_apache_license:
super()._add_apache_license_to_model_zip_file(zip_file_path)
self.torch_script_zip_file_path = zip_file_path
print("zip file is saved to ", zip_file_path, "\n")
return zip_file_path
def save_as_onnx(
self,
model_id="sentence-transformers/msmarco-distilbert-base-tas-b",
model_name: str = None,
save_json_folder_path: str = None,
model_output_path: str = None,
zip_file_name: str = None,
add_apache_license: bool = False,
) -> str:
"""
Download sentence transformer model directly from huggingface, convert model to onnx format,
zip the model file and its tokenizer.json file to prepare to upload to the Open Search cluster
:param model_id:
sentence transformer model id to download model from sentence transformers.
default model_id = "sentence-transformers/msmarco-distilbert-base-tas-b"
:type model_id: string
:param model_name:
Optional, model name to name the model file, e.g, "sample_model.pt". If None, default takes the
model_id and add the extension with ".pt"
:type model_name: string
:param save_json_folder_path:
Optional, path to save model json file, e.g, "home/save_pre_trained_model_json/"). If None, default as
default_folder_path from the constructor
:type save_json_folder_path: string
:param model_output_path:
Optional, path to save traced model zip file. If None, default as
default_folder_path from the constructor
:type model_output_path: string
:param zip_file_name:
Optional, file name for zip file. e.g, "sample_model.zip". If None, default takes the model_id
and add the extension with ".zip"
:type zip_file_name: string
:param add_apache_license:
Optional, whether to add a Apache-2.0 license file to model zip file
:type add_apache_license: string
:return: model zip file path. The file path where the zip file is being saved
:rtype: string
"""
model = SentenceTransformer(model_id)
if model_name is None:
model_name = str(model_id.split("/")[-1] + ".onnx")
model_path = os.path.join(self.folder_path, "onnx", model_name)
if save_json_folder_path is None:
save_json_folder_path = self.folder_path
if model_output_path is None:
model_output_path = self.folder_path
if zip_file_name is None:
zip_file_name = str(model_id.split("/")[-1] + ".zip")
zip_file_path = os.path.join(model_output_path, zip_file_name)
# handle when model_max_length is unproperly defined in model's tokenizer (e.g. "intfloat/e5-small-v2")
# (See PR #219 and https://github.com/huggingface/transformers/issues/14561 for more context)
if model.tokenizer.model_max_length > model.get_max_seq_length():
model.tokenizer.model_max_length = model.get_max_seq_length()
print(
f"The model_max_length is not properly defined in tokenizer_config.json. Setting it to be {model.tokenizer.model_max_length}"
)
# save tokenizer.json in output_path
model.save(save_json_folder_path)
super()._fill_null_truncation_field(
save_json_folder_path, model.tokenizer.model_max_length
)
convert(
framework="pt",
model=model_id,
output=Path(model_path),
opset=15,
)
print("model file is saved to ", model_path)
# zip model file along with tokenizer.json (and license file) as output
with ZipFile(str(zip_file_path), "w") as zipObj:
zipObj.write(
model_path,
arcname=str(model_name),
)
zipObj.write(
os.path.join(save_json_folder_path, "tokenizer.json"),
arcname="tokenizer.json",
)
if add_apache_license:
super()._add_apache_license_to_model_zip_file(zip_file_path)
self.onnx_zip_file_path = zip_file_path
print("zip file is saved to ", zip_file_path, "\n")
return zip_file_path
def set_up_accelerate_config(
self,
compute_environment: str = None,
num_machines: int = 1,
num_processes: int = None,
verbose: bool = False,
) -> None:
"""
Get default config setting based on the number of GPU on the machine
if users require other configs, users can run !acclerate config for more options
:param compute_environment:
optional, compute environment type to run model, if None, default using 'LOCAL_MACHINE'
:type compute_environment: string
:param num_machines:
optional, number of machine to run model , if None, default using 1
:type num_machines: int
:param num_processes:
optional, number of processes to run model, if None, default to check how many gpus are available and
use all. if no gpu is available, use cpu
:type num_processes: int
:param verbose:
optional, use printing more logs. Default as false
:type verbose: bool
:return: no return value expected
:rtype: None
"""
if compute_environment is None or compute_environment == 0:
compute_environment = "LOCAL_MACHINE"
else:
subprocess.run("accelerate config")
return
hf_cache_home = os.path.expanduser(
os.getenv(
"HF_HOME",
os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "huggingface"),
)
)
cache_dir = os.path.join(hf_cache_home, "accelerate")
file_path = os.path.join(cache_dir, "default_config.yaml")
use_cpu = False
if verbose:
print("generated config file: at " + file_path + "\n")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
if num_processes is None:
if torch.cuda.is_available():
num_processes = torch.cuda.device_count()
else:
num_processes = 1
use_cpu = True
model_config_content = [
{
"compute_environment": compute_environment,
"deepspeed_config": {
"gradient_accumulation_steps": 1,
"offload_optimizer_device": "none",
"offload_param_device": "none",
"zero3_init_flag": False,
"zero_stage": 2,
},
"distributed_type": "DEEPSPEED",
"downcast_bf16": "no",
"fsdp_config": {},
"machine_rank": 0,
"main_process_ip": None,
"main_process_port": None,