'''AntGLM Chat-model data format.
格式化 AntGLM 以及各种开源模型的符号系统:
- 确定 Chat 模型依赖的文件数据结构协议
- 确定单轮/多轮的统一结构
- 确定 Chat 符号系统的协议, 包括角色定义、分隔符等
- 方便做开源模型依赖的 prompt 转换
- 支持工具、代码、推理等支持
参考 FastChat Conversation 对象的设计思路.
Reference: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
'''
import copy
import dataclasses
import logging
import re
import uuid
from copy import deepcopy
from enum import IntEnum, auto
from typing import Dict, List, Optional, Tuple
logger = logging.getLogger(__name__)
class PromptStyle(IntEnum):
'''Prompt styles.'''
# 原始 antglm format 格式, 单轮指令没有结构, 多轮 `第1轮\n用户: xx\n机器人: xx\n`
ANTGLM_RAW = auto()
# Chat format 格式, 单轮多轮统一为 chat format 格式
ANTGLM_CHAT = auto()
# 单轮指令没有结构, 只有多轮为 chat format 格式
ANTGLM_ONLY_MULTITURN_CHAT = auto()
# OpenAI ChatML 格式, 包括千问
CHATML = auto()
# LLAMA2 格式
LLAMA2 = auto()
# ChatGLM 1/2 格式
CHATGLM = auto()
# ChatGLM3 格式
CHATGLM3 = auto()
# 百川格式
BAICHUAN2 = auto()
@dataclasses.dataclass
class Chat:
'''Chat 数据符号结构, 格式化 AntGLM 以及各种开源模型的符号系统.
Examples:
```python
>>> from antllm.data.chat_format import Chat
>>> ### 从 json 数据结构创建 chat 对象, 并且 format 结构使用 AntGLM 原始结构
>>> input_json = {
... "messages": [
... {"role": "HUMAN", "content": "讲一个笑话"},
... {"role": "ASSISTANT", "content": "为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!"},
... {"role": "HUMAN", "content": "不好笑,换个程序员的笑话"}
... ],
... }
>>> chat = Chat.from_json(input_json, name='antglm_raw')
>>> ### 根据 chat 对象创建大模型训练所需 pack 数据
>>> pack_data = chat.prompt_pack
>>> print(pack_data)
>>> ### 根据 chat 对象创建大模型训练所需 input, output 数据
>>> data = chat.prompt_inout
>>> print(data)
>>> ### 根据 chat 对象创建大模型预测用的 prompt
>>> prompt = chat.prompt_str
>>> print(prompt)
>>> ### 从大模型训练数据 {"input": "xx", "output": "xx"} 中创建 chat 对象
>>> data = {
... 'input': (
... '第1轮\n用户: 讲一个笑话\n机器人: 为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!\n'
... '第2轮\n用户: 不好笑,换个程序员的笑话\n机器人:'
... ),
... 'output': ''
... }
>>> chat = Chat.from_inout(data, name='antglm_raw')
>>> ### 从大模型 pack 训练数据创建 chat 对象列表
>>> pack_data = {
... 'inputs': ['第1轮\n用户: 讲一个笑话\n机器人:', '第2轮\n用户: 不好笑,换个程序员的笑话\n机器人:', '第1轮\n用户: 写首诗\n机器人:'],
... 'outputs': [
... '为什么猪不能上网?因为它们会被网上的“猪”骗!哈哈哈!\n',
... '为什么程序员总是喜欢使用黑色主题?因为他们喜欢“黑暗模式”(Dark Mode),这样他们就可以在晚上加班时更好地隐藏自己的错误!',
... '']
... }
>>> chats = Chat.from_pack(pack_data, name='antglm_raw')
>>> assert len(chats) == 2
>>> print(chats[0])
>>> print(chats[1])
>>> ### 显示总交互轮数 (以用户输出多少次为轮数个数)
>>> print(chat.turns_num)
>>> ### 根据 chat 对象创建 json 格式化输出
>>> data_json = chat.to_json()
>>> print(data_json)
>>> ### 增加轮次信息
>>> content = (
... '为什么程序员总是喜欢使用黑色主题?'
... '因为他们喜欢“黑暗模式”(Dark Mode),这样他们就可以在晚上加班时更好地隐藏自己的错误!'
... )
>>> chat.append_message(chat.role_assistant, content)
>>> ### 将 chat 对象转成 OpenAI ChatCompletion 接口的入参
>>> openai_messages = chat.to_openai_api_messages()
>>> print(openai_messages)
>>> ### 复制一个 chat 对象
>>> chat_new = chat.copy()
```
'''
# 数据结构名称
id: str = None
# format 支持: antglm_raw, antglm_chat, chatglm1, chatglm2, llama2, qwen, baichuan2
name: Optional[str] = None
# Prompt 风格
prompt_style: Optional[PromptStyle] = None
# System Template 和 message
system_template: str = 'SYSTEM{}'
system_message: str = ''
# 角色定义
role_human: str = 'HUMAN'
role_assistant: str = 'ASSISTANT'
role_observation: str = 'OBSERVATION'
role_template: str = '{}'
# 每轮符号定义
turn_start: str = ''
human_end: str = ''
assistant_start: str = ''
assistant_end: str = ''
assistant_end_ids: Optional[List[int]] = None
general_role_end: str = ''
# agent 符号定义
tool_template = '{}'
code_template = '{}
'
arithemetic_templte = '{}'
image_template = '{}'
# All messages. Each item is (role, message).
messages: List[Tuple[str, str]] = ()
# messages 中用于 few-shot messages
offset: int = 0
# 其他 meta data
source: Optional[str] = None
lang: Optional[str] = None
topic: Optional[str] = None
# 原始 json 数据
origin_json: Optional[dict] = None
@property
def support_names(self) -> Dict[str, str]:
'''支持的数据对象名称.'''
return {
'antglm_raw': '原始 antglm format 格式, 单轮指令没有结构, 多轮 `第1轮\\n用户:xx\\n机器人xx\\n`',
'antglm_chat': 'Chat format 格式, 单轮多轮统一为 chat format 格式',
'chatglm1': 'chatglm1 format',
'chatglm2': 'chatglm2 format',
'llama2': 'llama2 format',
'qwen': '千问 format',
'baichuan2': '百川 2 format',
}
@classmethod
def from_json(
cls,
input: dict,
name: Optional[str] = None,
prompt_style: Optional[PromptStyle] = None,
):
'''从文件数据结构到数据对象的转换.
Params:
name: `Optional[str]`, 符号系统名称
- format 支持: antglm_raw, antglm_chat, chatglm1, chatglm2, llama2, qwen, baichuan2
- 如果指定了 format name, 使用该 name 符号系统, 否则使用 input 中 `name` 字段
prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格
input: `dict`, 文件中的 json dict 对象, 协议为:
- 既支持 `messages` 字段, 也支持 `turns` 字段
{
"id": "xxx",
"name": "antglm",
"source": "xxx",
"lang": "xx",
"topic": "xx",
"system_template": "",
"system_message": "xx",
"messages": [
{
"role": "HUMAN",
"content": "Hi"
},
{
"role": "ASSISTANT",
"content": "Hello"
},
{
"role": "OBSERVATION",
"content": "xxx"
},
{
"role": "ASSISTANT",
"content": "xxx"
}
],
"turns": [
{"HUMAN": "xxx", "OBSERVATION": "xx", "ASSISTANT": "xx"}
]
}
Returns:
`Chat` 对象
'''
_id = input.get('id')
if name:
_name = name
else:
_name = input.get('name')
source = input.get('source')
lang = input.get('lang')
topic = input.get('topic')
kwargs = {}
if 'system_template' in input:
kwargs['system_template'] = input['system_template']
if 'system_message' in input:
kwargs['system_message'] = input['system_message']
# 转换成 Chat 对象
chat = cls(
id=_id,
name=_name,
prompt_style=prompt_style,
source=source,
lang=lang,
topic=topic,
origin_json=deepcopy(input),
**kwargs,
)
if 'messages' in input:
for msg in input['messages']:
if msg['role'] == 'HUMAN':
role = chat.role_human
elif msg['role'] == 'OBSERVATION':
role = chat.role_observation
elif msg['role'] == 'ASSISTANT':
role = chat.role_assistant
else:
raise ValueError(f'不支持数据集中的 role: {msg["role"]}')
chat.append_message(role, msg['content'])
elif 'turns' in input:
for turn in input['turns']:
if 'HUMAN' in turn:
content = turn['HUMAN']
chat.append_message(chat.role_human, content)
if 'OBSERVATION' in turn:
content = turn['OBSERVATION']
chat.append_message(chat.role_observation, content)
if 'ASSISTANT' in turn:
content = turn['ASSISTANT']
chat.append_message(chat.role_assistant, content)
return chat
@classmethod
def from_pack(
cls,
packs: Dict[str, List[str]],
name: str,
prompt_style: Optional[PromptStyle] = None,
) -> list:
'''根据 pack 数据创建 Chat 对象.
Params:
packs: `dict`, pack 样本数据
{
'inputs': ['xx', 'xx'],
'outputs': ['xx', 'xx'],
}
name: `str`, 符号系统名称
prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格
'''
chat = cls(name=name, prompt_style=prompt_style)
packs = cls._format_packs(packs)
sys_pattern = re.compile(chat.system_template.format(r'(.*?)'), re.DOTALL)
turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL)
human_pattern = re.compile(chat.role_template.format(chat.role_human).strip(), re.DOTALL)
observe_pattern = re.compile(chat.role_template.format(chat.role_observation).strip(), re.DOTALL)
assistant_pattern = re.compile(chat.role_template.format(chat.role_assistant).strip(), re.DOTALL)
chats = []
for input, output in zip(packs['input'], packs['output']):
# system message
sys_match = sys_pattern.search(input)
if sys_match and sys_match.group(0):
# system 指令只在首轮, 新增 chat 对象
if len(chat.messages) > 0:
chats.append(chat)
chat = cls(name=name, prompt_style=prompt_style)
input = input[sys_match.end() :]
chat.system_message = sys_match.group(1)
# turn start
turn_match = turn_pattern.search(input)
if turn_match and turn_match.group(0):
# 当出现下一个轮次开始信息, 新增 chat 对象
if name in ['antglm', 'antglm_raw', 'chatglm2']:
round_start = 1
else:
round_start = 0
if all(
[
len(turn_match.groups()) > 0,
int(turn_match.group(1)) == round_start,
len(chat.messages) > 0,
]
):
chats.append(chat)
chat = cls(name=name, prompt_style=prompt_style)
input = input[turn_match.end() :]
human_iter = human_pattern.finditer(input)
observe_iter = observe_pattern.finditer(input)
assistant_iter = assistant_pattern.finditer(input)
human_match = next(human_iter, None)
observe_match = next(observe_iter, None)
assistant_match = next(assistant_iter, None)
if not human_match and not observe_match:
# 无 role format
chat.append_message(chat.role_human, input)
while human_match or observe_match:
next_human_match = next(human_iter, None)
next_observe_match = next(observe_iter, None)
input = cls._append_human_observation(
chat,
input,
human_match=human_match,
next_human_match=next_human_match,
observe_match=observe_match,
next_observe_match=next_observe_match,
assistant_match=assistant_match,
)
human_match = next_human_match
observe_match = next_observe_match
next_human_match = next(human_iter, None)
next_observe_match = next(observe_iter, None)
if output:
chat.append_message(chat.role_assistant, output)
if chat.messages:
chats.append(chat)
return chats
@classmethod
def _append_human_observation(
cls,
chat,
input: str,
human_match: Optional[re.Match] = None,
next_human_match: Optional[re.Match] = None,
observe_match: Optional[re.Match] = None,
next_observe_match: Optional[re.Match] = None,
assistant_match: Optional[re.Match] = None,
) -> str:
'''给 chat 对象增加 human/observation message.'''
if observe_match:
# observation 在 human 之后
if observe_match.span()[0] > observe_match.span()[0]:
human_str = input[observe_match.span()[1] : observe_match.span()[0]]
observe_str = input[observe_match.span()[1] : assistant_match.span()[0]]
chat.append_message(chat.role_human, human_str.strip())
input_end = observe_match.span()[1]
if observe_match.span()[0] < next_human_match.span()[0]:
chat.append_message(chat.role_observation, observe_str.strip())
input_end = assistant_match.span()[1]
else:
# observation 在 human 之前
human_str = input[observe_match.span()[1] : assistant_match.span()[0]]
observe_str = input[observe_match.span()[1] : observe_match.span()[0]]
chat.append_message(chat.role_observation, observe_str.strip())
input_end = observe_match.span()[1]
if observe_match.span()[0] < next_observe_match.span()[0]:
chat.append_message(chat.role_human, human_str.strip())
input_end = assistant_match.span()[1]
else:
if assistant_match:
human_str = input[human_match.span()[1] : assistant_match.span()[0]]
input_end = assistant_match.span()[1]
else:
human_str = input[human_match.span()[1] :]
input_end = len(input)
chat.append_message(chat.role_human, human_str.strip())
return input[input_end:]
@classmethod
def from_inout(
cls,
sample: Dict[str, str],
name: str,
prompt_style: Optional[PromptStyle] = None,
):
'''根据单样本创建一个 Chat 对象.
Params:
sample: `Dict[str, str]`, input/output 数据样本
{
"input": "xxx",
"output": "xxx",
}
name: `str`, 符号系统名称
prompt_style: `Optional[PromptStyle]`, 指定 prompt 风格, 默认使用和 name 一致的风格
'''
chat = cls(name=name, prompt_style=prompt_style)
input = sample['input']
output = sample['output']
sys_pattern = re.compile(chat.system_template.format(r'(.*?)'), re.DOTALL)
turn_pattern = re.compile(chat.turn_start.format(r'(\d+)'), re.DOTALL)
human_pattern = re.compile(chat.role_template.format(chat.role_human).strip(), re.DOTALL)
observe_pattern = re.compile(chat.role_template.format(chat.role_observation).strip(), re.DOTALL)
assistant_pattern = re.compile(chat.role_template.format(chat.role_assistant).strip(), re.DOTALL)
# 去除轮次信息
input = turn_pattern.sub('', input)
# system message search
sys_match = sys_pattern.search(input)
if sys_match and sys_match.group(0):
input = input[sys_match.end() :]
chat.system_message = sys_match.group(1)
human_iter = human_pattern.finditer(input)
observe_iter = observe_pattern.finditer(input)
assistant_iter = assistant_pattern.finditer(input)
human_match = next(human_iter, None)
observe_match = next(observe_iter, None)
assistant_match = next(assistant_iter, None)
next_human_match = next(human_iter, None)
next_observe_match = next(observe_iter, None)
while any(
[
human_match,
observe_match,
assistant_match,
]
):
# human/observation 先后顺序可能不一样, 并且有可能有多个
# 判断 assitant 之前是否还有 human/observation
while any(
[
human_match and human_match.span()[0] < assistant_match.span()[0],
observe_match and observe_match.span()[0] < assistant_match.span()[0],
next_human_match and next_human_match.span()[0] < assistant_match.span()[0],
next_observe_match and next_observe_match.span()[0] < assistant_match.span()[0],
]
):
if not input:
break
cls._append_human_observation(
chat,
input,
human_match=human_match,
next_human_match=next_human_match,
observe_match=observe_match,
next_observe_match=next_observe_match,
assistant_match=assistant_match,
)
human_match = next_human_match
observe_match = next_observe_match
next_human_match = next(human_iter, None)
next_observe_match = next(observe_iter, None)
# assistant message
if assistant_match and assistant_match.span():
if observe_match:
if observe_match.span() and observe_match.span()[0] < human_match.span()[0]:
assistant_str = input[assistant_match.span()[1] : observe_match.span()[0]]
elif human_match:
if human_match.span():
assistant_str = input[assistant_match.span()[1] : human_match.span()[0]]
else:
assistant_str = input[assistant_match.span()[1] :]
if assistant_str:
chat.append_message(chat.role_assistant, assistant_str)
assistant_match = next(assistant_iter, None)
if output:
chat.append_message(chat.role_assistant, output)
return chat
def __hash__(self):
'''数据对象的 hash 函数.'''
return hash(self.id)
def __post_init__(self):
'''对象初始化后的处理, 处理包括:
- 根据数据对象名称, 支持转成其他开源数据对象的基本信息
'''
self.id = str(uuid.uuid4())
if not self.messages:
self.messages = []
if not self.name and not self.prompt_style:
logger.error('构造 Chat 对象至少包含以下一个入参: `name/prompt_style`.\n\n' '`name` 支持以下 format 名称:')
logger.error('\n'.join([f'{k}: {v}' for k, v in self.support_names.items()]))
logger.error('\n`prompt_style` 参考 antllm.data.chat_format.PromptStyle')
raise ValueError
if self.name == 'antglm':
# 默认 antglm 使用原始 antglm_raw - 第1轮\n用户: xx\n机器人: xx\n
self.name = 'antglm_raw'
if not self.name and self.prompt_style == PromptStyle.ANTGLM_CHAT:
logger.info(
'Chat 对象入参没有 `name`, 默认使用 `ANTGLM_CHAT`, format:\n'
f'role_human: {self.role_human}\n'
f'role_assistant: {self.role_assistant}\n'
f'role_observation: {self.role_observation}\n'
f'role_template: {self.role_template}\n'
f'turn_start: {self.turn_start}\n'
f'human_end: {self.human_end}\n'
f'assistant_start: {self.assistant_start}\n'
f'assistant_end: {self.assistant_end}\n'
f'assistant_end_ids: {self.assistant_end_ids}\n'
f'general_role_end: {self.general_role_end}\n'
f'tool_template: {self.tool_template}\n'
f'code_template: {self.code_template}\n'
f'arithemetic_templte: {self.arithemetic_templte}\n'
f'image_template: {self.image_template}\n'
f'\n入参 `name` 支持: ``'
)
return
if self.name == 'antglm_raw' or self.prompt_style == PromptStyle.ANTGLM_RAW:
self.prompt_style = PromptStyle.ANTGLM_RAW
self.role_template = '{}'
self.role_human = '用户: '
self.role_assistant = '机器人: '
self.turn_start = '第{}轮\n'
self.general_role_end = '\n'
if self.name in ['chatglm1', 'chatglm2'] or self.prompt_style == PromptStyle.CHATGLM:
self.prompt_style = PromptStyle.CHATGLM
self.role_template = '{}'
self.role_human = '问:'
self.role_assistant = '答:'
self.turn_start = '[Round {}]\n'
if self.name == 'chatglm1':
self.general_role_end = '\n'
else:
self.general_role_end = '\n\n'
elif self.name == 'chatglm3' or self.prompt_style == PromptStyle.CHATGLM3:
self.prompt_style = PromptStyle.CHATGLM3
self.system_template = '<|system|>\n {}'
self.role_human = '<|user|>\n '
self.role_assistant = '<|assistant|>\n '
self.role_template = '{}'
elif self.name == 'llama2' or self.prompt_style == PromptStyle.LLAMA2:
self.prompt_style = PromptStyle.LLAMA2
self.role_template = '{}'
self.system_template = '[INST] <>\n{}\n<>\n\n'
self.role_human = '[INST] '
self.role_assistant = '[/INST] '
self.human_end = ' '
self.assistant_end = ' '
elif self.name == 'qwen':
self.prompt_style = PromptStyle.CHATML
self.role_template = '{}'
self.system_template = '<|im_start|>system\n{}'
if not self.system_message:
self.system_message = 'You are a helpful assistant.'
self.role_human = '<|im_start|>user\n'
self.role_assistant = '<|im_start|>assistant\n'
self.general_role_end = '<|im_end|>\n'
elif self.name == 'baichuan':
self.prompt_style = PromptStyle.BAICHUAN2
self.role_template = '{}'
self.system_template = '{}'
self.role_human = ''
self.role_assistant = ''
if not self.system_template:
self.system_template = '{}'
def readable_messages(self) -> str:
'''将 messages 输出为人类可读的字符串, 方便分析数据.'''
pass
@property
def prompt_str(self) -> str:
'''将 Chat 对象转成 prompt str, 合并 human/assitant 输出为 format 字符串.'''
return f'{self.prompt_inout["input"]}{self.prompt_inout["output"]}'
@classmethod
def _format_packs(cls, packs: Dict[str, List[str]]) -> Dict[str, List[str]]:
'''格式化 pack 样本, 输出相同 pack inputs, outputs 个数.'''
_packs = copy.deepcopy(packs)
if len(_packs['input']) - 1 == len(_packs['output']):
_packs['output'].append('')
if len(_packs['input']) != len(_packs['output']):
print(packs)
raise ValueError(
'输入 input 和 output 数量不匹配, '
f'input num: {len(packs["input"])}, '
f'output num: {len(packs["output"])}'
)
return _packs
@property
def prompt_inout(self) -> Dict[str, str]:
'''将 Chat 对象转成 input prompt, output prompt 字符串.
Returns:
`Dict[str, str]`, 示例:
{
"input": "SYSTEMxxxxHUMAN你好ASSISTANT你好,有什么可以帮您?ASSISTANT", # noqa
"output": "你好,有什么可以帮您?"
}
'''
packs = self._format_packs(self.prompt_pack)
# 兼容逻辑
if self.prompt_style == PromptStyle.ANTGLM_RAW:
packs['input'] = [f'{item} ' for item in packs['input']]
prompt_input = ''.join([f'{x}{y}' for x, y in zip(packs['input'][:-1], packs['output'][:-1])])
prompt_input += packs['input'][-1]
prompt_output = packs['output'][-1]
# 兼容逻辑
if self.prompt_style == PromptStyle.ANTGLM_RAW:
prompt_input = prompt_input.strip()
return {
'input': prompt_input,
'output': prompt_output,
}
@property
def prompt_pack(self) -> Dict[str, List[str]]:
'''将数据对象转成 pack input prompt, output prompt 字符串列表.:
Returns:
`Dict[str, List[str]]`, 示例:
{
"input": [
"SYSTEMxxxxHUMAN你好ASSISTANT",
"HUMAN讲个笑话ASSISTANT",
"OBSERVATION{\"weather\": \"晴\"}ASSISTANT"
],
"output": [
"你好,有什么可以帮您?",
"笑话 1",
"今天天气 xxx"
]
}
'''
inputs = []
outputs = []
# 最开始 system 构造
system_prompt = ''
if self.system_message:
system_prompt = self.system_template.format(self.system_message)
if system_prompt:
ret = system_prompt + self.general_role_end
else:
ret = ''
# 有些 prompt style 单轮指令没有 format
if self.prompt_style in [
PromptStyle.ANTGLM_RAW,
PromptStyle.ANTGLM_ONLY_MULTITURN_CHAT,
]:
if len(self.messages) <= 2:
output = ''
for role, message in self.messages:
if role == self.role_assistant:
output = message
else:
input = ret + message
return {
'input': [input],
'output': [output],
}
# 多轮对话
if self.name in ['antglm_raw', 'chatglm2']:
round_start = 1
else:
round_start = 0
for i, (role, message) in enumerate(self.messages):
# 轮次信息
if self.name in ['antglm_raw', 'chatglm1', 'chatglm2']:
if i % 2 == 0:
ret += self.turn_start.format(i // 2 + round_start)
# 角色 + 内容
role_end = self.general_role_end
if role == self.role_assistant and self.assistant_end:
role_end = self.assistant_end
elif self.human_end:
role_end = self.human_end
ret += self.role_template.format(role) + message + role_end
if role == self.role_assistant:
# output 只保留实际 assistant 内容
if not message:
outputs.append('')
else:
outputs.append(message + role_end)
# input 需要连接 assistant role
inputs[-1] += ret[: -len(message + role_end)]
elif all(
[
role == self.role_observation,
len(self.messages) > 1,
self.messages[i - 1][0] != self.role_assistant,
]
):
# observation 之前不是 assistant, 需要将 observation 和上一个 input 连接一起
continue
else:
inputs.append(ret)
ret = ''
# 最后一轮不是机器人回复, 需要拼接机器人 role, 用于模型生成
if i == len(self.messages) - 1 and role != self.role_assistant:
inputs[-1] += self.role_template.format(self.role_assistant).strip()
# 兼容逻辑, 去除 inputs 最后空格符号
if self.prompt_style == PromptStyle.ANTGLM_RAW:
inputs = [item.strip() for item in inputs]
return {
'input': inputs,
'output': outputs,
}
@property
def turns_num(self) -> int:
'''和机器人的交互轮数, 以用户输出多少次为轮数个数.'''
return sum([1 if msg[0] == self.role_human else 0 for msg in self.messages])
def to_json(self) -> dict:
'''输出 chat json dict 格式, 包含不同角色和机器人交互的每轮信息.
Returns
`List[dict]`, {
"id": "xx",
"messages": [
{"role": "HUMAN", "content": "xxx"}
]
"turns": [
{"HUMAN": "xx", "OBSERVATION": "xx", "ASSISTANT": "xx"}
]
}
'''
turns = []
messages = []
turn = {}
for msg in self.messages:
if msg[0] == self.role_assistant:
messages.append({'role': 'ASSISTANT', 'content': msg[1]})
turn['ASSISTANT'] = msg[1]
turns.append(turn)
turn = {}
if msg[0] == self.role_human:
messages.append({'role': 'HUMAN', 'content': msg[1]})
turn['HUMAN'] = msg[1]
if msg[0] == self.role_observation:
messages.append({'role': 'OBSERVATION', 'content': msg[1]})
turn['OBSERVATION'] = msg[1]
if self.messages[-1][0] == self.role_human:
messages.append({'role': 'ASSISTANT', 'content': ''})
turn['ASSISTANT'] = ''
turns.append(turn)
result = self.origin_json or {}
result.update(
{
'id': self.id,
'name': self.name,
'source': self.source,
'lang': self.lang,
'topic': self.topic,
'system_template': self.system_template,
'system_message': self.system_message,
'turns': turns,
'messages': messages,
}
)
return result
def set_system_message(self, system_message: str):
'''Set the system message.'''
self.system_message = system_message
def append_message(self, role: str, message: str):
'''Append a new message.'''
if not message:
message = ''
self.messages.append([role, message])
def to_openai_api_messages(self) -> List[dict]:
'''Convert the conversation to OpenAI chat completion format.'''
ret = [{'role': 'system', 'content': self.system_message}]
for i, (_, msg) in enumerate(self.messages[self.offset :]):
if i % 2 == 0:
ret.append({'role': 'user', 'content': msg})
else:
if msg is not None:
ret.append({'role': 'assistant', 'content': msg})
return ret
def copy(self):
return copy.deepcopy(self)