yungongzi OnePieceMan commited on
Commit
d7bf446
·
1 Parent(s): 6cb617c

Added support for Baichuan LLM (#934)

Browse files

### What problem does this PR solve?

- Added support for Baichuan LLM

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: 海贼宅 <[email protected]>

api/db/init_data.py CHANGED
@@ -137,7 +137,12 @@ factory_infos = [{
137
  "logo": "",
138
  "tags": "LLM, TEXT EMBEDDING",
139
  "status": "1",
140
- }
 
 
 
 
 
141
  # {
142
  # "name": "文心一言",
143
  # "logo": "",
@@ -392,6 +397,49 @@ def init_llm_factory():
392
  "max_tokens": 4096,
393
  "model_type": LLMType.CHAT.value
394
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  ]
396
  for info in factory_infos:
397
  try:
 
137
  "logo": "",
138
  "tags": "LLM, TEXT EMBEDDING",
139
  "status": "1",
140
+ },{
141
+ "name": "BaiChuan",
142
+ "logo": "",
143
+ "tags": "LLM,TEXT EMBEDDING",
144
+ "status": "1",
145
+ },
146
  # {
147
  # "name": "文心一言",
148
  # "logo": "",
 
397
  "max_tokens": 4096,
398
  "model_type": LLMType.CHAT.value
399
  },
400
+ # ------------------------ BaiChuan -----------------------
401
+ {
402
+ "fid": factory_infos[10]["name"],
403
+ "llm_name": "Baichuan2-Turbo",
404
+ "tags": "LLM,CHAT,32K",
405
+ "max_tokens": 32768,
406
+ "model_type": LLMType.CHAT.value
407
+ },
408
+ {
409
+ "fid": factory_infos[10]["name"],
410
+ "llm_name": "Baichuan2-Turbo-192k",
411
+ "tags": "LLM,CHAT,192K",
412
+ "max_tokens": 196608,
413
+ "model_type": LLMType.CHAT.value
414
+ },
415
+ {
416
+ "fid": factory_infos[10]["name"],
417
+ "llm_name": "Baichuan3-Turbo",
418
+ "tags": "LLM,CHAT,32K",
419
+ "max_tokens": 32768,
420
+ "model_type": LLMType.CHAT.value
421
+ },
422
+ {
423
+ "fid": factory_infos[10]["name"],
424
+ "llm_name": "Baichuan3-Turbo-128k",
425
+ "tags": "LLM,CHAT,128K",
426
+ "max_tokens": 131072,
427
+ "model_type": LLMType.CHAT.value
428
+ },
429
+ {
430
+ "fid": factory_infos[10]["name"],
431
+ "llm_name": "Baichuan4",
432
+ "tags": "LLM,CHAT,128K",
433
+ "max_tokens": 131072,
434
+ "model_type": LLMType.CHAT.value
435
+ },
436
+ {
437
+ "fid": factory_infos[10]["name"],
438
+ "llm_name": "Baichuan-Text-Embedding",
439
+ "tags": "TEXT EMBEDDING",
440
+ "max_tokens": 512,
441
+ "model_type": LLMType.EMBEDDING.value
442
+ },
443
  ]
444
  for info in factory_infos:
445
  try:
rag/llm/__init__.py CHANGED
@@ -26,7 +26,8 @@ EmbeddingModel = {
26
  "ZHIPU-AI": ZhipuEmbed,
27
  "FastEmbed": FastEmbed,
28
  "Youdao": YoudaoEmbed,
29
- "DeepSeek": DefaultEmbedding
 
30
  }
31
 
32
 
@@ -47,6 +48,7 @@ ChatModel = {
47
  "Ollama": OllamaChat,
48
  "Xinference": XinferenceChat,
49
  "Moonshot": MoonshotChat,
50
- "DeepSeek": DeepSeekChat
 
51
  }
52
 
 
26
  "ZHIPU-AI": ZhipuEmbed,
27
  "FastEmbed": FastEmbed,
28
  "Youdao": YoudaoEmbed,
29
+ "DeepSeek": DefaultEmbedding,
30
+ "BaiChuan": BaiChuanEmbed
31
  }
32
 
33
 
 
48
  "Ollama": OllamaChat,
49
  "Xinference": XinferenceChat,
50
  "Moonshot": MoonshotChat,
51
+ "DeepSeek": DeepSeekChat,
52
+ "BaiChuan": BaiChuanChat
53
  }
54
 
rag/llm/chat_model.py CHANGED
@@ -95,6 +95,84 @@ class DeepSeekChat(Base):
95
  super().__init__(key, model_name, base_url)
96
 
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  class QWenChat(Base):
99
  def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
