黄腾 aopstudio commited on
Commit
06a1df0
·
1 Parent(s): 73da86b

add support for Baidu yiyan (#2049)

Browse files

### What problem does this PR solve?

add support for Baidu yiyan

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <[email protected]>

api/apps/llm_app.py CHANGED
@@ -140,7 +140,11 @@ def add_llm():
140
  api_key = req.get("api_key","xxxxxxxxxxxxxxx")
141
  elif factory =="XunFei Spark":
142
  llm_name = req["llm_name"]
143
- api_key = req.get("spark_api_password","")
 
 
 
 
144
  else:
145
  llm_name = req["llm_name"]
146
  api_key = req.get("api_key","xxxxxxxxxxxxxxx")
@@ -157,7 +161,7 @@ def add_llm():
157
  msg = ""
158
  if llm["model_type"] == LLMType.EMBEDDING.value:
159
  mdl = EmbeddingModel[factory](
160
- key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
161
  model_name=llm["llm_name"],
162
  base_url=llm["api_base"])
163
  try:
@@ -168,7 +172,7 @@ def add_llm():
168
  msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
169
  elif llm["model_type"] == LLMType.CHAT.value:
170
  mdl = ChatModel[factory](
171
- key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate","XunFei Spark"] else None,
172
  model_name=llm["llm_name"],
173
  base_url=llm["api_base"]
174
  )
@@ -182,7 +186,9 @@ def add_llm():
182
  e)
183
  elif llm["model_type"] == LLMType.RERANK:
184
  mdl = RerankModel[factory](
185
- key=None, model_name=llm["llm_name"], base_url=llm["api_base"]
 
 
186
  )
187
  try:
188
  arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
@@ -193,7 +199,9 @@ def add_llm():
193
  e)
194
  elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
195
  mdl = CvModel[factory](
196
- key=llm["api_key"] if factory in ["OpenAI-API-Compatible"] else None, model_name=llm["llm_name"], base_url=llm["api_base"]
 
 
197
  )
198
  try:
199
  img_url = (
 
140
  api_key = req.get("api_key","xxxxxxxxxxxxxxx")
141
  elif factory =="XunFei Spark":
142
  llm_name = req["llm_name"]
143
+ api_key = req.get("spark_api_password","xxxxxxxxxxxxxxx")
144
+ elif factory == "BaiduYiyan":
145
+ llm_name = req["llm_name"]
146
+ api_key = '{' + f'"yiyan_ak": "{req.get("yiyan_ak", "")}", ' \
147
+ f'"yiyan_sk": "{req.get("yiyan_sk", "")}"' + '}'
148
  else:
149
  llm_name = req["llm_name"]
150
  api_key = req.get("api_key","xxxxxxxxxxxxxxx")
 
161
  msg = ""
162
  if llm["model_type"] == LLMType.EMBEDDING.value:
163
  mdl = EmbeddingModel[factory](
164
+ key=llm['api_key'],
165
  model_name=llm["llm_name"],
166
  base_url=llm["api_base"])
167
  try:
 
172
  msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
173
  elif llm["model_type"] == LLMType.CHAT.value:
174
  mdl = ChatModel[factory](
175
+ key=llm['api_key'],
176
  model_name=llm["llm_name"],
177
  base_url=llm["api_base"]
178
  )
 
186
  e)
187
  elif llm["model_type"] == LLMType.RERANK:
188
  mdl = RerankModel[factory](
189
+ key=llm["api_key"],
190
+ model_name=llm["llm_name"],
191
+ base_url=llm["api_base"]
192
  )
193
  try:
194
  arr, tc = mdl.similarity("Hello~ Ragflower!", ["Hi, there!"])
 
199
  e)
200
  elif llm["model_type"] == LLMType.IMAGE2TEXT.value:
201
  mdl = CvModel[factory](
202
+ key=llm["api_key"],
203
+ model_name=llm["llm_name"],
204
+ base_url=llm["api_base"]
205
  )
206
  try:
207
  img_url = (
conf/llm_factories.json CHANGED
@@ -3201,6 +3201,13 @@
3201
  "tags": "LLM",
3202
  "status": "1",
3203
  "llm": []
 
 
 
 
 
 
 
3204
  }
