黄腾 aopstudio commited on
Commit
2e1c73c
·
1 Parent(s): 07dead3

add support for cohere (#1849)

Browse files

### What problem does this PR solve?

_Briefly describe what this PR aims to solve. Include background context
that will help reviewers understand the purpose of the PR._

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

---------

Co-authored-by: Zhedong Cen <[email protected]>

conf/llm_factories.json CHANGED
@@ -2216,6 +2216,116 @@
2216
  "tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
2217
  "status": "1",
2218
  "llm": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2219
  }
2220
  ]
2221
  }
 
2216
  "tags": "LLM,TEXT EMBEDDING,IMAGE2TEXT",
2217
  "status": "1",
2218
  "llm": []
2219
+ },
2220
+ {
2221
+ "name": "cohere",
2222
+ "logo": "",
2223
+ "tags": "LLM,TEXT EMBEDDING, TEXT RE-RANK",
2224
+ "status": "1",
2225
+ "llm": [
2226
+ {
2227
+ "llm_name": "command-r-plus",
2228
+ "tags": "LLM,CHAT,128k",
2229
+ "max_tokens": 131072,
2230
+ "model_type": "chat"
2231
+ },
2232
+ {
2233
+ "llm_name": "command-r",
2234
+ "tags": "LLM,CHAT,128k",
2235
+ "max_tokens": 131072,
2236
+ "model_type": "chat"
2237
+ },
2238
+ {
2239
+ "llm_name": "command",
2240
+ "tags": "LLM,CHAT,4k",
2241
+ "max_tokens": 4096,
2242
+ "model_type": "chat"
2243
+ },
2244
+ {
2245
+ "llm_name": "command-nightly",
2246
+ "tags": "LLM,CHAT,128k",
2247
+ "max_tokens": 131072,
2248
+ "model_type": "chat"
2249
+ },
2250
+ {
2251
+ "llm_name": "command-light",
2252
+ "tags": "LLM,CHAT,4k",
2253
+ "max_tokens": 4096,
2254
+ "model_type": "chat"
2255
+ },
2256
+ {
2257
+ "llm_name": "command-light-nightly",
2258
+ "tags": "LLM,CHAT,4k",
2259
+ "max_tokens": 4096,
2260
+ "model_type": "chat"
2261
+ },
2262
+ {
2263
+ "llm_name": "embed-english-v3.0",
2264
+ "tags": "TEXT EMBEDDING",
2265
+ "max_tokens": 512,
2266
+ "model_type": "embedding"
2267
+ },
2268
+ {
2269
+ "llm_name": "embed-english-light-v3.0",
2270
+ "tags": "TEXT EMBEDDING",
2271
+ "max_tokens": 512,
2272
+ "model_type": "embedding"
2273
+ },
2274
+ {
2275
+ "llm_name": "embed-multilingual-v3.0",
2276
+ "tags": "TEXT EMBEDDING",
2277
+ "max_tokens": 512,
2278
+ "model_type": "embedding"
2279
+ },
2280
+ {
2281
+ "llm_name": "embed-multilingual-light-v3.0",
2282
+ "tags": "TEXT EMBEDDING",
2283
+ "max_tokens": 512,
2284
+ "model_type": "embedding"
2285
+ },
2286
+ {
2287
+ "llm_name": "embed-english-v2.0",
2288
+ "tags": "TEXT EMBEDDING",
2289
+ "max_tokens": 512,
2290
+ "model_type": "embedding"
2291
+ },
2292
+ {
2293
+ "llm_name": "embed-english-light-v2.0",
2294
+ "tags": "TEXT EMBEDDING",
2295
+ "max_tokens": 512,
2296
+ "model_type": "embedding"
2297
+ },
2298
+ {
2299
+ "llm_name": "embed-multilingual-v2.0",
2300
+ "tags": "TEXT EMBEDDING",
2301
+ "max_tokens": 256,
2302
+ "model_type": "embedding"
2303
+ },
2304
+ {
2305
+ "llm_name": "rerank-english-v3.0",
2306
+ "tags": "RE-RANK,4k",
2307
+ "max_tokens": 4096,
2308
+ "model_type": "rerank"
2309
+ },
2310
+ {
2311
+ "llm_name": "rerank-multilingual-v3.0",
2312
+ "tags": "RE-RANK,4k",
2313
+ "max_tokens": 4096,
2314
+ "model_type": "rerank"
2315
+ },
2316
+ {
2317
+ "llm_name": "rerank-english-v2.0",
2318
+ "tags": "RE-RANK,512",
2319
+ "max_tokens": 8196,
2320
+ "model_type": "rerank"
2321
+ },
2322
+ {
2323
+ "llm_name": "rerank-multilingual-v2.0",
2324
+ "tags": "RE-RANK,512",
2325
+ "max_tokens": 512,
2326
+ "model_type": "rerank"
2327
+ }
2328
+ ]
2329
  }
