""" Conversation prompt templates. We kindly request that you import fastchat instead of copying this file if you wish to use it. If you have changes in mind, please contribute back so the community can benefit collectively and continue to maintain these valuable templates. """ import dataclasses from enum import IntEnum, auto from typing import Any, Dict, List, Tuple, Union class SeparatorStyle(IntEnum): """Separator styles.""" ADD_COLON_SINGLE = auto() ADD_COLON_TWO = auto() ADD_COLON_SPACE_SINGLE = auto() NO_COLON_SINGLE = auto() NO_COLON_TWO = auto() ADD_NEW_LINE_SINGLE = auto() LLAMA2 = auto() CHATGLM = auto() CHATML = auto() CHATINTERN = auto() DOLLY = auto() RWKV = auto() PHOENIX = auto() ROBIN = auto() FALCON_CHAT = auto() CHATGLM3 = auto() INTERNVL_ZH = auto() @dataclasses.dataclass class Conversation: """A class that manages prompt templates and keeps all conversation history.""" # The name of this template name: str # The template of the system prompt system_template: str = '{system_message}' # The system message system_message: str = '' # The names of two roles roles: Tuple[str] = ('USER', 'ASSISTANT') # All messages. Each item is (role, message). messages: List[List[str]] = () # The number of few shot examples offset: int = 0 # The separator style and configurations sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE sep: str = '\n' sep2: str = None # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None # Stops generation if meeting any token in this list stop_token_ids: List[int] = None def get_prompt(self) -> str: """Get the prompt for generation.""" system_prompt = self.system_template.format(system_message=self.system_message) if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: ret = system_prompt + self.sep for role, message in self.messages: if message: ret += role + ': ' + message + self.sep else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: seps = [self.sep, self.sep2] ret = system_prompt + seps[0] for i, (role, message) in enumerate(self.messages): if message: ret += role + ': ' + message + seps[i % 2] else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: ret = system_prompt + self.sep for role, message in self.messages: if message: ret += role + ': ' + message + self.sep else: ret += role + ': ' # must be end with a space return ret elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: ret = '' if system_prompt == '' else system_prompt + self.sep for role, message in self.messages: if message: ret += role + '\n' + message + self.sep else: ret += role + '\n' return ret elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: ret = system_prompt for role, message in self.messages: if message: ret += role + message + self.sep else: ret += role return ret elif self.sep_style == SeparatorStyle.NO_COLON_TWO: seps = [self.sep, self.sep2] ret = system_prompt for i, (role, message) in enumerate(self.messages): if message: ret += role + message + seps[i % 2] else: ret += role return ret elif self.sep_style == SeparatorStyle.RWKV: ret = system_prompt for i, (role, message) in enumerate(self.messages): if message: ret += ( role + ': ' + message.replace('\r\n', '\n').replace('\n\n', '\n') ) ret += '\n\n' else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.LLAMA2: seps = [self.sep, self.sep2] if self.system_message: ret = system_prompt else: ret = '[INST] ' for i, (role, message) in enumerate(self.messages): tag = self.roles[i % 2] if message: if i == 0: ret += message + ' ' else: ret += tag + ' ' + message + seps[i % 2] else: ret += tag return ret elif self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 round_add_n = 1 if self.name == 'chatglm2' else 0 if system_prompt: ret = system_prompt + self.sep else: ret = '' for i, (role, message) in enumerate(self.messages): if i % 2 == 0: ret += f'[Round {i//2 + round_add_n}]{self.sep}' if message: ret += f'{role}:{message}{self.sep}' else: ret += f'{role}:' return ret elif self.sep_style == SeparatorStyle.CHATML: ret = '' if system_prompt == '' else system_prompt + self.sep + '\n' for role, message in self.messages: if message: ret += role + '\n' + message + self.sep + '\n' else: ret += role + '\n' return ret elif self.sep_style == SeparatorStyle.CHATGLM3: ret = '' if self.system_message: ret += system_prompt for role, message in self.messages: if message: ret += role + '\n' + ' ' + message else: ret += role return ret elif self.sep_style == SeparatorStyle.CHATINTERN: # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 seps = [self.sep, self.sep2] ret = system_prompt for i, (role, message) in enumerate(self.messages): # if i % 2 == 0: # ret += "" if message: ret += role + ':' + message + seps[i % 2] + '\n' else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.DOLLY: seps = [self.sep, self.sep2] ret = system_prompt for i, (role, message) in enumerate(self.messages): if message: ret += role + ':\n' + message + seps[i % 2] if i % 2 == 1: ret += '\n\n' else: ret += role + ':\n' return ret elif self.sep_style == SeparatorStyle.PHOENIX: ret = system_prompt for role, message in self.messages: if message: ret += role + ': ' + '' + message + '' else: ret += role + ': ' + '' return ret elif self.sep_style == SeparatorStyle.ROBIN: ret = system_prompt + self.sep for role, message in self.messages: if message: ret += role + ':\n' + message + self.sep else: ret += role + ':\n' return ret elif self.sep_style == SeparatorStyle.FALCON_CHAT: ret = '' if self.system_message: ret += system_prompt + self.sep for role, message in self.messages: if message: ret += role + ': ' + message + self.sep else: ret += role + ':' return ret elif self.sep_style == SeparatorStyle.INTERNVL_ZH: seps = [self.sep2, self.sep] ret = self.system_message + seps[0] for i, (role, message) in enumerate(self.messages): if message: ret += role + ': ' + message + seps[i % 2] else: ret += role + ':' return ret else: raise ValueError(f'Invalid style: {self.sep_style}') def set_system_message(self, system_message: str): """Set the system message.""" self.system_message = system_message def append_message(self, role: str, message: str): """Append a new message.""" self.messages.append([role, message]) def update_last_message(self, message: str): """Update the last output. The last message is typically set to be None when constructing the prompt, so we need to update it in-place after getting the response from a model. """ self.messages[-1][1] = message def to_gradio_chatbot(self): """Convert the conversation to gradio chatbot format.""" ret = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def to_openai_api_messages(self): """Convert the conversation to OpenAI chat completion format.""" ret = [{'role': 'system', 'content': self.system_message}] for i, (_, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: ret.append({'role': 'user', 'content': msg}) else: if msg is not None: ret.append({'role': 'assistant', 'content': msg}) return ret def copy(self): return Conversation( name=self.name, system_template=self.system_template, system_message=self.system_message, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, stop_str=self.stop_str, stop_token_ids=self.stop_token_ids, ) def dict(self): return { 'template_name': self.name, 'system_message': self.system_message, 'roles': self.roles, 'messages': self.messages, 'offset': self.offset, } # A global registry for all conversation templates conv_templates: Dict[str, Conversation] = {} def register_conv_template(template: Conversation, override: bool = False): """Register a new conversation template.""" if not override: assert ( template.name not in conv_templates ), f'{template.name} has been registered.' conv_templates[template.name] = template def get_conv_template(name: str) -> Conversation: """Get a conversation template.""" return conv_templates[name].copy() # InternVL-Chat-V1-1 template register_conv_template( Conversation( name='internvl_zh', system_template='', roles=('', ''), sep_style=SeparatorStyle.INTERNVL_ZH, sep='', sep2=' ', ) )