3205
  ]
3206
  }
 
3201
  "tags": "LLM",
3202
  "status": "1",
3203
  "llm": []
3204
+ },
3205
+ {
3206
+ "name": "BaiduYiyan",
3207
+ "logo": "",
3208
+ "tags": "LLM",
3209
+ "status": "1",
3210
+ "llm": []
3211
  }
3212
  ]
3213
  }
docs/quickstart.mdx CHANGED
@@ -119,7 +119,7 @@ This section provides instructions on setting up the RAGFlow server on Linux. If
119
  ```
120
 
121
  :::note
122
- If the above steps does not work, consider using [this workaround](https://github.com/docker/for-mac/issues/7047#issuecomment-1791912053), which employs a container and does not require manual editing of the macOS settings.
123
  :::
124
 
125
  </TabItem>
 
119
  ```
120
 
121
  :::note
122
+ If the above steps do not work, consider using [this workaround](https://github.com/docker/for-mac/issues/7047#issuecomment-1791912053), which employs a container and does not require manual editing of the macOS settings.
123
  :::
124
 
125
  </TabItem>
rag/llm/__init__.py CHANGED
@@ -43,7 +43,8 @@ EmbeddingModel = {
43
  "PerfXCloud": PerfXCloudEmbed,
44
  "Upstage": UpstageEmbed,
45
  "SILICONFLOW": SILICONFLOWEmbed,
46
- "Replicate": ReplicateEmbed
 
47
  }
48
 
49
 
@@ -101,7 +102,8 @@ ChatModel = {
101
  "01.AI": YiChat,
102
  "Replicate": ReplicateChat,
103
  "Tencent Hunyuan": HunyuanChat,
104
- "XunFei Spark": SparkChat
 
105
  }
106
 
107
 
@@ -115,7 +117,8 @@ RerankModel = {
115
  "OpenAI-API-Compatible": OpenAI_APIRerank,
116
  "cohere": CoHereRerank,
117
  "TogetherAI": TogetherAIRerank,
118
- "SILICONFLOW": SILICONFLOWRerank
 
119
  }
120
 
121
 
 
43
  "PerfXCloud": PerfXCloudEmbed,
44
  "Upstage": UpstageEmbed,
45
  "SILICONFLOW": SILICONFLOWEmbed,
46
+ "Replicate": ReplicateEmbed,
47
+ "BaiduYiyan": BaiduYiyanEmbed
48
  }
49
 
50
 
 
102
  "01.AI": YiChat,
103
  "Replicate": ReplicateChat,
104
  "Tencent Hunyuan": HunyuanChat,
105
+ "XunFei Spark": SparkChat,
106
+ "BaiduYiyan": BaiduYiyanChat
107
  }
108
 
109
 
 
117
  "OpenAI-API-Compatible": OpenAI_APIRerank,
118
  "cohere": CoHereRerank,
119
  "TogetherAI": TogetherAIRerank,
120
+ "SILICONFLOW": SILICONFLOWRerank,
121
+ "BaiduYiyan": BaiduYiyanRerank
122
  }
123
 
124
 
rag/llm/chat_model.py CHANGED
@@ -1185,3 +1185,69 @@ class SparkChat(Base):
1185
  }
1186
  model_version = model2version[model_name]
1187
  super().__init__(key, model_version, base_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1185
  }
1186
  model_version = model2version[model_name]
1187
  super().__init__(key, model_version, base_url)