2330
  ]
2331
  }
rag/llm/__init__.py CHANGED
@@ -37,7 +37,8 @@ EmbeddingModel = {
37
  "Gemini": GeminiEmbed,
38
  "NVIDIA": NvidiaEmbed,
39
  "LM-Studio": LmStudioEmbed,
40
- "OpenAI-API-Compatible": OpenAI_APIEmbed
 
41
  }
42
 
43
 
@@ -81,7 +82,8 @@ ChatModel = {
81
  "StepFun": StepFunChat,
82
  "NVIDIA": NvidiaChat,
83
  "LM-Studio": LmStudioChat,
84
- "OpenAI-API-Compatible": OpenAI_APIChat
 
85
  }
86
 
87
 
@@ -92,7 +94,8 @@ RerankModel = {
92
  "Xinference": XInferenceRerank,
93
  "NVIDIA": NvidiaRerank,
94
  "LM-Studio": LmStudioRerank,
95
- "OpenAI-API-Compatible": OpenAI_APIRerank
 
96
  }
97
 
98
 
 
37
  "Gemini": GeminiEmbed,
38
  "NVIDIA": NvidiaEmbed,
39
  "LM-Studio": LmStudioEmbed,
40
+ "OpenAI-API-Compatible": OpenAI_APIEmbed,
41
+ "cohere": CoHereEmbed
42
  }
43
 
44
 
 
82
  "StepFun": StepFunChat,
83
  "NVIDIA": NvidiaChat,
84
  "LM-Studio": LmStudioChat,
85
+ "OpenAI-API-Compatible": OpenAI_APIChat,
86
+ "cohere": CoHereChat
87
  }
88
 
89
 
 
94
  "Xinference": XInferenceRerank,
95
  "NVIDIA": NvidiaRerank,
96
  "LM-Studio": LmStudioRerank,
97
+ "OpenAI-API-Compatible": OpenAI_APIRerank,
98
+ "cohere": CoHereRerank
99
  }
100
 
101
 
rag/llm/chat_model.py CHANGED
@@ -900,3 +900,84 @@ class OpenAI_APIChat(Base):
900
  base_url = os.path.join(base_url, "v1")
901
  model_name = model_name.split("___")[0]