100
  import dashscope
 
95
  super().__init__(key, model_name, base_url)
96
 
97
 
98
+ class BaiChuanChat(Base):
99
+ def __init__(self, key, model_name="Baichuan3-Turbo", base_url="https://api.baichuan-ai.com/v1"):
100
+ if not base_url:
101
+ base_url = "https://api.baichuan-ai.com/v1"
102
+ super().__init__(key, model_name, base_url)
103
+
104
+ @staticmethod
105
+ def _format_params(params):
106
+ return {
107
+ "temperature": params.get("temperature", 0.3),
108
+ "max_tokens": params.get("max_tokens", 2048),
109
+ "top_p": params.get("top_p", 0.85),
110
+ }
111
+
112
+ def chat(self, system, history, gen_conf):
113
+ if system:
114
+ history.insert(0, {"role": "system", "content": system})
115
+ try:
116
+ response = self.client.chat.completions.create(
117
+ model=self.model_name,
118
+ messages=history,
119
+ extra_body={
120
+ "tools": [{
121
+ "type": "web_search",
122
+ "web_search": {
123
+ "enable": True,
124
+ "search_mode": "performance_first"
125
+ }
126
+ }]
127
+ },
128
+ **self._format_params(gen_conf))
129
+ ans = response.choices[0].message.content.strip()
130
+ if response.choices[0].finish_reason == "length":
131
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
132
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
133
+ return ans, response.usage.total_tokens
134
+ except openai.APIError as e:
135
+ return "**ERROR**: " + str(e), 0
136
+
137
+ def chat_streamly(self, system, history, gen_conf):
138
+ if system:
139
+ history.insert(0, {"role": "system", "content": system})
140
+ ans = ""
141
+ total_tokens = 0
142
+ try:
143
+ response = self.client.chat.completions.create(
144
+ model=self.model_name,
145
+ messages=history,
146
+ extra_body={
147
+ "tools": [{
148
+ "type": "web_search",
149
+ "web_search": {
150
+ "enable": True,
151
+ "search_mode": "performance_first"
152
+ }
153
+ }]
154
+ },
155
+ stream=True,
156
+ **self._format_params(gen_conf))
157
+ for resp in response:
158
+ if resp.choices[0].finish_reason == "stop":
159
+ if not resp.choices[0].delta.content:
160
+ continue
161
+ total_tokens = resp.usage.get('total_tokens', 0)
162
+ if not resp.choices[0].delta.content:
163
+ continue
164
+ ans += resp.choices[0].delta.content
165
+ if resp.choices[0].finish_reason == "length":
166
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
167
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
168
+ yield ans
169
+
170
+ except Exception as e:
171
+ yield ans + "\n**ERROR**: " + str(e)
172
+
173
+ yield total_tokens
174
+
175
+
176
  class QWenChat(Base):
177
  def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs):
178
  import dashscope
rag/llm/embedding_model.py CHANGED
@@ -104,6 +104,15 @@ class OpenAIEmbed(Base):
104
  return np.array(res.data[0].embedding), res.usage.total_tokens
105
 
106
 
 
 
 
 
 
 
 
 
 
107
  class QWenEmbed(Base):
108
  def __init__(self, key, model_name="text_embedding_v2", **kwargs):
109
  dashscope.api_key = key
 
104
  return np.array(res.data[0].embedding), res.usage.total_tokens
105
 
106
 
107
+ class BaiChuanEmbed(OpenAIEmbed):
108
+ def __init__(self, key,
109
+ model_name='Baichuan-Text-Embedding',
110
+ base_url='https://api.baichuan-ai.com/v1'):
111
+ if not base_url:
112
+ base_url = "https://api.baichuan-ai.com/v1"
113
+ super().__init__(key, model_name, base_url)
114
+
115
+
116
  class QWenEmbed(Base):
117
  def __init__(self, key, model_name="text_embedding_v2", **kwargs):
118
  dashscope.api_key = key
web/src/assets/svg/llm/baichuan.svg ADDED
web/src/pages/user-setting/setting-model/index.tsx CHANGED
@@ -55,6 +55,7 @@ const IconMap = {
55
  Xinference: 'xinference',
56
  DeepSeek: 'deepseek',
57
  VolcEngine: 'volc_engine',
 
58
  };
59
 
60
  const LlmIcon = ({ name }: { name: string }) => {
 
55
  Xinference: 'xinference',
56
  DeepSeek: 'deepseek',
57
  VolcEngine: 'volc_engine',
58
+ BaiChuan: 'baichuan',
59
  };
60
 
61
  const LlmIcon = ({ name }: { name: string }) => {