-
Notifications
You must be signed in to change notification settings - Fork 466
/
Copy pathtemplate.py
3709 lines (3140 loc) · 160 KB
/
template.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) Alibaba, Inc. and its affiliates.
import inspect
import os
import re
from contextlib import contextmanager
from copy import deepcopy
from datetime import datetime
from functools import partial, wraps
from types import MethodType
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union
import json
import torch
import torch.nn.functional as F
import transformers
from packaging import version
from peft import PeftModel
from torch.nn.utils.rnn import pad_sequence
from transformers import PreTrainedTokenizerBase, StoppingCriteria
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import strtobool
from swift.llm.agent.utils import calculate_loss_scale, get_tools_prompt
from swift.torchacc_utils import pad_and_split_batch
from swift.utils import get_dist_setting, get_logger, upper_bound, use_torchacc
from .vision_utils import (load_audio_qwen, load_batch, load_image, load_video_cogvlm2, load_video_internvl,
load_video_llava, load_video_minicpmv_mplug_owl3, load_video_qwen2, rescale_image,
transform_image)
logger = get_logger()
DEFAULT_SYSTEM = 'You are a helpful assistant.'
History = List[Union[Tuple[str, str], List[str]]]
Prompt = List[Union[str, List[int], List[str]]]
StopWords = Prompt
Context = Union[str, List[int]]
TEMPLATE_MAPPING: Dict[str, Dict[str, Any]] = {}
class TemplateType:
# text-generation
default_generation = 'default-generation'
chatglm_generation = 'chatglm-generation'
qwen_vl_generation = 'qwen-vl-generation'
qwen_audio_generation = 'qwen-audio-generation'
# chat
default = 'default'
qwen = 'qwen'
qwen2_5 = 'qwen2_5'
qwen_vl = 'qwen-vl'
qwen_audio = 'qwen-audio'
qwen2_audio = 'qwen2-audio'
qwen2_audio_generation = 'qwen2-audio-generation'
qwen2_vl = 'qwen2-vl'
qwen2_vl_generation = 'qwen2-vl-generation'
modelscope_agent = 'modelscope-agent'
baichuan = 'baichuan'
chatglm2 = 'chatglm2'
chatglm3 = 'chatglm3'
chatglm4 = 'chatglm4'
codegeex4 = 'codegeex4'
llama = 'llama' # llama2
llama3 = 'llama3'
llama3_1_omni = 'llama3_1-omni'
llama3_2 = 'llama3_2'
llama3_2_vision = 'llama3_2-vision'
llama3_2_vision_generation = 'llama3_2-vision-generation'
reflection = 'reflection'
longwriter_llama3 = 'longwriter-llama3'
# llava-hf
llava1_5 = 'llava1_5'
llava_mistral = 'llava-mistral'
llava_vicuna = 'llava-vicuna'
llava_yi = 'llava-yi'
llama3_llava_next_hf = 'llama-llava-next-hf'
llava_next_llama3 = 'llava-next-llama3'
llava_qwen_hf = 'llama-qwen-hf'
llava_onevision_qwen = 'llava-onevision-qwen'
# llava-video
llava_next_video = 'llava-next-video'
llava_next_video_yi = 'llava-next-video-yi'
# lmms-lab:llava
llama3_llava_next = 'llama3-llava-next'
llava_qwen = 'llava-qwen'
# xtuner:llava
llava_llama_instruct = 'llava-llama-instruct'
idefics3 = 'idefics3'
mistral_nemo = 'mistral-nemo'
pixtral = 'pixtral'
openbuddy = 'openbuddy'
openbuddy2 = 'openbuddy2'
internlm = 'internlm'
internlm2 = 'internlm2'
internlm_xcomposer2 = 'internlm-xcomposer2'
internlm_xcomposer2_4khd = 'internlm-xcomposer2-4khd'
internlm_xcomposer2_5 = 'internlm-xcomposer2_5'
internvl = 'internvl'
internvl2 = 'internvl2'
internvl_phi3 = 'internvl-phi3'
internvl2_phi3 = 'internvl2-phi3'
florence = 'florence'
yi_coder = 'yi-coder'
yi_vl = 'yi-vl'
yuan = 'yuan'
xverse = 'xverse'
ziya = 'ziya'
skywork = 'skywork'
bluelm = 'bluelm'
zephyr = 'zephyr'
sus = 'sus'
deepseek = 'deepseek'
numina_math = 'numina-math'
deepseek_coder = 'deepseek-coder'
deepseek_vl = 'deepseek-vl'
deepseek2 = 'deepseek2'
deepseek2_5 = 'deepseek2_5'
codefuse_codellama = 'codefuse-codellama'
codefuse = 'codefuse'
cogvlm = 'cogvlm'
cogvlm2_video = 'cogvlm2-video'
glm4v = 'glm4v'
cogagent_chat = 'cogagent-chat'
cogagent_instruct = 'cogagent-instruct'
orion = 'orion'
minicpm = 'minicpm'
minicpm_v = 'minicpm-v'
minicpm_v_v2_5 = 'minicpm-v-v2_5'
minicpm_v_v2_6 = 'minicpm-v-v2_6'
gemma = 'gemma'
paligemma = 'paligemma'
mplug_owl2 = 'mplug-owl2'
mplug_owl3 = 'mplug_owl3'
wizardlm2_awq = 'wizardlm2-awq'
wizardlm2 = 'wizardlm2'
atom = 'atom'
phi3 = 'phi3'
phi3_vl = 'phi3-vl'
telechat = 'telechat'
telechat_v2 = 'telechat-v2'
dbrx = 'dbrx'
mengzi = 'mengzi'
c4ai = 'c4ai'
chatml = 'chatml'
got_ocr2 = 'got_ocr2'
# compatibility. (Deprecated)
default_generation_bos = 'default-generation-bos'
yi = 'yi'
yi1_5 = 'yi1_5'
@classmethod
def get_template_name_list(cls) -> List[str]:
res = []
for k in cls.__dict__.keys():
if k.startswith('__') or k == 'get_template_name_list':
continue
res.append(cls.__dict__[k])
return res
class StopWordsCriteria(StoppingCriteria):
# The returned sentence includes stop words.
def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: StopWords, **tokenizer_kwargs) -> None:
self.tokenizer = tokenizer
self.stop_words = stop_words
self.tokenizer_kwargs = tokenizer_kwargs
self.start_idx = -1
def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> bool:
if self.start_idx == -1:
self.start_idx = len(input_ids[0]) - 1
tokenizer = self.tokenizer
stop_words = self.stop_words
# [-20:]: Assuming the end tokens do not exceed 20 tokens,
# to avoid input_ids being too long and affecting efficiency.
text = tokenizer.decode(input_ids[0, self.start_idx:][-20:], **self.tokenizer_kwargs)
for stop_word in stop_words:
if isinstance(stop_word, str):
if stop_word in text:
return True
else: # list
if len(stop_word) > 0 and input_ids[0].tolist()[-len(stop_word):] == stop_word:
return True
return False
def is_deepspeed_enabled():
return strtobool(os.environ.get('ACCELERATE_USE_DEEPSPEED', 'False'))
class Template:
"""A template class for all supported models.
Args:
prefix: Prefix tokens before the first turn's prompt
prompt: A list of elements whose types are str and list of integers. The input query part of every turn.
chat_sep: The chat separators between every turn.
suffix: The end tokens after the chat finished.
default_system: A default system instruction.
system_prefix: The prefix if the `system` is not empty.
auto_add_bos: By default, the bos_token is not added. The auto_add_bos option will determine
whether to add it based on `tokenizer.encode('')`.
Examples:
<start>system\nYou are a helpful assistant!<end>\n<bos><start>Who are you?<end>\n<start>assistant:I am a robot<end>\n<start>Who are you?<end>\n<start>assistant:I am a robot<end> # noqa
--------------- -------------------------- --- ----- ------------ ----------------------- ----------- ---- -----
system_prefix system prefix prompt query prompt response chat_sep suffix
"""
special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>']
special_keys = ['images', 'videos', 'audios', 'objects']
grounding_type = 'norm_1000'
image_placeholder = ['<image>']
load_medias = True
compute_per_round_loss = True # for rlhf
output_prompt_answer = False # for encoder-decoder & kto
def __init__(self,
prefix: Prompt,
prompt: Prompt,
chat_sep: Optional[Prompt],
suffix: Prompt,
default_system: Optional[str] = None,
system_prefix: Optional[Prompt] = None,
auto_add_bos: bool = False,
tools_prompt: str = 'react_en',
tool_prompt: Optional[Prompt] = None,
padding_side: Literal['left', 'right'] = 'right') -> None:
# check
for x in [prefix, prompt, chat_sep, suffix, system_prefix]:
assert x is None or isinstance(x, list)
if default_system == '':
default_system = None
if self._has_system(prefix):
assert system_prefix is None, 'The prefix already contains {{SYSTEM}}.'
system_prefix = prefix
prefix = self._replace_system(prefix)
self.prefix = prefix
self.system_prefix = system_prefix
if self.system_prefix is None and not any(['{{SYSTEM}}' in context for context in prompt]):
assert default_system is None, 'The template does not support `system`.'
self.prompt = prompt
self.chat_sep = chat_sep
self.support_multi_round = self.chat_sep is not None
self.suffix = suffix
self.default_system = default_system
self.use_default_system = True
self.auto_add_bos = auto_add_bos
self._is_init = False
self.tools_prompt = tools_prompt
self.tool_prompt = tool_prompt if tool_prompt is not None else self.prompt # default as user
self._is_vllm = False
self._is_lmdeploy = False
self._is_training = False
self.padding_side = padding_side
@staticmethod
def _replace_system(prefix: Prompt) -> Prompt:
return [p.replace('{{SYSTEM}}', '') for p in prefix if '{{SYSTEM}}' in p]
@staticmethod
def _has_system(prefix: Prompt) -> bool:
return any(['{{SYSTEM}}' in p for p in prefix])
@staticmethod
def _preprocess_prompt(tokenizer: PreTrainedTokenizerBase, value: Optional[Prompt]) -> Optional[Prompt]:
"""Turn `eos_token_id` to token id
e.g. [['eos_token_id']] -> [[2]]
"""
if value is None:
return None
res_value = []
for v in value:
if isinstance(v, list):
res_v = []
for sub_v in v:
if isinstance(sub_v, str):
sub_v = getattr(tokenizer, sub_v)
res_v.append(sub_v)
v = res_v
res_value.append(v)
return res_value
def _init_template(self,
tokenizer: PreTrainedTokenizerBase,
default_system: Optional[str] = None,
max_length: Optional[int] = None,
truncation_strategy: Literal['delete', 'truncation_left'] = 'delete',
model: torch.nn.Module = None,
**kwargs) -> None:
assert self._is_init is False, 'The template has been initialized.'
self.is_multimodal = getattr(tokenizer, 'is_multimodal', None)
self._is_init = True
self.tokenizer = tokenizer
# if default_system is None. not change self.default_system
if default_system == '':
self.default_system = None
elif default_system is not None:
assert self.system_prefix is not None, (
f'The template does not support `system`, template_type: {getattr(self, "template_type", None)}')
self.default_system = default_system
self.max_length = max_length
self.truncation_strategy = truncation_strategy
self.model = model
self.ref_model = kwargs.get('ref_model', None)
self.use_loss_scale = kwargs.get('use_loss_scale', False)
self.response_loss_scale_map = kwargs.get('loss_scale_map', None)
self.query_loss_scale_map = None
if self.response_loss_scale_map is not None:
if 'query' in self.response_loss_scale_map and isinstance(self.response_loss_scale_map['query'], dict):
self.query_loss_scale_map = self.response_loss_scale_map['query']
if 'response' in self.response_loss_scale_map and isinstance(self.response_loss_scale_map['response'],
dict):
self.response_loss_scale_map = self.response_loss_scale_map['response']
self.sequence_parallel_size = kwargs.get('sequence_parallel_size', 1)
self.rescale_image = kwargs.get('rescale_image', -1)
for key in ['prefix', 'prompt', 'chat_sep', 'suffix', 'system_prefix']:
value = getattr(self, key)
value = self._preprocess_prompt(tokenizer, value)
setattr(self, key, value)
@contextmanager
def training_context(self):
if self.model is None:
self._is_training = True
yield
self._is_training = False
return
self._is_training = True
def _pre_forward_hook(module, args, kwargs):
from .utils import to_device
if '_data' in kwargs:
res_extra = []
data = kwargs.pop('_data')
for d in data:
res_extra.append(self._post_encode(module, d))
kwargs.update(to_device(self.data_collator(res_extra), module.device))
if 'inputs_embeds' in kwargs:
kwargs.pop('input_ids', None)
if isinstance(module, PeftModel):
parameters = inspect.signature(module.base_model.model.forward).parameters
else:
parameters = inspect.signature(module.forward).parameters
if 'position_ids' not in parameters:
kwargs.pop('position_ids', None)
return args, kwargs
parameters = inspect.signature(self.model.register_forward_pre_hook).parameters
handle, handle2 = None, None
deepspeed = None
if 'with_kwargs' in parameters:
handle = self.model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True)
if self.ref_model:
handle2 = self.ref_model.register_forward_pre_hook(_pre_forward_hook, with_kwargs=True)
if is_deepspeed_zero3_enabled():
import deepspeed
_old_initialize = deepspeed.initialize
@wraps(_old_initialize)
def _initialize(*args, **kwargs):
res = _old_initialize(*args, **kwargs)
self.model._forward_pre_hooks.move_to_end(handle.id)
if self.ref_model:
self.ref_model._forward_pre_hooks.move_to_end(handle2.id)
return res
deepspeed.initialize = _initialize
yield
self._is_training = False
if handle:
handle.remove()
if handle2:
handle2.remove()
if deepspeed:
deepspeed.initialize = _old_initialize
@contextmanager
def vllm_context(self):
self._is_vllm = True
yield
self._is_vllm = False
@contextmanager
def lmdeploy_context(self):
self._is_lmdeploy = True
yield
self._is_lmdeploy = False
def _post_encode(self, model, data: Any) -> Dict[str, Any]:
return {}
def check_example(self, example: Dict[str, Any]) -> None:
pass
def add_default_tags(self, example: Dict[str, Any]) -> None:
history: History = deepcopy(example.get('history') or [])
query: str = example.get('query') or ''
response: str = example.get('response') or ''
history.append([query, response])
for media_key, media_tag in [('videos', '<video>'), ('images', '<image>'), ('audios', '<audio>')]:
if example.get(media_key):
infer_media_type = TEMPLATE_MAPPING[self.template_type].get('infer_media_type')
if infer_media_type == 'round':
n_round = len(example[media_key])
assert n_round == len(history)
for i, h, m in zip(range(n_round), history, example[media_key]):
content = f'{h[0]}\n{h[1]}'
num_media_tags = len(re.findall(media_tag, content))
if m:
assert num_media_tags <= 1, (
'The model includes at most one media per round. However, '
f'this round contains {num_media_tags} media_tags. query: {h[0]}, response: {h[1]}')
if num_media_tags == 0:
h[0] = media_tag + h[0]
else:
assert num_media_tags == 0, f'Missing media. query: {h[0]}'
history[i][0] = h[0]
example[media_key] = [m for m in example[media_key] if m]
else:
num_media_tags = len(re.findall(media_tag, '\n'.join([f'{h[0]}\n{h[1]}' for h in history])))
example[media_key] = [m for m in example[media_key] if m]
num_media = len(example[media_key])
num_new_tags = num_media - num_media_tags
assert num_new_tags >= 0, f'Number of media: {num_media}, number of media_tags: {num_media_tags}'
history[0][0] = media_tag * num_new_tags + history[0][0]
example['query'] = history[-1][0]
if example.get('response') is not None:
example['response'] = history[-1][1]
example['history'] = history[:-1]
def replace_media_tags(self, example) -> None:
if self.is_multimodal in {True, None}:
for k, tag, pattern in zip(['images', 'audios', 'videos'], ['<image>', '<audio>', '<video>'],
[r'<img>(.+?)</img>', r'<audio>(.+?)</audio>', r'<video>(.+?)</video>']):
example['query'], example['response'], example['history'], medias_path = replace_img_tag(
example.get('query'), example.get('response'),
example.get('history') or [], tag, pattern)
if example.get(k) and medias_path:
raise ValueError(f'Do not mix use the {pattern} tag and {tag} tag.')
example[k] = example.get(k) or [] + medias_path
def _preprocess_media(self, example):
from .media import MediaTag
from .client_utils import decode_base64
# Format media_keys to list
for media_key in MediaTag.media_keys.values():
if example.get(media_key) and not isinstance(example[media_key], (tuple, list)):
# change images field to list
example[media_key] = [example[media_key]]
self.replace_media_tags(example)
# Add default tags to examples to note where to put the medias into the sequence
self.add_default_tags(example)
# Format objects(groundings/refs) to json
if example.get('objects') and isinstance(example['objects'], str):
# reload grounding from str
example['objects'] = json.loads(example['objects'])
objects = []
for object in example['objects']:
# Compatible with list format
if isinstance(object, list):
object = {
'caption': object[0],
'bbox': object[1],
'bbox_type': None,
'image': 0,
}
objects.append(object)
example['objects'] = objects
# Load image into PIL format
images = example.get('images') or []
if images:
if example.get('objects') or self.load_medias or self._is_lmdeploy or self._is_vllm:
images = load_batch(images, load_image) # base64/local_path -> PIL.Image
if example.get('objects'):
# Normalize grounding bboxes
self.normalize_bbox(example['objects'], images, to_type=self.grounding_type)
if self.load_medias and self.grounding_type != 'real':
images = [rescale_image(img, self.rescale_image) for img in images]
if not self.load_medias and not self._is_lmdeploy and not self._is_vllm: # fix pt & qwen-vl
images = decode_base64(images=images)['images'] # PIL.Image/base64 -> local_path
example['images'] = images
# Check the example that whether matching the very template's rules
self.check_example(example)
def preprocess(self, example):
# Duplicate example and create a new one to prepare in-place changes
example = example.copy()
template_type: Optional[str] = getattr(self, 'template_type', None)
tools: Union[List[Any], str] = example.get('tools') or []
# Template needs to be initialized
if not self._is_init:
raise ValueError(
'Template is not initialized, please use the `get_template` function to obtain the template.')
# Reset system (by default value and agent tools)
system: Optional[str] = example.get('system', None)
if system is None:
if self.use_default_system:
system = self.default_system
elif system == '':
system = None
else:
assert self.system_prefix is not None, (
f'The template does not support `system`, template_type: {template_type}')
if tools:
if isinstance(tools, str):
tools = json.loads(tools)
if system is None:
system = ''
system += get_tools_prompt(tools, self.tools_prompt)
example['system'] = system
# Check whether this template supports multi-round
history: History = example.get('history') or []
if len(history) > 0:
assert self.support_multi_round, (
f'The template does not support multi-round chat, template_type: {template_type}')
# Set history_roles
history_roles: Optional[History] = example.get('history_roles')
if history_roles is None:
example['history_roles'] = [['user', 'assistant'] for _ in range(len(history))]
self._preprocess_media(example)
return example
def encode(self, example: Dict[str, Any], streaming: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any]]:
from .utils import to_device
example = self.preprocess(example)
_encode = self._encode
if self._is_lmdeploy or self._is_vllm:
assert self.is_multimodal is not None, 'Please use the get_model_tokenizer function.'
_encode = MethodType(Template._encode, self)
res = _encode(example)
inputs = res[0]
if not self._is_training and '_data' in inputs:
data = inputs.pop('_data')
data = to_device(data, self.model.device)
inputs.update(self._post_encode(self.model, data))
return res if not streaming else inputs
async def prepare_lmdeploy_inputs(self, inputs: Dict[str, Any]) -> None:
images = inputs.pop('images', None) or []
if len(images) == 0:
return
from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX
input_ids = inputs['input_ids']
idx_list = _findall(input_ids, -100)
assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
idx_list.insert(0, -1)
new_input_ids = []
ranges = []
for i in range(len(idx_list) - 1):
_range = []
new_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]]
_range.append(len(new_input_ids))
new_input_ids += [IMAGE_DUMMY_TOKEN_INDEX] * images[i].shape[0]
_range.append(len(new_input_ids))
ranges.append(_range)
new_input_ids += input_ids[idx_list[-1] + 1:]
inputs['input_embeddings'] = images
inputs['input_embedding_ranges'] = ranges
inputs['input_ids'] = new_input_ids
def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""return: inputs, tokenizer_kwargs"""
query: str = example.get('query') or ''
query_role: str = example.get('query_role') or 'user'
response: Optional[str] = example.get('response')
history: History = example.get('history') or []
history_roles: Optional[History] = example.get('history_roles')
system: Optional[str] = example.get('system', None)
is_multi_modal: bool = any([example.get(key) for key in Template.special_keys])
inputs, tokenizer_kwargs = self._concat_and_tokenize(
query,
query_role,
response,
history,
history_roles,
system,
self.truncation_strategy,
auto_add_bos=self.auto_add_bos,
example=example,
is_multi_modal=is_multi_modal)
if self._is_lmdeploy or self._is_vllm:
for key in ['images', 'audios', 'videos']:
inputs[key] = example.get(key)
if inputs.get('labels') is None:
inputs.pop('loss_scale', None)
return inputs, tokenizer_kwargs
def _concat_context_list(
self,
context_list: List[Context],
res_context_list: List[Context], # inplace
loss_scale_list: List[float], # inplace
system: Optional[str] = None,
query: Optional[str] = None,
response: Optional[str] = None,
round0: Optional[int] = None,
compute_loss: bool = True) -> None:
# concat context list and replace placeholder
round1 = None
if round0 is not None:
round1 = str(round0 + 1)
round0 = str(round0)
for context in context_list:
if isinstance(context, str):
if '{{RESPONSE}}' == context:
assert response is not None
if compute_loss:
content_part, weight_part = calculate_loss_scale(query, response, self.use_loss_scale,
self.response_loss_scale_map,
self.query_loss_scale_map)
else:
content_part, weight_part = [response], [0.]
res_context_list.extend(content_part)
loss_scale_list.extend(weight_part)
continue
old_str_list = ['{{SYSTEM}}', '{{QUERY}}', '{{ROUND0}}', '{{ROUND1}}']
new_str_list = [system, query, round0, round1]
for (old_str, new_str) in zip(old_str_list, new_str_list):
if new_str is not None and old_str in context:
assert isinstance(new_str, str), f'new_str: {new_str}'
context = context.replace(old_str, new_str)
if len(context) == 0:
continue
res_context_list.append(context)
loss_scale_list.append(0.)
def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float],
**kwargs) -> Tuple[List[Context], List[float]]:
is_multi_modal: bool = kwargs.pop('is_multi_modal', False)
if is_multi_modal:
context_list, loss_scale_list = self.split_special_tokens(context_list, loss_scale_list)
context_list, loss_scale_list = self.pre_tokenize(context_list, loss_scale_list, **kwargs)
res: List[Context] = [] # result of context_list
res_loss_scale: List[float] = [] # result of loss_scale_list
temp: List[str] = []
temp_loss_scale = 0.
for i, (context, loss_scale) in enumerate(zip(context_list, loss_scale_list)):
if isinstance(context, str) and (loss_scale == temp_loss_scale):
temp.append(context)
else:
if len(temp) > 0:
res.append(''.join(temp))
res_loss_scale.append(temp_loss_scale)
temp.clear()
if isinstance(context, str): # loss_scale diff
temp.append(context)
else:
res.append(context)
res_loss_scale.append(loss_scale)
temp_loss_scale = loss_scale
if len(temp) > 0:
res.append(''.join(temp))
res_loss_scale.append(temp_loss_scale)
return res, res_loss_scale
@staticmethod
def split_special_tokens(context_list: List[Context],
loss_scale_list: List[float]) -> Tuple[List[Context], List[float]]:
from swift.utils.utils import split_str_parts_by
res: List[Context] = []
loss_scale_res: List[float] = []
from .utils import fetch_one
for context, loss_scale in zip(context_list, loss_scale_list):
contexts = []
if isinstance(fetch_one(context), str):
for d in split_str_parts_by(context, Template.special_tokens):
contexts.extend([d['key'], d['content']])
contexts = [c for c in contexts if c]
res.extend(contexts)
loss_scale_res.extend([loss_scale] * len(contexts))
else:
res.append(context)
loss_scale_res.append(loss_scale)
return res, loss_scale_res
def _tokenize(self, context, **tokenizer_kwargs):
return self.tokenizer(
context, return_attention_mask=False, add_special_tokens=False, **tokenizer_kwargs)['input_ids']
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
example: Dict[str, Any]) -> List[Context]:
if media_type == 'image':
if self._is_lmdeploy:
return [[-100]]
else:
return self.image_placeholder
elif media_type == 'video':
return ['<video>']
elif media_type == 'audio':
return ['<audio>']
def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]:
objects = example.get('objects')
if objects:
object_ = objects[index]
return [object_['caption']]
else:
return ['<ref-object>']
def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]:
objects = example.get('objects')
if objects:
object_ = objects[index]
if isinstance(object_['bbox'][0], list):
all_objects = ''
for sub_object in object_['bbox']:
all_objects += (f'[({sub_object[0]},{sub_object[1]}),' f'({sub_object[2]},{sub_object[3]})],')
all_objects = all_objects[:-1]
return [all_objects]
else:
return [f'[({object_["bbox"][0]},{object_["bbox"][1]}),({object_["bbox"][2]},{object_["bbox"][3]})]']
else:
return ['<bbox>']
@classmethod
def normalize_bbox(cls, objects, images, to_type: Literal['real', 'norm_1000', 'norm_1']):
if not objects or not images:
return
for object in objects:
bbox = object['bbox']
bbox_type = object['bbox_type']
idx = object['image']
image = images[idx]
if bbox_type == 'real':
if to_type == 'real':
continue
width, height = image.width, image.height
if isinstance(bbox[0], list):
bboxes = []
for _box in bbox:
bboxes.append([
int(coord / dim * 999) if to_type == 'norm_1000' else coord / dim
for coord, dim in zip(_box, [width, height, width, height])
])
object['bbox'] = bboxes
else:
object['bbox'] = [
int(coord / dim * 999) if to_type == 'norm_1000' else coord / dim
for coord, dim in zip(bbox, [width, height, width, height])
]
object['bbox_type'] = to_type
elif bbox_type == 'norm_1000':
if to_type == 'norm_1000':
continue
if to_type == 'norm_1':
object['bbox'] = [coord / 999. for coord in bbox]
elif to_type == 'real':
width, height = image.width, image.height
object['bbox'] = [
int(coord / 999. * dim) for coord, dim in zip(bbox, [width, height, width, height])
]
object['bbox_type'] = to_type
elif bbox_type == 'norm_1':
if to_type == 'norm_1':
continue
if to_type == 'norm_1000':
object['bbox'] = [int(coord * 999) for coord in bbox]
elif to_type == 'real':
width, height = image.width, image.height
object['bbox'] = [int(coord * dim) for coord, dim in zip(bbox, [width, height, width, height])]
object['bbox_type'] = to_type
def pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float],
**kwargs) -> Tuple[List[Context], List[float]]:
# replace tag/object/box
example = kwargs.get('example') # get x_index
res: List[Context] = [] # result of context_list
res_loss_scale: List[float] = [] # result of loss_scale_list
for k in ['image', 'video', 'audio']:
example[f'{k}_index'] = 0
for context, loss_scale in zip(context_list, loss_scale_list):
for k in ['image', 'video', 'audio']:
if context == f'<{k}>':
c_list = self.replace_tag(k, example[f'{k}_index'], example)
example[f'{k}_index'] += 1
loss_scale = 0.
break
else:
if context == '<ref-object>':
c_list = self.replace_object(example.get('object_index', 0), example)
example['object_index'] = example.get('object_index', 0) + 1
elif context == '<bbox>':
c_list = self.replace_box(example.get('box_index', 0), example)
example['box_index'] = example.get('box_index', 0) + 1
else:
c_list = [context]
res += c_list
res_loss_scale += [loss_scale] * len(c_list)
return res, res_loss_scale
def _encode_context_list(
self,
context_list: List[Context],
loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float], Dict[str, Any]]:
"""return: input_ids, labels, tokenizer_kwargs"""
input_ids: List[int] = []
labels: List[int] = []
loss_scale: List[float] = []
tokenizer_kwargs = {}
if loss_scale_list is None:
loss_scale_list = [0.] * len(context_list)
for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
if isinstance(context, str):
# tokenizer_kwargs is the returned tokenizer_kwargs,
# while curr_tokenizer_kwargs is the tokenizer_kwargs for the current context.
curr_tokenizer_kwargs = self._get_tokenizer_kwargs(context)
self._concat_tokenizer_kwargs(tokenizer_kwargs, curr_tokenizer_kwargs)
token_list = self._tokenize(context, **curr_tokenizer_kwargs)
else:
token_list = context
input_ids += token_list
if loss_scale_list[i] > 0.0:
labels += token_list
else:
labels += [-100] * len(token_list)
loss_scale.extend([loss_weight] * len(token_list))
return input_ids, labels, loss_scale, tokenizer_kwargs
@staticmethod
def use_dynamic_eos(labels: List[int], suffix_tokens_id: List[int]) -> None:
suffix_len = len(suffix_tokens_id)
start = 0
for i in range(1, len(labels)):
if labels[i - 1] >= 0 and labels[i] == -100:
start = i
if start > 0 and labels[i - 1] == -100 and labels[i] >= 0:
# [0, 1, 2, -100(start), -100, 3(i), 4]
length = i - start
if length >= suffix_len:
labels[start:start + suffix_len] = suffix_tokens_id
def _concat_and_tokenize(self,
query: str,
query_role: str,
response: Optional[str],
history: History,
history_roles: History,
system: Optional[str],
truncation_strategy: str,
auto_add_bos: bool = False,
**kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
return: inputs, tokenizer_kwargs
"""
history = history.copy()
res_context_list: List[Context] = []
loss_scale_list: List[float] = []
if auto_add_bos:
bos_token_id = self.tokenizer.bos_token_id
if isinstance(bos_token_id, int) and bos_token_id in self.tokenizer.encode(''):
res_context_list.append([bos_token_id])
loss_scale_list.append(0.)
prompt = self.prompt.copy()
if system is None:
prompt = [context for context in prompt if '{{SYSTEM}}' not in context]
if system is None or any(['{{SYSTEM}}' in context for context in prompt]):
prefix = self.prefix
else:
prefix = self.system_prefix
self._concat_context_list(prefix, res_context_list, loss_scale_list, system=system)
history.append([query, response])
history_roles.append([query_role, 'assistant'])
for i, ((q, r), (qr, rr)) in enumerate(zip(history, history_roles)):
context_list = self.tool_prompt.copy() if qr == 'tool' else prompt.copy()
extra_context_list = []
is_suffix = False
if i < len(history) - 1:
context_list = [context for context in context_list if '{{SYSTEM}}' not in context]
context_list.append('{{RESPONSE}}')
if history[i + 1][0]:
extra_context_list = self.chat_sep
elif r is not None:
# last response
context_list.append('{{RESPONSE}}')
extra_context_list = self.suffix
is_suffix = True
if q or r:
self._concat_context_list(
context_list,
res_context_list,
loss_scale_list,
query=q,
response=r,
system=system,
round0=i,
compute_loss=self.compute_per_round_loss or is_suffix)
res_context_list += extra_context_list
loss_scale_list += ([1.] if is_suffix else [0.]) * len(extra_context_list)
inputs = {}
if self.output_prompt_answer:
# tokenizer_kwargs: use prompt
answer_len = len(extra_context_list) + bool(response is not None)
total_len = len(res_context_list)
for key, _slice in zip(['answer', 'prompt'],
[slice(total_len - answer_len, total_len),
slice(0, total_len - answer_len)]):
_res_context_list, _loss_scale_list = self._simplify_context_list(res_context_list[_slice],
loss_scale_list[_slice], **kwargs)
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
_res_context_list, _loss_scale_list)
inputs[f'{key}_input_ids'], inputs[f'{key}_labels'] = input_ids, labels
if self.use_loss_scale:
inputs[f'{key}_loss_scale'] = loss_scale
input_ids = inputs['prompt_input_ids'] + inputs['answer_input_ids']
labels = inputs['prompt_labels'] + inputs['answer_labels']
if response is None:
assert len(inputs['answer_labels']) == 0
inputs['answer_labels'] = None
else:
res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, **kwargs)
input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
res_context_list, loss_scale_list)
if labels is not None:
self.use_dynamic_eos(labels, self._encode_context_list(self.suffix)[0])
if response is None:
labels = None
if self.max_length is not None:
if truncation_strategy == 'delete' and len(input_ids) > self.max_length:
logger.warn(f'Current length of row({len(input_ids)}) is larger'
f' than the max_length({self.max_length}), deleted.')
return {}, {}
input_ids = input_ids[-self.max_length:]
if labels is not None:
labels = labels[-self.max_length:]
if loss_scale is not None:
loss_scale = loss_scale[-self.max_length:]
inputs['input_ids'] = input_ids
inputs['labels'] = labels
if self.use_loss_scale:
inputs['loss_scale'] = loss_scale
return inputs, tokenizer_kwargs
def _get_tokenizer_kwargs(self, context: str) -> Dict[str, Any]:
"""return: curr_tokenizer_kwargs"""
return {}
def _concat_tokenizer_kwargs(self, tokenizer_kwargs: Dict[str, Any], curr_tokenizer_kwargs: Dict[str, Any]) -> None:
assert len(tokenizer_kwargs) == 0
@staticmethod
def pad_sequence(sequences: List[torch.Tensor],
padding_value: float = 0.,
padding_side: Literal['right', 'left'] = 'right'):
padding_right = padding_side == 'right'
if padding_right:
return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
max_len = max([s.size(0) for s in sequences])
padded_sequences = []
for seq in sequences:
pad_length = max_len - seq.size(0)
pad_tuple = [0] * ((seq.dim() - 1) * 2) + [pad_length, 0]
padded_seq = F.pad(seq, tuple(pad_tuple), 'constant', padding_value)
padded_sequences.append(padded_seq)
return torch.stack(padded_sequences)
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
"""
Args:
batch(`List[Dict[str, Any]]`): The input data in batch
padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
will be padded to the `longest`
"""