902
  super().__init__(key, model_name, base_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
900
  base_url = os.path.join(base_url, "v1")
901
  model_name = model_name.split("___")[0]
902
  super().__init__(key, model_name, base_url)
903
+
904
+
905
+ class CoHereChat(Base):
906
+ def __init__(self, key, model_name, base_url=""):
907
+ from cohere import Client
908
+
909
+ self.client = Client(api_key=key)
910
+ self.model_name = model_name
911
+
912
+ def chat(self, system, history, gen_conf):
913
+ if system:
914
+ history.insert(0, {"role": "system", "content": system})
915
+ if "top_p" in gen_conf:
916
+ gen_conf["p"] = gen_conf.pop("top_p")
917
+ if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf:
918
+ gen_conf.pop("presence_penalty")
919
+ for item in history:
920
+ if "role" in item and item["role"] == "user":
921
+ item["role"] = "USER"
922
+ if "role" in item and item["role"] == "assistant":
923
+ item["role"] = "CHATBOT"
924
+ if "content" in item:
925
+ item["message"] = item.pop("content")
926
+ mes = history.pop()["message"]
927
+ ans = ""
928
+ try:
929
+ response = self.client.chat(
930
+ model=self.model_name, chat_history=history, message=mes, **gen_conf
931
+ )
932
+ ans = response.text
933
+ if response.finish_reason == "MAX_TOKENS":
934
+ ans += (
935
+ "...\nFor the content length reason, it stopped, continue?"
936
+ if is_english([ans])
937
+ else "······\n由于长度的原因,回答被截断了,要继续吗?"
938
+ )
939
+ return (
940
+ ans,
941
+ response.meta.tokens.input_tokens + response.meta.tokens.output_tokens,
942
+ )
943
+ except Exception as e:
944
+ return ans + "\n**ERROR**: " + str(e), 0
945
+
946
+ def chat_streamly(self, system, history, gen_conf):
947
+ if system:
948
+ history.insert(0, {"role": "system", "content": system})
949
+ if "top_p" in gen_conf:
950
+ gen_conf["p"] = gen_conf.pop("top_p")
951
+ if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf:
952
+ gen_conf.pop("presence_penalty")
953
+ for item in history:
954
+ if "role" in item and item["role"] == "user":
955
+ item["role"] = "USER"
956
+ if "role" in item and item["role"] == "assistant":
957
+ item["role"] = "CHATBOT"
958
+ if "content" in item:
959
+ item["message"] = item.pop("content")
960
+ mes = history.pop()["message"]
961
+ ans = ""
962
+ total_tokens = 0
963
+ try:
964
+ response = self.client.chat_stream(
965
+ model=self.model_name, chat_history=history, message=mes, **gen_conf
966
+ )
967
+ for resp in response:
968
+ if resp.event_type == "text-generation":
969
+ ans += resp.text
970
+ total_tokens += num_tokens_from_string(resp.text)
971
+ elif resp.event_type == "stream-end":
972
+ if resp.finish_reason == "MAX_TOKENS":
973
+ ans += (
974
+ "...\nFor the content length reason, it stopped, continue?"
975
+ if is_english([ans])
976
+ else "······\n由于长度的原因,回答被截断了,要继续吗?"
977
+ )
978
+ yield ans
979
+
980
+ except Exception as e:
981
+ yield ans + "\n**ERROR**: " + str(e)
982
+
983
+ yield total_tokens
rag/llm/embedding_model.py CHANGED
@@ -522,4 +522,34 @@ class OpenAI_APIEmbed(OpenAIEmbed):
522
  if base_url.split("/")[-1] != "v1":
523
  base_url = os.path.join(base_url, "v1")
524
  self.client = OpenAI(api_key=key, base_url=base_url)
525
- self.model_name = model_name.split("___")[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
522
  if base_url.split("/")[-1] != "v1":
523
  base_url = os.path.join(base_url, "v1")
524
  self.client = OpenAI(api_key=key, base_url=base_url)
525
+ self.model_name = model_name.split("___")[0]
526
+
527
+
528
+ class CoHereEmbed(Base):
529
+ def __init__(self, key, model_name, base_url=None):
530
+ from cohere import Client
531
+
532
+ self.client = Client(api_key=key)
533
+ self.model_name = model_name
534
+
535
+ def encode(self, texts: list, batch_size=32):
536
+ res = self.client.embed(
537
+ texts=texts,
538
+ model=self.model_name,
539
+ input_type="search_query",
540
+ embedding_types=["float"],
541
+ )
542
+ return np.array([d for d in res.embeddings.float]), int(
543
+ res.meta.billed_units.input_tokens
544
+ )
545
+
546
+ def encode_queries(self, text):
547
+ res = self.client.embed(
548
+ texts=[text],
549
+ model=self.model_name,
550
+ input_type="search_query",
551
+ embedding_types=["float"],
552
+ )
553
+ return np.array([d for d in res.embeddings.float]), int(
554
+ res.meta.billed_units.input_tokens
555
+ )
rag/llm/rerank_model.py CHANGED
@@ -203,7 +203,9 @@ class NvidiaRerank(Base):
203
  "top_n": len(texts),
204
  }
205
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
206
- return (np.array([d["logit"] for d in res["rankings"]]), token_count)
 
 
207
 
208
 
209
  class LmStudioRerank(Base):
@@ -220,3 +222,26 @@ class OpenAI_APIRerank(Base):
220
 
221
  def similarity(self, query: str, texts: list):
222
  raise NotImplementedError("The api has not been implement")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  "top_n": len(texts),
204
  }
