Skip to content

Commit

Permalink
Support glm 4v (#1947)
Browse files Browse the repository at this point in the history
* support glm-4-9b-chat

* support glm-4v

* update docs

* fix

* resolve comment
  • Loading branch information
RunningLeon authored Jul 12, 2024
1 parent d8bd412 commit 49208aa
Show file tree
Hide file tree
Showing 17 changed files with 308 additions and 60 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ For detailed inference benchmarks in more devices and more settings, please refe
<li>CogVLM2-Chat (19B)</li>
<li>MiniCPM-Llama3-V-2_5</li>
<li>Phi-3-vision (4.2B)</li>
<li>GLM-4V (9B)</li>
</ul>
</td>
</tr>
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ LMDeploy TurboMind 引擎拥有卓越的推理能力,在各种规模的模型
<li>CogVLM2-Chat (19B)</li>
<li>MiniCPM-Llama3-V-2_5</li>
<li>Phi-3-vision (4.2B)</li>
<li>GLM-4V (9B)</li>
</ul>
</td>
</tr>
Expand Down
1 change: 1 addition & 0 deletions docs/en/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@ The TurboMind engine doesn't support window attention. Therefore, for models tha
| LLaVA(1.5,1.6) | 7B-34B | Yes | No | No |
| InternVL-Chat(v1.5) | 2B-26B | Yes | No | No |
| Gemma2 | 9B-27B | Yes | No | No |
| GLM4 | 9B | Yes | No | No |
1 change: 1 addition & 0 deletions docs/zh_cn/supported_models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,4 @@ turbomind 引擎不支持 window attention。所以,对于应用了 window att
| LLaVA(1.5,1.6) | 7B-34B | Yes | No | No |
| InternVL-Chat(v1.5) | 2B-26B | Yes | No | No |
| Gemma2 | 9B-27B | Yes | No | No |
| GLM4 | 9B | Yes | No | No |
2 changes: 2 additions & 0 deletions lmdeploy/archs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def check_vl_llm(config: dict) -> bool:
return True
elif arch == 'MultiModalityCausalLM' and 'language_config' in config:
return True
elif arch == 'ChatGLMModel' and 'vision_config' in config:
return True
elif arch in supported_archs:
return True
return False
Expand Down
52 changes: 39 additions & 13 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ def match(cls, model_path: str) -> Optional[str]:
model_path (str): the model path used for matching.
"""
path = model_path.lower()
if 'chatglm' in path and 'chatglm3' not in path:
if 'chatglm2' in path:
return 'chatglm'


Expand Down Expand Up @@ -1523,21 +1523,20 @@ def match(cls, model_path: str) -> Optional[str]:
return 'internvl2-phi3'


@MODELS.register_module(name='glm4')
@MODELS.register_module(name='chatglm3')
class Glm4Chat(BaseChatTemplate):
"""Chat template of InternLM model."""
class ChatGLM3(BaseChatTemplate):
"""Chat template of chatglm3 model."""

def __init__(self,
system='<|system|>\n',
system='<|system|>\n ',
meta_instruction=None,
eosys='',
user='<|user|>\n',
user='<|user|>\n ',
eoh='',
assistant='<|assistant|>\n',
assistant='<|assistant|>\n ',
eoa='',
separator='',
stop_words=['<|user|>', '<|endoftext|>', '<|observation|>'],
stop_words=['<eos>'],
**kwargs):
super().__init__(system=system,
meta_instruction=meta_instruction,
Expand All @@ -1549,7 +1548,7 @@ def __init__(self,
separator=separator,
stop_words=stop_words,
**kwargs)
self.start = '[gMASK]<sop>'
self.start = '[gMASK]sop'

def get_prompt(self, prompt, sequence_start=True):
"""Return the prompt that is concatenated with other elements in the
Expand All @@ -1562,7 +1561,7 @@ def get_prompt(self, prompt, sequence_start=True):
Returns:
str: the concatenated prompt
"""
prompt = super(Glm4Chat, self).get_prompt(prompt, sequence_start)
prompt = super().get_prompt(prompt, sequence_start)
if sequence_start:
prompt = self.start + prompt
return prompt
Expand All @@ -1578,8 +1577,35 @@ def messages2prompt(self, messages, sequence_start=True, **kwargs):
"""
if isinstance(messages, str):
return self.get_prompt(messages, sequence_start)
return self.start + super(Glm4Chat, self).messages2prompt(
messages, sequence_start, **kwargs)
return self.start + super().messages2prompt(messages, sequence_start,
**kwargs)

@classmethod
def match(cls, model_path: str) -> Optional[str]:
"""Return the model_name that was registered to MODELS.
Args:
model_path (str): the model path used for matching.
"""
path = model_path.lower()
if 'chatglm3' in path:
return 'chatglm3'


@MODELS.register_module(name='glm4')
class Glm4Chat(ChatGLM3):
"""Chat template of glm-4 model."""

def __init__(self,
system='<|system|>\n',
user='<|user|>\n',
assistant='<|assistant|>\n',
**kwargs):
super().__init__(system=system,
user=user,
assistant=assistant,
**kwargs)
self.start = '[gMASK]<sop>'

@classmethod
def match(cls, model_path: str) -> Optional[str]:
Expand All @@ -1589,7 +1615,7 @@ def match(cls, model_path: str) -> Optional[str]:
model_path (str): the model path used for matching.
"""
path = model_path.lower()
if 'glm-4' in path or 'chatglm3' in path:
if 'glm-4' in path:
return 'glm4'


Expand Down
1 change: 1 addition & 0 deletions lmdeploy/pytorch/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class ModelConfig:
model_arch: str = None
unused_modules: List[str] = None
auto_model_cls: Any = AutoModelForCausalLM
cogvlm_style: bool = False

def get_head_size(self):
"""get head size."""
Expand Down
23 changes: 14 additions & 9 deletions lmdeploy/pytorch/configurations/chatglm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,17 @@ def build(cls, hf_config, model_path: str = None):
if bos_token_id is None:
bos_token_id = hf_config.pad_token_id
init_kwargs = dict(empty_init=False)
return ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_layers,
num_attention_heads=hf_config.num_attention_heads,
num_key_value_heads=hf_config.multi_query_group_num,
bos_token_id=bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
vocab_size=hf_config.padded_vocab_size,
init_kwargs=init_kwargs)
cfg = ModelConfig(hidden_size=hf_config.hidden_size,
num_layers=hf_config.num_layers,
num_attention_heads=hf_config.num_attention_heads,
num_key_value_heads=hf_config.multi_query_group_num,
bos_token_id=bos_token_id,
eos_token_id=hf_config.eos_token_id,
head_dim=head_dim,
vocab_size=hf_config.padded_vocab_size,
init_kwargs=init_kwargs)
# glm-4v
if hasattr(hf_config, 'vision_config'):
cfg.unused_modules = ['transformer.vision']
cfg.cogvlm_style = True
return cfg
1 change: 1 addition & 0 deletions lmdeploy/pytorch/configurations/cogvlm.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def build(cls, hf_config, model_path: str = None):
if getattr(hf_config, 'num_multi_query_heads', None):
cfg.num_key_value_heads = hf_config.num_multi_query_heads
cfg.unused_modules = ['model.vision']
cfg.cogvlm_style = True
torch_dtype = 'bfloat16' if torch.cuda.is_bf16_supported(
) else 'float16'
hf_config.torch_dtype = torch_dtype
Expand Down
2 changes: 1 addition & 1 deletion lmdeploy/pytorch/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def __get_vlm_embeddings():
history_image_nums = None
history_image_token_lengths = None
# only for cogvlm
if self.model_config.model_arch == 'CogVLMForCausalLM':
if self.model_config.cogvlm_style:
(history_image_nums,
history_image_token_lengths) = __get_cogvlm_image_info()

Expand Down
Loading

0 comments on commit 49208aa

Please sign in to comment.