Kevin Hu commited on
Commit
258e6bf
·
1 Parent(s): aebd986

Make spark model robuster to model name (#3514)

Browse files

### What problem does this PR solve?


### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (2) hide show
  1. conf/llm_factories.json +4 -4
  2. rag/llm/chat_model.py +5 -1
conf/llm_factories.json CHANGED
@@ -3,7 +3,7 @@
3
  {
4
  "name": "OpenAI",
5
  "logo": "",
6
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
7
  "status": "1",
8
  "llm": [
9
  {
@@ -89,7 +89,7 @@
89
  {
90
  "name": "Tongyi-Qianwen",
91
  "logo": "",
92
- "tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK,SPEECH2TEXT,MODERATION",
93
  "status": "1",
94
  "llm": [
95
  {
@@ -352,7 +352,7 @@
352
  {
353
  "name": "Xinference",
354
  "logo": "",
355
- "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION,TEXT RE-RANK",
356
  "status": "1",
357
  "llm": []
358
  },
@@ -2303,7 +2303,7 @@
2303
  {
2304
  "name": "XunFei Spark",
2305
  "logo": "",
2306
- "tags": "LLM",
2307
  "status": "1",
2308
  "llm": []
2309
  },
 
3
  {
4
  "name": "OpenAI",
5
  "logo": "",
6
+ "tags": "LLM,TEXT EMBEDDING,TTS,TEXT RE-RANK,SPEECH2TEXT,MODERATION",
7
  "status": "1",
8
  "llm": [
9
  {
 
89
  {
90
  "name": "Tongyi-Qianwen",
91
  "logo": "",
92
+ "tags": "LLM,TEXT EMBEDDING,TEXT RE-RANK,TTS,SPEECH2TEXT,MODERATION",
93
  "status": "1",
94
  "llm": [
95
  {
 
352
  {
353
  "name": "Xinference",
354
  "logo": "",
355
+ "tags": "LLM,TEXT EMBEDDING,TTS,SPEECH2TEXT,MODERATION,TEXT RE-RANK",
356
  "status": "1",
357
  "llm": []
358
  },
 
2303
  {
2304
  "name": "XunFei Spark",
2305
  "logo": "",
2306
+ "tags": "LLM,TTS",
2307
  "status": "1",
2308
  "llm": []
2309
  },
rag/llm/chat_model.py CHANGED
@@ -1164,7 +1164,11 @@ class SparkChat(Base):
1164
  "Spark-Pro-128K": "pro-128k",
1165
  "Spark-4.0-Ultra": "4.0Ultra",
1166
  }
1167
- model_version = model2version[model_name]
 
 
 
 
1168
  super().__init__(key, model_version, base_url)
1169
 
1170
 
 
1164
  "Spark-Pro-128K": "pro-128k",
1165
  "Spark-4.0-Ultra": "4.0Ultra",
1166
  }
1167
+ version2model = {v: k for k, v in model2version.items()}
1168
+ assert model_name in model2version or model_name in version2model, f"The given model name is not supported yet. Support: {list(model2version.keys())}"
1169
+ if model_name in model2version:
1170
+ model_version = model2version[model_name]
1171
+ else: model_version = model_name
1172
  super().__init__(key, model_version, base_url)
1173
 
1174