205
  res = requests.post(self.base_url, headers=self.headers, json=data).json()
206
+ rank = np.array([d["logit"] for d in res["rankings"]])
207
+ indexs = [d["index"] for d in res["rankings"]]
208
+ return rank[indexs], token_count
209
 
210
 
211
  class LmStudioRerank(Base):
 
222
 
223
  def similarity(self, query: str, texts: list):
224
  raise NotImplementedError("The api has not been implement")
225
+
226
+
227
+ class CoHereRerank(Base):
228
+ def __init__(self, key, model_name, base_url=None):
229
+ from cohere import Client
230
+
231
+ self.client = Client(api_key=key)
232
+ self.model_name = model_name
233
+
234
+ def similarity(self, query: str, texts: list):
235
+ token_count = num_tokens_from_string(query) + sum(
236
+ [num_tokens_from_string(t) for t in texts]
237
+ )
238
+ res = self.client.rerank(
239
+ model=self.model_name,
240
+ query=query,
241
+ documents=texts,
242
+ top_n=len(texts),
243
+ return_documents=False,
244
+ )
245
+ rank = np.array([d.relevance_score for d in res.results])
246
+ indexs = [d.index for d in res.results]
247
+ return rank[indexs], token_count
requirements.txt CHANGED
@@ -7,6 +7,7 @@ botocore==1.34.140
7
  cachetools==5.3.3
8
  chardet==5.2.0
9
  cn2an==0.5.22
 
10
  dashscope==1.14.1
11
  datrie==0.8.2
12
  demjson3==3.0.6
 
7
  cachetools==5.3.3
8
  chardet==5.2.0
9
  cn2an==0.5.22
10
+ cohere==5.6.2
11
  dashscope==1.14.1
12
  datrie==0.8.2
13
  demjson3==3.0.6
requirements_arm.txt CHANGED
@@ -14,6 +14,7 @@ certifi==2024.7.4
14
  cffi==1.16.0
15
  charset-normalizer==3.3.2
16
  click==8.1.7
 
17
  coloredlogs==15.0.1
18
  cryptography==42.0.5
19
  dashscope==1.14.1
 
14
  cffi==1.16.0
15
  charset-normalizer==3.3.2
16
  click==8.1.7
17
+ cohere==5.6.2
18
  coloredlogs==15.0.1
19
  cryptography==42.0.5
20
  dashscope==1.14.1
requirements_dev.txt CHANGED
@@ -14,6 +14,7 @@ certifi==2024.7.4
14
  cffi==1.16.0
15
  charset-normalizer==3.3.2
16
  click==8.1.7
 
17
  coloredlogs==15.0.1
18
  cryptography==42.0.5
19
  dashscope==1.14.1
 
14
  cffi==1.16.0
15
  charset-normalizer==3.3.2
16
  click==8.1.7
17
+ cohere==5.6.2
18
  coloredlogs==15.0.1
19
  cryptography==42.0.5
20
  dashscope==1.14.1
web/src/assets/svg/llm/cohere.svg ADDED
web/src/pages/user-setting/setting-model/constant.ts CHANGED
@@ -22,7 +22,8 @@ export const IconMap = {
22
  StepFun: 'stepfun',
23
  NVIDIA:'nvidia',
24
  'LM-Studio':'lm-studio',
25
- 'OpenAI-API-Compatible':'openai-api'
 
26
  };
27
 
28
  export const BedrockRegionList = [
 
22
  StepFun: 'stepfun',
23
  NVIDIA:'nvidia',
24
  'LM-Studio':'lm-studio',
25
+ 'OpenAI-API-Compatible':'openai-api',
26
+ 'cohere':'cohere'
27
  };
28
 
29
  export const BedrockRegionList = [