黄腾 aopstudio commited on
Commit
9e27b49
·
1 Parent(s): 9cded99

add support for XunFei Spark (#2017)

Browse files

### What problem does this PR solve?

#1853 add support for XunFei Spark

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Co-authored-by: Zhedong Cen <[email protected]>

api/apps/llm_app.py CHANGED
@@ -138,6 +138,9 @@ def add_llm():
138
  elif factory == "OpenAI-API-Compatible":
139
  llm_name = req["llm_name"]+"___OpenAI-API"
140
  api_key = req.get("api_key","xxxxxxxxxxxxxxx")
 
 
 
141
  else:
142
  llm_name = req["llm_name"]
143
  api_key = req.get("api_key","xxxxxxxxxxxxxxx")
@@ -165,7 +168,7 @@ def add_llm():
165
  msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
166
  elif llm["model_type"] == LLMType.CHAT.value:
167
  mdl = ChatModel[factory](
168
- key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate"] else None,
169
  model_name=llm["llm_name"],
170
  base_url=llm["api_base"]
171
  )
 
138
  elif factory == "OpenAI-API-Compatible":
139
  llm_name = req["llm_name"]+"___OpenAI-API"
140
  api_key = req.get("api_key","xxxxxxxxxxxxxxx")
141
+ elif factory =="XunFei Spark":
142
+ llm_name = req["llm_name"]
143
+ api_key = req.get("spark_api_password","")
144
  else:
145
  llm_name = req["llm_name"]
146
  api_key = req.get("api_key","xxxxxxxxxxxxxxx")
 
168
  msg += f"\nFail to access embedding model({llm['llm_name']})." + str(e)
169
  elif llm["model_type"] == LLMType.CHAT.value:
170
  mdl = ChatModel[factory](
171
+ key=llm['api_key'] if factory in ["VolcEngine", "Bedrock","OpenAI-API-Compatible","Replicate","XunFei Spark"] else None,
172
  model_name=llm["llm_name"],
173
  base_url=llm["api_base"]
174
  )
conf/llm_factories.json CHANGED
@@ -3194,6 +3194,13 @@
3194
  "model_type": "image2text"
3195
  }
3196
  ]
 
 
 
 
 
 
 
3197
  }
3198
  ]
3199
  }
 
3194
  "model_type": "image2text"
3195
  }
3196
  ]
3197
+ },
3198
+ {
3199
+ "name": "XunFei Spark",
3200
+ "logo": "",
3201
+ "tags": "LLM",
3202
+ "status": "1",
3203
+ "llm": []
3204
  }
3205
  ]
3206
  }
rag/llm/__init__.py CHANGED
@@ -100,7 +100,8 @@ ChatModel = {
100
  "SILICONFLOW": SILICONFLOWChat,
101
  "01.AI": YiChat,
102
  "Replicate": ReplicateChat,
103
- "Tencent Hunyuan": HunyuanChat
 
104
  }
105
 
106
 
 
100
  "SILICONFLOW": SILICONFLOWChat,
101
  "01.AI": YiChat,
102
  "Replicate": ReplicateChat,
103
+ "Tencent Hunyuan": HunyuanChat,
104
+ "XunFei Spark": SparkChat
105
  }
106
 
107
 
rag/llm/chat_model.py CHANGED
@@ -1133,12 +1133,12 @@ class HunyuanChat(Base):
1133
  from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
1134
  TencentCloudSDKException,
1135
  )
1136
-
1137
  _gen_conf = {}
1138
  _history = [{k.capitalize(): v for k, v in item.items() } for item in history]
1139
  if system:
1140
  _history.insert(0, {"Role": "system", "Content": system})
1141
-
1142
  if "temperature" in gen_conf:
1143
  _gen_conf["Temperature"] = gen_conf["temperature"]
1144
  if "top_p" in gen_conf:
@@ -1168,3 +1168,20 @@ class HunyuanChat(Base):
1168
  yield ans + "\n**ERROR**: " + str(e)
1169
 
1170
  yield total_tokens
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1133
  from tencentcloud.common.exception.tencent_cloud_sdk_exception import (
1134
  TencentCloudSDKException,
1135
  )
1136
+
1137
  _gen_conf = {}
1138
  _history = [{k.capitalize(): v for k, v in item.items() } for item in history]
1139
  if system:
1140
  _history.insert(0, {"Role": "system", "Content": system})
1141
+
1142
  if "temperature" in gen_conf:
1143
  _gen_conf["Temperature"] = gen_conf["temperature"]
1144
  if "top_p" in gen_conf:
 