1188
+
1189
+
1190
+ class BaiduYiyanChat(Base):
1191
+ def __init__(self, key, model_name, base_url=None):
1192
+ import qianfan
1193
+
1194
+ key = json.loads(key)
1195
+ ak = key.get("yiyan_ak","")
1196
+ sk = key.get("yiyan_sk","")
1197
+ self.client = qianfan.ChatCompletion(ak=ak,sk=sk)
1198
+ self.model_name = model_name.lower()
1199
+ self.system = ""
1200
+
1201
+ def chat(self, system, history, gen_conf):
1202
+ if system:
1203
+ self.system = system
1204
+ gen_conf["penalty_score"] = (
1205
+ (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
1206
+ ) + 1
1207
+ if "max_tokens" in gen_conf:
1208
+ gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
1209
+ ans = ""
1210
+
1211
+ try:
1212
+ response = self.client.do(
1213
+ model=self.model_name,
1214
+ messages=history,
1215
+ system=self.system,
1216
+ **gen_conf
1217
+ ).body
1218
+ ans = response['result']
1219
+ return ans, response["usage"]["total_tokens"]
1220
+
1221
+ except Exception as e:
1222
+ return ans + "\n**ERROR**: " + str(e), 0
1223
+
1224
+ def chat_streamly(self, system, history, gen_conf):
1225
+ if system:
1226
+ self.system = system
1227
+ gen_conf["penalty_score"] = (
1228
+ (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
1229
+ ) + 1
1230
+ if "max_tokens" in gen_conf:
1231
+ gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
1232
+ ans = ""
1233
+ total_tokens = 0
1234
+
1235
+ try:
1236
+ response = self.client.do(
1237
+ model=self.model_name,
1238
+ messages=history,
1239
+ system=self.system,
1240
+ stream=True,
1241
+ **gen_conf
1242
+ )
1243
+ for resp in response:
1244
+ resp = resp.body
1245
+ ans += resp['result']
1246
+ total_tokens = resp["usage"]["total_tokens"]
1247
+
1248
+ yield ans
1249
+
1250
+ except Exception as e:
1251
+ return ans + "\n**ERROR**: " + str(e), 0
1252
+
1253
+ yield total_tokens
rag/llm/embedding_model.py CHANGED
@@ -32,6 +32,7 @@ import asyncio
32
  from api.utils.file_utils import get_home_cache_dir
33
  from rag.utils import num_tokens_from_string, truncate
34
  import google.generativeai as genai
 
35
 
36
  class Base(ABC):
37
  def __init__(self, key, model_name):
@@ -591,11 +592,34 @@ class ReplicateEmbed(Base):
591
  self.client = Client(api_token=key)
592
 
593
  def encode(self, texts: list, batch_size=32):
594
- from json import dumps
595
-
596
- res = self.client.run(self.model_name, input={"texts": dumps(texts)})
597
  return np.array(res), sum([num_tokens_from_string(text) for text in texts])
598
 
599
  def encode_queries(self, text):
600
  res = self.client.embed(self.model_name, input={"texts": [text]})
601
  return np.array(res), num_tokens_from_string(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  from api.utils.file_utils import get_home_cache_dir
33
  from rag.utils import num_tokens_from_string, truncate
34
  import google.generativeai as genai
35
+ import json
36
 
37
  class Base(ABC):
38
  def __init__(self, key, model_name):
 
592
  self.client = Client(api_token=key)
593
 
594
  def encode(self, texts: list, batch_size=32):
595
+ res = self.client.run(self.model_name, input={"texts": json.dumps(texts)})
 
 
596
  return np.array(res), sum([num_tokens_from_string(text) for text in texts])
597
 
598
  def encode_queries(self, text):
599
  res = self.client.embed(self.model_name, input={"texts": [text]})
600
  return np.array(res), num_tokens_from_string(text)
601
+
602
+
603
+ class BaiduYiyanEmbed(Base):
604
+ def __init__(self, key, model_name, base_url=None):
605
+ import qianfan
606
+
607
+ key = json.loads(key)
608
+ ak = key.get("yiyan_ak", "")
609
+ sk = key.get("yiyan_sk", "")
610
+ self.client = qianfan.Embedding(ak=ak, sk=sk)
611
+ self.model_name = model_name
612
+
613
+ def encode(self, texts: list, batch_size=32):
614
+ res = self.client.do(model=self.model_name, texts=texts).body
615
+ return (
616
+ np.array([r["embedding"] for r in res["data"]]),
617
+ res["usage"]["total_tokens"],
618
+ )
619
+
620
+ def encode_queries(self, text):
621
+ res = self.client.do(model=self.model_name, texts=[text]).body
622
+ return (
623
+ np.array([r["embedding"] for r in res["data"]]),
624
+ res["usage"]["total_tokens"],
625
+ )
rag/llm/rerank_model.py CHANGED
@@ -24,6 +24,7 @@ from abc import ABC
24
  import numpy as np
25
  from api.utils.file_utils import get_home_cache_dir
26
  from rag.utils import num_tokens_from_string, truncate
 
27
 
28
  def sigmoid(x):
29
  return 1 / (1 + np.exp(-x))
@@ -288,3 +289,25 @@ class SILICONFLOWRerank(Base):
288
  rank[indexs],
289
  response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
290
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  import numpy as np
25
  from api.utils.file_utils import get_home_cache_dir
26
  from rag.utils import num_tokens_from_string, truncate
27
+ import json
28
 
29
  def sigmoid(x):
30
  return 1 / (1 + np.exp(-x))
 
289
  rank[indexs],
290
  response["meta"]["tokens"]["input_tokens"] + response["meta"]["tokens"]["output_tokens"],
291
  )
292
+
293
+
294
+ class BaiduYiyanRerank(Base):
295
+ def __init__(self, key, model_name, base_url=None):
296
+ from qianfan.resources import Reranker
297
+
298
+ key = json.loads(key)
299
+ ak = key.get("yiyan_ak", "")
300
+ sk = key.get("yiyan_sk", "")
301
+ self.client = Reranker(ak=ak, sk=sk)
302
+ self.model_name = model_name
303
+
304
+ def similarity(self, query: str, texts: list):
305
+ res = self.client.do(
306
+ model=self.model_name,
307
+ query=query,
308
+ documents=texts,
309
+ top_n=len(texts),
310
+ ).body
311
+ rank = np.array([d["relevance_score"] for d in res["results"]])
312
+ indexs = [d["index"] for d in res["results"]]
313
+ return rank[indexs], res["usage"]["total_tokens"]
requirements.txt CHANGED
@@ -62,6 +62,7 @@ pytest==8.2.2
62
  python-dotenv==1.0.1
63
  python_dateutil==2.8.2
64
  python_pptx==0.6.23
 
65
  readability_lxml==0.8.1
66
  redis==5.0.3
67
  Requests==2.32.2
 
62
  python-dotenv==1.0.1
63
  python_dateutil==2.8.2
64
  python_pptx==0.6.23
65
+ qianfan==0.4.6
66
  readability_lxml==0.8.1
67
  redis==5.0.3
68
  Requests==2.32.2
requirements_arm.txt CHANGED
@@ -100,6 +100,7 @@ python-docx==1.1.0
100
  python-dotenv==1.0.1
101
  python-pptx==0.6.23
102
  PyYAML==6.0.1
 
103
  redis==5.0.3
104
  regex==2023.12.25
105
  replicate==0.31.0
 
100
  python-dotenv==1.0.1
101
  python-pptx==0.6.23
102
  PyYAML==6.0.1
103
+ qianfan==0.4.6
104
  redis==5.0.3
105
  regex==2023.12.25
106
  replicate==0.31.0
web/src/assets/svg/llm/yiyan.svg ADDED
web/src/locales/en.ts CHANGED
@@ -528,6 +528,11 @@ The above is the content you need to summarize.`,
528
  SparkModelNameMessage: 'Please select Spark model',
529
  addSparkAPIPassword: 'Spark APIPassword',
530
  SparkAPIPasswordMessage: 'please input your APIPassword',
 
 
 
 
 
531
  },
532
  message: {
533
  registered: 'Registered!',
 
528
  SparkModelNameMessage: 'Please select Spark model',
529
  addSparkAPIPassword: 'Spark APIPassword',
530
  SparkAPIPasswordMessage: 'please input your APIPassword',
531
+ yiyanModelNameMessage: 'Please input model name',
532
+ addyiyanAK: 'yiyan API KEY',
533
+ yiyanAKMessage: 'Please input your API KEY',
534
+ addyiyanSK: 'yiyan Secret KEY',
535
+ yiyanSKMessage: 'Please input your Secret KEY',
536
  },
537
  message: {
538
  registered: 'Registered!',
web/src/locales/zh-traditional.ts CHANGED
@@ -491,6 +491,11 @@ export default {
491
  SparkModelNameMessage: '請選擇星火模型!',
492
  addSparkAPIPassword: '星火 APIPassword',
493
  SparkAPIPasswordMessage: '請輸入 APIPassword',
 
 
 
 
 
494
  },
495
  message: {
496
  registered: '註冊成功',
 
491
  SparkModelNameMessage: '請選擇星火模型!',
492
  addSparkAPIPassword: '星火 APIPassword',
493
  SparkAPIPasswordMessage: '請輸入 APIPassword',
494
+ yiyanModelNameMessage: '輸入模型名稱',
495
+ addyiyanAK: '一言 API KEY',
496
+ yiyanAKMessage: '請輸入 API KEY',
497
+ addyiyanSK: '一言 Secret KEY',
498
+ yiyanSKMessage: '請輸入 Secret KEY',
499
  },
500
  message: {
501
  registered: '註冊成功',
web/src/locales/zh.ts CHANGED
@@ -508,6 +508,11 @@ export default {
508
  SparkModelNameMessage: '请选择星火模型!',
509
  addSparkAPIPassword: '星火 APIPassword',
510
  SparkAPIPasswordMessage: '请输入 APIPassword',
 
 
 
 
 
511
  },
512
  message: {
513
  registered: '注册成功',
 
508
  SparkModelNameMessage: '请选择星火模型!',
509
  addSparkAPIPassword: '星火 APIPassword',
510
  SparkAPIPasswordMessage: '请输入 APIPassword',
511
+ yiyanModelNameMessage: '请输入模型名称',
512
+ addyiyanAK: '一言 API KEY',
513
+ yiyanAKMessage: '请输入 API KEY',
514
+ addyiyanSK: '一言 Secret KEY',
515
+ yiyanSKMessage: '请输入 Secret KEY',
516
  },
517
  message: {
518
  registered: '注册成功',
web/src/pages/user-setting/setting-model/constant.ts CHANGED
@@ -34,6 +34,7 @@ export const IconMap = {
34
  Replicate: 'replicate',
35
  'Tencent Hunyuan': 'hunyuan',
36
  'XunFei Spark': 'spark',
 
37
  };
38
 
39
  export const BedrockRegionList = [
 
34
  Replicate: 'replicate',
35
  'Tencent Hunyuan': 'hunyuan',
36
  'XunFei Spark': 'spark',
37
+ BaiduYiyan: 'yiyan',
38
  };
39
 
40
  export const BedrockRegionList = [
web/src/pages/user-setting/setting-model/hooks.ts CHANGED
@@ -217,6 +217,33 @@ export const useSubmitSpark = () => {
217
  };
218
  };
219
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
220
  export const useSubmitBedrock = () => {
221
  const { addLlm, loading } = useAddLlm();
222
  const {
 
217
  };
218
  };
219
 
220
+ export const useSubmityiyan = () => {
221
+ const { addLlm, loading } = useAddLlm();
222
+ const {
223
+ visible: yiyanAddingVisible,
224
+ hideModal: hideyiyanAddingModal,
225
+ showModal: showyiyanAddingModal,
226
+ } = useSetModalState();
227
+
228
+ const onyiyanAddingOk = useCallback(
229
+ async (payload: IAddLlmRequestBody) => {
230
+ const ret = await addLlm(payload);
231
+ if (ret === 0) {
232
+ hideyiyanAddingModal();
233
+ }
234
+ },
235
+ [hideyiyanAddingModal, addLlm],
236
+ );
237
+
238
+ return {
239
+ yiyanAddingLoading: loading,
240
+ onyiyanAddingOk,
241
+ yiyanAddingVisible,
242
+ hideyiyanAddingModal,
243
+ showyiyanAddingModal,
244
+ };
245
+ };
246
+
247
  export const useSubmitBedrock = () => {
248
  const { addLlm, loading } = useAddLlm();
249
  const {
web/src/pages/user-setting/setting-model/index.tsx CHANGED
@@ -39,6 +39,7 @@ import {
39
  useSubmitSpark,
40
  useSubmitSystemModelSetting,
41
  useSubmitVolcEngine,
 
42
  } from './hooks';
43
  import HunyuanModal from './hunyuan-modal';
44
  import styles from './index.less';
@@ -46,6 +47,7 @@ import OllamaModal from './ollama-modal';
46
  import SparkModal from './spark-modal';
47
  import SystemModelSettingModal from './system-model-setting-modal';
48
  import VolcEngineModal from './volcengine-modal';
 
49
 
50
  const LlmIcon = ({ name }: { name: string }) => {
51
  const icon = IconMap[name as keyof typeof IconMap];
@@ -95,7 +97,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {
95
  {isLocalLlmFactory(item.name) ||
96
  item.name === 'VolcEngine' ||
97
  item.name === 'Tencent Hunyuan' ||
98
- item.name === 'XunFei Spark'
 
99
  ? t('addTheModel')
100
  : 'API-Key'}
101
  <SettingOutlined />
@@ -185,6 +188,14 @@ const UserSettingModel = () => {
185
  SparkAddingLoading,
186
  } = useSubmitSpark();
187
 
 
 
 
 
 
 
 
 
188
  const {
189
  bedrockAddingLoading,
190
  onBedrockAddingOk,
@@ -199,12 +210,14 @@ const UserSettingModel = () => {
199
  VolcEngine: showVolcAddingModal,
200
  'Tencent Hunyuan': showHunyuanAddingModal,
201
  'XunFei Spark': showSparkAddingModal,
 
202
  }),
203
  [
204
  showBedrockAddingModal,
205
  showVolcAddingModal,
206
  showHunyuanAddingModal,
207
  showSparkAddingModal,
 
208
  ],
209
  );
210
 
@@ -330,6 +343,13 @@ const UserSettingModel = () => {
330
  loading={SparkAddingLoading}
331
  llmFactory={'XunFei Spark'}
332
  ></SparkModal>
 
 
 
 
 
 
 
333
  <BedrockModal
334
  visible={bedrockAddingVisible}
335
  hideModal={hideBedrockAddingModal}
 
39
  useSubmitSpark,
40
  useSubmitSystemModelSetting,
41
  useSubmitVolcEngine,
42
+ useSubmityiyan,
43
  } from './hooks';
44
  import HunyuanModal from './hunyuan-modal';
45
  import styles from './index.less';
 
47
  import SparkModal from './spark-modal';
48
  import SystemModelSettingModal from './system-model-setting-modal';
49
  import VolcEngineModal from './volcengine-modal';
50
+ import YiyanModal from './yiyan-modal';
51
 
52
  const LlmIcon = ({ name }: { name: string }) => {
53
  const icon = IconMap[name as keyof typeof IconMap];
 
97
  {isLocalLlmFactory(item.name) ||
98
  item.name === 'VolcEngine' ||
99
  item.name === 'Tencent Hunyuan' ||
100
+ item.name === 'XunFei Spark' ||
101
+ item.name === 'BaiduYiyan'
102
  ? t('addTheModel')
103
  : 'API-Key'}
104
  <SettingOutlined />
 
188
  SparkAddingLoading,
189
  } = useSubmitSpark();
190
 
191
+ const {
192
+ yiyanAddingVisible,
193
+ hideyiyanAddingModal,
194
+ showyiyanAddingModal,
195
+ onyiyanAddingOk,
196
+ yiyanAddingLoading,
197
+ } = useSubmityiyan();
198
+
199
  const {
200
  bedrockAddingLoading,
201
  onBedrockAddingOk,
 
210
  VolcEngine: showVolcAddingModal,
211
  'Tencent Hunyuan': showHunyuanAddingModal,
212
  'XunFei Spark': showSparkAddingModal,
213
+ BaiduYiyan: showyiyanAddingModal,
214
  }),
215
  [
216
  showBedrockAddingModal,
217
  showVolcAddingModal,
218
  showHunyuanAddingModal,
219
  showSparkAddingModal,
220
+ showyiyanAddingModal,
221
  ],
222
  );
223
 
 
343
  loading={SparkAddingLoading}
344
  llmFactory={'XunFei Spark'}
345
  ></SparkModal>
346
+ <YiyanModal
347
+ visible={yiyanAddingVisible}
348
+ hideModal={hideyiyanAddingModal}
349
+ onOk={onyiyanAddingOk}
350
+ loading={yiyanAddingLoading}
351
+ llmFactory={'BaiduYiyan'}
352
+ ></YiyanModal>
353
  <BedrockModal
354
  visible={bedrockAddingVisible}
355
  hideModal={hideBedrockAddingModal}
web/src/pages/user-setting/setting-model/yiyan-modal/index.tsx ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useTranslate } from '@/hooks/common-hooks';
2
+ import { IModalProps } from '@/interfaces/common';
3
+ import { IAddLlmRequestBody } from '@/interfaces/request/llm';
4
+ import { Form, Input, Modal, Select } from 'antd';
5
+ import omit from 'lodash/omit';
6
+
7
+ type FieldType = IAddLlmRequestBody & {
8
+ vision: boolean;
9
+ yiyan_ak: string;
10
+ yiyan_sk: string;
11
+ };
12
+
13
+ const { Option } = Select;
14
+
15
+ const YiyanModal = ({
16
+ visible,
17
+ hideModal,
18
+ onOk,
19
+ loading,
20
+ llmFactory,
21
+ }: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
22
+ const [form] = Form.useForm<FieldType>();
23
+
24
+ const { t } = useTranslate('setting');
25
+
26
+ const handleOk = async () => {
27
+ const values = await form.validateFields();
28
+ const modelType =
29
+ values.model_type === 'chat' && values.vision
30
+ ? 'image2text'
31
+ : values.model_type;
32
+
33
+ const data = {
34
+ ...omit(values, ['vision']),
35
+ model_type: modelType,
36
+ llm_factory: llmFactory,
37
+ };
38
+ console.info(data);
39
+
40
+ onOk?.(data);
41
+ };
42
+
43
+ return (
44
+ <Modal
45
+ title={t('addLlmTitle', { name: llmFactory })}
46
+ open={visible}
47
+ onOk={handleOk}
48
+ onCancel={hideModal}
49
+ okButtonProps={{ loading }}
50
+ confirmLoading={loading}
51
+ >
52
+ <Form
53
+ name="basic"
54
+ style={{ maxWidth: 600 }}
55
+ autoComplete="off"
56
+ layout={'vertical'}
57
+ form={form}
58
+ >
59
+ <Form.Item<FieldType>
60
+ label={t('modelType')}
61
+ name="model_type"
62
+ initialValue={'chat'}
63
+ rules={[{ required: true, message: t('modelTypeMessage') }]}
64
+ >
65
+ <Select placeholder={t('modelTypeMessage')}>
66
+ <Option value="chat">chat</Option>
67
+ <Option value="embedding">embedding</Option>
68
+ <Option value="rerank">rerank</Option>
69
+ </Select>
70
+ </Form.Item>
71
+ <Form.Item<FieldType>
72
+ label={t('modelName')}
73
+ name="llm_name"
74
+ rules={[{ required: true, message: t('yiyanModelNameMessage') }]}
75
+ >
76
+ <Input placeholder={t('yiyanModelNameMessage')} />
77
+ </Form.Item>
78
+ <Form.Item<FieldType>
79
+ label={t('addyiyanAK')}
80
+ name="yiyan_ak"
81
+ rules={[{ required: true, message: t('yiyanAKMessage') }]}
82
+ >
83
+ <Input placeholder={t('yiyanAKMessage')} />
84
+ </Form.Item>
85
+ <Form.Item<FieldType>
86
+ label={t('addyiyanSK')}
87
+ name="yiyan_sk"
88
+ rules={[{ required: true, message: t('yiyanSKMessage') }]}
89
+ >
90
+ <Input placeholder={t('yiyanSKMessage')} />
91
+ </Form.Item>
92
+ </Form>
93
+ </Modal>
94
+ );
95
+ };
96
+
97
+ export default YiyanModal;