Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove session_len and deprecated short names of the chat templates #2105

Merged
merged 3 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 15 additions & 108 deletions lmdeploy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,7 @@ def from_json(cls, file_or_string):
class BaseModel:
"""Base model."""

def __init__(self,
session_len=2048,
capability='chat',
stop_words=None,
**kwargs):
self.session_len = session_len
def __init__(self, capability='chat', stop_words=None, **kwargs):
self.stop_words = stop_words
self.capability = capability

Expand Down Expand Up @@ -361,8 +356,8 @@ def match(cls, model_path: str) -> Optional[str]:
class MiniGemini(Vicuna):
"""Chat template of vicuna model."""

def __init__(self, session_len=4096, **kwargs):
super().__init__(session_len=session_len, **kwargs)
def __init__(self, **kwargs):
super().__init__(**kwargs)

def get_prompt(self, prompt, sequence_start=True):
return super().get_prompt(prompt, sequence_start)[:-1]
Expand All @@ -384,8 +379,6 @@ def match(cls, model_path: str) -> Optional[str]:
return 'mini-gemini-vicuna'


@MODELS.register_module(name='internlm-chat')
@MODELS.register_module(name='internlm-chat-7b')
@MODELS.register_module(name='internlm')
class InternLMChat7B(BaseChatTemplate):
"""Chat template of InternLM model."""
Expand Down Expand Up @@ -429,48 +422,11 @@ def match(cls, model_path: str) -> Optional[str]:
return 'internlm'


@MODELS.register_module(name='internlm-chat-20b')
@MODELS.register_module(name='internlm-chat-7b-8k')
class InternLMChat7B8K(InternLMChat7B):
"""Chat template and generation parameters of InternLM-Chat-7B-8K and
InternLM-Chat-20B models."""

def __init__(self, session_len=8192, **kwargs):
super(InternLMChat7B8K, self).__init__(**kwargs)
self.session_len = session_len


@MODELS.register_module(name='internlm-20b')
class InternLMBaseModel20B(BaseChatTemplate):
"""Generation parameters of InternLM-20B-Base model."""

def __init__(self, session_len=4096, capability='completion', **kwargs):
super().__init__(session_len=session_len,
capability=capability,
**kwargs)


@MODELS.register_module(
name=['internlm2-1_8b', 'internlm2-7b', 'internlm2-20b'])
class InternLM2BaseModel7B(BaseChatTemplate):
"""Generation parameters of InternLM2-7B-Base model."""

def __init__(self, session_len=32768, capability='completion', **kwargs):
super().__init__(session_len=session_len,
capability=capability,
**kwargs)


@MODELS.register_module(name=[
'internlm2-chat', 'internlm2-chat-1_8b', 'internlm2-chat-7b',
'internlm2-chat-20b'
])
@MODELS.register_module(name='internlm2')
class InternLM2Chat7B(InternLMChat7B):
"""Chat template and generation parameters of InternLM2-Chat-7B."""