1168
  yield ans + "\n**ERROR**: " + str(e)
1169
 
1170
  yield total_tokens
1171
+
1172
+
1173
+ class SparkChat(Base):
1174
+ def __init__(
1175
+ self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
1176
+ ):
1177
+ if not base_url:
1178
+ base_url = "https://spark-api-open.xf-yun.com/v1"
1179
+ model2version = {
1180
+ "Spark-Max": "generalv3.5",
1181
+ "Spark-Lite": "general",
1182
+ "Spark-Pro": "generalv3",
1183
+ "Spark-Pro-128K": "pro-128k",
1184
+ "Spark-4.0-Ultra": "4.0Ultra",
1185
+ }
1186
+ model_version = model2version[model_name]
1187
+ super().__init__(key, model_version, base_url)
web/src/assets/svg/llm/spark.svg ADDED
web/src/locales/en.ts CHANGED
@@ -525,6 +525,9 @@ The above is the content you need to summarize.`,
525
  HunyuanSIDMessage: 'Please input your Secret ID',
526
  addHunyuanSK: 'Hunyuan Secret Key',
527
  HunyuanSKMessage: 'Please input your Secret Key',
 
 
 
528
  },
529
  message: {
530
  registered: 'Registered!',
 
525
  HunyuanSIDMessage: 'Please input your Secret ID',
526
  addHunyuanSK: 'Hunyuan Secret Key',
527
  HunyuanSKMessage: 'Please input your Secret Key',
528
+ SparkModelNameMessage: 'Please select Spark model',
529
+ addSparkAPIPassword: 'Spark APIPassword',
530
+ SparkAPIPasswordMessage: 'please input your APIPassword',
531
  },
532
  message: {
533
  registered: 'Registered!',
web/src/locales/zh-traditional.ts CHANGED
@@ -488,6 +488,9 @@ export default {
488
  HunyuanSIDMessage: '請輸入 Secret ID',
489
  addHunyuanSK: '混元 Secret Key',
490
  HunyuanSKMessage: '請輸入 Secret Key',
 
 
 
491
  },
492
  message: {
493
  registered: '註冊成功',
 
488
  HunyuanSIDMessage: '請輸入 Secret ID',
489
  addHunyuanSK: '混元 Secret Key',
490
  HunyuanSKMessage: '請輸入 Secret Key',
491
+ SparkModelNameMessage: '請選擇星火模型!',
492
+ addSparkAPIPassword: '星火 APIPassword',
493
+ SparkAPIPasswordMessage: '請輸入 APIPassword',
494
  },
495
  message: {
496
  registered: '註冊成功',
web/src/locales/zh.ts CHANGED
@@ -505,6 +505,9 @@ export default {
505
  HunyuanSIDMessage: '请输入 Secret ID',
506
  addHunyuanSK: '混元 Secret Key',
507
  HunyuanSKMessage: '请输入 Secret Key',
 
 
 
508
  },
509
  message: {
510
  registered: '注册成功',
 
505
  HunyuanSIDMessage: '请输入 Secret ID',
506
  addHunyuanSK: '混元 Secret Key',
507
  HunyuanSKMessage: '请输入 Secret Key',
508
+ SparkModelNameMessage: '请选择星火模型!',
509
+ addSparkAPIPassword: '星火 APIPassword',
510
+ SparkAPIPasswordMessage: '请输入 APIPassword',
511
  },
512
  message: {
513
  registered: '注册成功',
web/src/pages/user-setting/setting-model/constant.ts CHANGED
@@ -33,6 +33,7 @@ export const IconMap = {
33
  '01.AI': 'yi',
34
  Replicate: 'replicate',
35
  'Tencent Hunyuan': 'hunyuan',
 
36
  };
37
 
38
  export const BedrockRegionList = [
 
33
  '01.AI': 'yi',
34
  Replicate: 'replicate',
35
  'Tencent Hunyuan': 'hunyuan',
36
+ 'XunFei Spark': 'spark',
37
  };
38
 
39
  export const BedrockRegionList = [
web/src/pages/user-setting/setting-model/hooks.ts CHANGED
@@ -190,6 +190,33 @@ export const useSubmitHunyuan = () => {
190
  };
191
  };
192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  export const useSubmitBedrock = () => {
194
  const { addLlm, loading } = useAddLlm();
195
  const {
 
190
  };
191
  };
192
 
193
+ export const useSubmitSpark = () => {
194
+ const { addLlm, loading } = useAddLlm();
195
+ const {
196
+ visible: SparkAddingVisible,
197
+ hideModal: hideSparkAddingModal,
198
+ showModal: showSparkAddingModal,
199
+ } = useSetModalState();
200
+
201
+ const onSparkAddingOk = useCallback(
202
+ async (payload: IAddLlmRequestBody) => {
203
+ const ret = await addLlm(payload);
204
+ if (ret === 0) {
205
+ hideSparkAddingModal();
206
+ }
207
+ },
208
+ [hideSparkAddingModal, addLlm],
209
+ );
210
+
211
+ return {
212
+ SparkAddingLoading: loading,
213
+ onSparkAddingOk,
214
+ SparkAddingVisible,
215
+ hideSparkAddingModal,
216
+ showSparkAddingModal,
217
+ };
218
+ };
219
+
220
  export const useSubmitBedrock = () => {
221
  const { addLlm, loading } = useAddLlm();
222
  const {
web/src/pages/user-setting/setting-model/index.tsx CHANGED
@@ -36,12 +36,14 @@ import {
36
  useSubmitBedrock,
37
  useSubmitHunyuan,
38
  useSubmitOllama,
 
39
  useSubmitSystemModelSetting,
40
  useSubmitVolcEngine,
41
  } from './hooks';
42
  import HunyuanModal from './hunyuan-modal';
43
  import styles from './index.less';
44
  import OllamaModal from './ollama-modal';
 
45
  import SystemModelSettingModal from './system-model-setting-modal';
46
  import VolcEngineModal from './volcengine-modal';
47
 
@@ -92,7 +94,8 @@ const ModelCard = ({ item, clickApiKey }: IModelCardProps) => {
92
  <Button onClick={handleApiKeyClick}>
93
  {isLocalLlmFactory(item.name) ||
94
  item.name === 'VolcEngine' ||
95
- item.name === 'Tencent Hunyuan'
 
96
  ? t('addTheModel')
97
  : 'API-Key'}
98
  <SettingOutlined />
@@ -174,6 +177,14 @@ const UserSettingModel = () => {
174
  HunyuanAddingLoading,
175
  } = useSubmitHunyuan();
176
 
 
 
 
 
 
 
 
 
177
  const {
178
  bedrockAddingLoading,
179
  onBedrockAddingOk,
@@ -187,8 +198,14 @@ const UserSettingModel = () => {
187
  Bedrock: showBedrockAddingModal,
188
  VolcEngine: showVolcAddingModal,
189
  'Tencent Hunyuan': showHunyuanAddingModal,
 
190
  }),
191
- [showBedrockAddingModal, showVolcAddingModal, showHunyuanAddingModal],
 
 
 
 
 
192
  );
193
 
194
  const handleAddModel = useCallback(
@@ -306,6 +323,13 @@ const UserSettingModel = () => {
306
  loading={HunyuanAddingLoading}
307
  llmFactory={'Tencent Hunyuan'}
308
  ></HunyuanModal>
 
 
 
 
 
 
 
309
  <BedrockModal
310
  visible={bedrockAddingVisible}
311
  hideModal={hideBedrockAddingModal}
 
36
  useSubmitBedrock,
37
  useSubmitHunyuan,
38
  useSubmitOllama,
39
+ useSubmitSpark,
40
  useSubmitSystemModelSetting,
41
  useSubmitVolcEngine,
42
  } from './hooks';
43
  import HunyuanModal from './hunyuan-modal';
44
  import styles from './index.less';
45
  import OllamaModal from './ollama-modal';
46
+ import SparkModal from './spark-modal';
47
  import SystemModelSettingModal from './system-model-setting-modal';
48
  import VolcEngineModal from './volcengine-modal';
49
 
 
94
  <Button onClick={handleApiKeyClick}>
95
  {isLocalLlmFactory(item.name) ||
96
  item.name === 'VolcEngine' ||
97
+ item.name === 'Tencent Hunyuan' ||
98
+ item.name === 'XunFei Spark'
99
  ? t('addTheModel')
100
  : 'API-Key'}
101
  <SettingOutlined />
 
177
  HunyuanAddingLoading,
178
  } = useSubmitHunyuan();
179
 
180
+ const {
181
+ SparkAddingVisible,
182
+ hideSparkAddingModal,
183
+ showSparkAddingModal,
184
+ onSparkAddingOk,
185
+ SparkAddingLoading,
186
+ } = useSubmitSpark();
187
+
188
  const {
189
  bedrockAddingLoading,
190
  onBedrockAddingOk,
 
198
  Bedrock: showBedrockAddingModal,
199
  VolcEngine: showVolcAddingModal,
200
  'Tencent Hunyuan': showHunyuanAddingModal,
201
+ 'XunFei Spark': showSparkAddingModal,
202
  }),
203
+ [
204
+ showBedrockAddingModal,
205
+ showVolcAddingModal,
206
+ showHunyuanAddingModal,
207
+ showSparkAddingModal,
208
+ ],
209
  );
210
 
211
  const handleAddModel = useCallback(
 
323
  loading={HunyuanAddingLoading}
324
  llmFactory={'Tencent Hunyuan'}
325
  ></HunyuanModal>
326
+ <SparkModal
327
+ visible={SparkAddingVisible}
328
+ hideModal={hideSparkAddingModal}
329
+ onOk={onSparkAddingOk}
330
+ loading={SparkAddingLoading}
331
+ llmFactory={'XunFei Spark'}
332
+ ></SparkModal>
333
  <BedrockModal
334
  visible={bedrockAddingVisible}
335
  hideModal={hideBedrockAddingModal}
web/src/pages/user-setting/setting-model/spark-modal/index.tsx ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { useTranslate } from '@/hooks/common-hooks';
2
+ import { IModalProps } from '@/interfaces/common';
3
+ import { IAddLlmRequestBody } from '@/interfaces/request/llm';
4
+ import { Form, Input, Modal, Select } from 'antd';
5
+ import omit from 'lodash/omit';
6
+
7
+ type FieldType = IAddLlmRequestBody & {
8
+ vision: boolean;
9
+ spark_api_password: string;
10
+ };
11
+
12
+ const { Option } = Select;
13
+
14
+ const SparkModal = ({
15
+ visible,
16
+ hideModal,
17
+ onOk,
18
+ loading,
19
+ llmFactory,
20
+ }: IModalProps<IAddLlmRequestBody> & { llmFactory: string }) => {
21
+ const [form] = Form.useForm<FieldType>();
22
+
23
+ const { t } = useTranslate('setting');
24
+
25
+ const handleOk = async () => {
26
+ const values = await form.validateFields();
27
+ const modelType =
28
+ values.model_type === 'chat' && values.vision
29
+ ? 'image2text'
30
+ : values.model_type;
31
+
32
+ const data = {
33
+ ...omit(values, ['vision']),
34
+ model_type: modelType,
35
+ llm_factory: llmFactory,
36
+ };
37
+ console.info(data);
38
+
39
+ onOk?.(data);
40
+ };
41
+
42
+ return (
43
+ <Modal
44
+ title={t('addLlmTitle', { name: llmFactory })}
45
+ open={visible}
46
+ onOk={handleOk}
47
+ onCancel={hideModal}
48
+ okButtonProps={{ loading }}
49
+ confirmLoading={loading}
50
+ >
51
+ <Form
52
+ name="basic"
53
+ style={{ maxWidth: 600 }}
54
+ autoComplete="off"
55
+ layout={'vertical'}
56
+ form={form}
57
+ >
58
+ <Form.Item<FieldType>
59
+ label={t('modelType')}
60
+ name="model_type"
61
+ initialValue={'chat'}
62
+ rules={[{ required: true, message: t('modelTypeMessage') }]}
63
+ >
64
+ <Select placeholder={t('modelTypeMessage')}>
65
+ <Option value="chat">chat</Option>
66
+ </Select>
67
+ </Form.Item>
68
+ <Form.Item<FieldType>
69
+ label={t('modelName')}
70
+ name="llm_name"
71
+ initialValue={'Spark-Max'}
72
+ rules={[{ required: true, message: t('SparkModelNameMessage') }]}
73
+ >
74
+ <Select placeholder={t('modelTypeMessage')}>
75
+ <Option value="Spark-Max">Spark-Max</Option>
76
+ <Option value="Spark-Lite">Spark-Lite</Option>
77
+ <Option value="Spark-Pro">Spark-Pro</Option>
78
+ <Option value="Spark-Pro-128K">Spark-Pro-128K</Option>
79
+ <Option value="Spark-4.0-Ultra">Spark-4.0-Ultra</Option>
80
+ </Select>
81
+ </Form.Item>
82
+ <Form.Item<FieldType>
83
+ label={t('addSparkAPIPassword')}
84
+ name="spark_api_password"
85
+ rules={[{ required: true, message: t('SparkPasswordMessage') }]}
86
+ >
87
+ <Input placeholder={t('SparkSIDMessage')} />
88
+ </Form.Item>
89
+ </Form>
90
+ </Modal>
91
+ );
92
+ };
93
+
94
+ export default SparkModal;