def __init__(self,
session_len=32768,
system='<|im_start|>system\n',
user='<|im_start|>user\n',
assistant='<|im_start|>assistant\n',
Expand All @@ -488,8 +444,7 @@ def __init__(self,
self.interpreter = interpreter
self.environment = environment
self.eoenv = eoenv
super(InternLM2Chat7B, self).__init__(session_len=session_len,
system=system,
super(InternLM2Chat7B, self).__init__(system=system,
user=user,
assistant=assistant,
eosys=eosys,
Expand Down Expand Up @@ -607,14 +562,12 @@ def match(cls, model_path: str) -> Optional[str]:
return 'internvl2-internlm2'


@MODELS.register_module(name='internlm-xcomposer2d5')
@MODELS.register_module(name='internlm-xcomposer2')
@MODELS.register_module(name=['internlm-xcomposer2', 'internlm-xcomposer2d5'])
class InternLMXComposer2Chat7B(InternLMChat7B):
"""Chat template and generation parameters of InternLM-XComposer2-7b."""

def __init__(
self,
session_len=4096,
system='[UNUSED_TOKEN_146]system\n',
meta_instruction="""You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).
- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.
Expand All @@ -628,8 +581,7 @@ def __init__(
separator='\n',
stop_words=['[UNUSED_TOKEN_145]'],
**kwargs):
super().__init__(session_len=session_len,
system=system,
super().__init__(system=system,
meta_instruction=meta_instruction,
user=user,
assistant=assistant,
Expand All @@ -654,18 +606,8 @@ def match(cls, model_path: str) -> Optional[str]:
return 'internlm-xcomposer2'


@MODELS.register_module(name='baichuan-7b')
@MODELS.register_module(name='baichuan-base')
class Baichuan7B(BaseChatTemplate):
"""Generation parameters of Baichuan-7B base model."""

def __init__(self, **kwargs):
super().__init__(**kwargs)


@MODELS.register_module(name='baichuan2-7b')
@MODELS.register_module(name='baichuan2')
class Baichuan2_7B(BaseChatTemplate):
class Baichuan2(BaseChatTemplate):
"""Chat template and generation parameters of Baichuan2-7B-Base and
Baichuan2-7B-Chat models."""

Expand Down Expand Up @@ -723,7 +665,7 @@ def match(cls, model_path: str) -> Optional[str]:
return 'puyu'


@MODELS.register_module(name=['llama2', 'llama-2', 'llama-2-chat'])
@MODELS.register_module(name='llama2')
class Llama2(BaseChatTemplate):
"""Chat template of LLaMA2 model."""

Expand Down Expand Up @@ -773,7 +715,6 @@ def __init__(self,
user='<|start_header_id|>user<|end_header_id|>\n\n',
eoh='<|eot_id|>',
stop_words=['<|eot_id|>', '<|end_of_text|>'],
session_len=8192,
**kwargs):
super().__init__(system=system,
meta_instruction=meta_instruction,
Expand All @@ -783,7 +724,6 @@ def __init__(self,
user=user,
eoh=eoh,
stop_words=stop_words,
session_len=session_len,
**kwargs)

def get_prompt(self, prompt, sequence_start=True):
Expand All @@ -809,14 +749,11 @@ def match(cls, model_path: str) -> Optional[str]:
return 'llama3'


@MODELS.register_module(name='qwen-14b')
@MODELS.register_module(name='qwen-7b')
@MODELS.register_module(name='qwen')
class Qwen7BChat(BaseChatTemplate):
"""Chat template for Qwen-7B-Chat."""

def __init__(self,
session_len=8192,
system='<|im_start|>system\n',
meta_instruction='You are a helpful assistant.',
eosys='<|im_end|>\n',
Expand All @@ -836,7 +773,6 @@ def __init__(self,
eoa=eoa,
separator=separator,
stop_words=stop_words,
session_len=session_len,
**kwargs)

@classmethod
Expand All @@ -855,20 +791,17 @@ class CodeLlama(Llama2):

def __init__(self,
meta_instruction='',
session_len=4096,
suffix_first=False,
stop_words=None,
**kwargs):
super().__init__(meta_instruction=meta_instruction,
session_len=session_len,
stop_words=stop_words,
**kwargs)
caps = ['completion', 'infilling', 'chat', 'python']
assert self.capability in caps, \
f'{self.capability} is not supported. ' \
f'The supported capabilities are: {caps}'
self.meta_instruction = meta_instruction
self.session_len = session_len
self.suffix_first = suffix_first
self.stop_words = stop_words
if self.capability == 'infilling':
Expand Down Expand Up @@ -921,7 +854,6 @@ def match(cls, model_path: str) -> Optional[str]:
return 'falcon'


@MODELS.register_module(name='chatglm2-6b')
@MODELS.register_module(name='chatglm')
class ChatGLM2(BaseModel):

Expand Down Expand Up @@ -979,7 +911,7 @@ def match(cls, model_path: str) -> Optional[str]:
return 'chatglm'


@MODELS.register_module(name=['solar', 'solar-70b'])
@MODELS.register_module(name='solar')
class SOLAR(BaseChatTemplate):
"""Chat template of SOLAR model.

Expand All @@ -993,7 +925,6 @@ def __init__(self,
eoh='\n\n',
assistant='### Assistant:\n',
meta_instruction='',
session_len=2048,
**kwargs):
super().__init__(**kwargs)
self.system = system
Expand All @@ -1002,7 +933,6 @@ def __init__(self,
self.eoh = eoh
self.assistant = assistant
self.meta_instruction = meta_instruction
self.session_len = session_len

@classmethod
def match(cls, model_path: str) -> Optional[str]:
Expand All @@ -1015,8 +945,7 @@ def match(cls, model_path: str) -> Optional[str]:
return 'solar'


@MODELS.register_module(name='ultracm')
@MODELS.register_module(name='ultralm')
@MODELS.register_module(name=['ultracm', 'ultralm'])
class UltraChat(BaseChatTemplate):
"""Template of UltraCM and UltraLM models.

Expand All @@ -1035,7 +964,6 @@ def __init__(
eoa='</s>',
separator='\n',
stop_words=['</s>'],
session_len=2048,
**kwargs):
super().__init__(system=system,
meta_instruction=meta_instruction,
Expand All @@ -1046,7 +974,6 @@ def __init__(
eoa=eoa,
separator=separator,
stop_words=stop_words,
session_len=session_len,
**kwargs)

@classmethod
Expand All @@ -1062,7 +989,7 @@ def match(cls, model_path: str) -> Optional[str]:
return 'ultralm'


@MODELS.register_module(name=['yi', 'yi-chat', 'yi-200k', 'yi-34b'])
@MODELS.register_module(name=['yi'])
class Yi(BaseChatTemplate):
"""Chat template of Yi model."""

Expand Down Expand Up @@ -1101,25 +1028,15 @@ def match(cls, model_path: str) -> Optional[str]:


@MODELS.register_module(name=['mistral', 'mixtral'])
@MODELS.register_module(name=['Mistral-7B-Instruct', 'Mixtral-8x7B-Instruct'])
class MistralChat(BaseChatTemplate):
"""Template of Mistral and Mixtral Instruct models.

`https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1`
`https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1`
"""

def __init__(self,
user='[INST] ',
eoh=' [/INST]',
eoa='</s>',
session_len=2048,
**kwargs):
super().__init__(user=user,
eoh=eoh,
eoa=eoa,
session_len=session_len,
**kwargs)
def __init__(self, user='[INST] ', eoh=' [/INST]', eoa='</s>', **kwargs):
super().__init__(user=user, eoh=eoh, eoa=eoa, **kwargs)

@classmethod
def match(cls, model_path: str) -> Optional[str]:
Expand Down Expand Up @@ -1168,7 +1085,6 @@ def match(cls, model_path: str) -> Optional[str]:
return 'gemma'


@MODELS.register_module(name=['deepseek-chat'])
@MODELS.register_module(name=['deepseek'])
class Deepseek(BaseChatTemplate):

Expand Down Expand Up @@ -1210,13 +1126,11 @@ def __init__(self,
eoh=' ',
assistant='<bot>: ',
eoa='</s>',
session_len=4096,
**kwargs):
super().__init__(user=user,
eoh=eoh,
assistant=assistant,
eoa=eoa,
session_len=session_len,
**kwargs)

def get_prompt(self, prompt, sequence_start=True):
Expand Down Expand Up @@ -1248,15 +1162,13 @@ def __init__(
eoh='\n\n',
assistant='Assistant: ',
eoa='<|end▁of▁sentence|>',
session_len=16384,
**kwargs):
super().__init__(meta_instruction=meta_instruction,
eosys=eosys,
user=user,
eoh=eoh,
assistant=assistant,
eoa=eoa,
session_len=session_len,
**kwargs)

def get_prompt(self, prompt, sequence_start=True):
Expand All @@ -1283,7 +1195,6 @@ class DeepSeek(BaseChatTemplate):

def __init__(
self,
session_len=4096,
system='',
meta_instruction="""You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\n""", # noqa: E501
eosys='',
Expand All @@ -1294,8 +1205,7 @@ def __init__(
separator='\n',
stop_words=['<|EOT|>'],
**kwargs):
super().__init__(session_len=session_len,
system=system,
super().__init__(system=system,
meta_instruction=meta_instruction,
eosys=eosys,
user=user,
Expand Down Expand Up @@ -1407,8 +1317,7 @@ def match(cls, model_path: str) -> Optional[str]:
return 'dbrx'


@MODELS.register_module(name=['internvl-zh-hermes2'])
@MODELS.register_module(name=['llava-chatml'])
@MODELS.register_module(name=['llava-chatml', 'internvl-zh-hermes2'])
class ChatmlDirect(BaseChatTemplate):

def __init__(self,
Expand All @@ -1420,7 +1329,6 @@ def __init__(self,
assistant='<|im_start|>assistant\n',
eoa='<|im_end|>',
separator='',
session_len=4096,
**kwargs):
super().__init__(system,
meta_instruction=meta_instruction,
Expand All @@ -1430,7 +1338,6 @@ def __init__(self,
assistant=assistant,
eoa=eoa,
separator=separator,
session_len=session_len,
**kwargs)

@classmethod
Expand Down
Loading
Loading