KevinHuSh commited on
Commit
4825b73
·
1 Parent(s): a3ebd45

add support for mistral (#1153)

Browse files

### What problem does this PR solve?

#433

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/db/init_data.py CHANGED
@@ -157,6 +157,11 @@ factory_infos = [{
157
  "logo": "",
158
  "tags": "LLM,TEXT EMBEDDING",
159
  "status": "1",
 
 
 
 
 
160
  }
161
  # {
162
  # "name": "文心一言",
@@ -584,6 +589,63 @@ def init_llm_factory():
584
  "max_tokens": 8192,
585
  "model_type": LLMType.CHAT.value
586
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
587
  ]
588
  for info in factory_infos:
589
  try:
 
157
  "logo": "",
158
  "tags": "LLM,TEXT EMBEDDING",
159
  "status": "1",
160
+ },{
161
+ "name": "Mistral",
162
+ "logo": "",
163
+ "tags": "LLM,TEXT EMBEDDING",
164
+ "status": "1",
165
  }
166
  # {
167
  # "name": "文心一言",
 
589
  "max_tokens": 8192,
590
  "model_type": LLMType.CHAT.value
591
  },
592
+ # ------------------------ Mistral -----------------------
593
+ {
594
+ "fid": factory_infos[14]["name"],
595
+ "llm_name": "open-mixtral-8x22b",
596
+ "tags": "LLM,CHAT,64k",
597
+ "max_tokens": 64000,
598
+ "model_type": LLMType.CHAT.value
599
+ },
600
+ {
601
+ "fid": factory_infos[14]["name"],
602
+ "llm_name": "open-mixtral-8x7b",
603
+ "tags": "LLM,CHAT,32k",
604
+ "max_tokens": 32000,
605
+ "model_type": LLMType.CHAT.value
606
+ },
607
+ {
608
+ "fid": factory_infos[14]["name"],
609
+ "llm_name": "open-mistral-7b",
610
+ "tags": "LLM,CHAT,32k",
611
+ "max_tokens": 32000,
612
+ "model_type": LLMType.CHAT.value
613
+ },
614
+ {
615
+ "fid": factory_infos[14]["name"],
616
+ "llm_name": "mistral-large-latest",
617
+ "tags": "LLM,CHAT,32k",
618
+ "max_tokens": 32000,
619
+ "model_type": LLMType.CHAT.value
620
+ },
621
+ {
622
+ "fid": factory_infos[14]["name"],
623
+ "llm_name": "mistral-small-latest",
624
+ "tags": "LLM,CHAT,32k",
625
+ "max_tokens": 32000,
626
+ "model_type": LLMType.CHAT.value
627
+ },
628
+ {
629
+ "fid": factory_infos[14]["name"],
630
+ "llm_name": "mistral-medium-latest",
631
+ "tags": "LLM,CHAT,32k",
632
+ "max_tokens": 32000,
633
+ "model_type": LLMType.CHAT.value
634
+ },
635
+ {
636
+ "fid": factory_infos[14]["name"],
637
+ "llm_name": "codestral-latest",
638
+ "tags": "LLM,CHAT,32k",
639
+ "max_tokens": 32000,
640
+ "model_type": LLMType.CHAT.value
641
+ },
642
+ {
643
+ "fid": factory_infos[14]["name"],
644
+ "llm_name": "mistral-embed",
645
+ "tags": "LLM,CHAT,8k",
646
+ "max_tokens": 8192,
647
+ "model_type": LLMType.EMBEDDING
648
+ },
649
  ]
650
  for info in factory_infos:
651
  try:
rag/llm/__init__.py CHANGED
@@ -29,7 +29,8 @@ EmbeddingModel = {
29
  "Youdao": YoudaoEmbed,
30
  "BaiChuan": BaiChuanEmbed,
31
  "Jina": JinaEmbed,
32
- "BAAI": DefaultEmbedding
 
33
  }
34
 
35
 
@@ -52,7 +53,8 @@ ChatModel = {
52
  "Moonshot": MoonshotChat,
53
  "DeepSeek": DeepSeekChat,
54
  "BaiChuan": BaiChuanChat,
55
- "MiniMax": MiniMaxChat
 
56
  }
57
 
58
 
 
29
  "Youdao": YoudaoEmbed,
30
  "BaiChuan": BaiChuanEmbed,
31
  "Jina": JinaEmbed,
32
+ "BAAI": DefaultEmbedding,
33
+ "Mistral": MistralEmbed
34
  }
35
 
36
 
 
53
  "Moonshot": MoonshotChat,
54
  "DeepSeek": DeepSeekChat,
55
  "BaiChuan": BaiChuanChat,
56
+ "MiniMax": MiniMaxChat,
57
+ "Mistral": MistralChat
58
  }
59
 
60
 
rag/llm/chat_model.py CHANGED
@@ -472,3 +472,57 @@ class MiniMaxChat(Base):
472
  if not base_url:
473
  base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"
474
  super().__init__(key, model_name, base_url)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
472
  if not base_url:
473
  base_url="https://api.minimax.chat/v1/text/chatcompletion_v2"
474
  super().__init__(key, model_name, base_url)
475
+
476
+
477
+ class MistralChat(Base):
478
+
479
+ def __init__(self, key, model_name, base_url=None):
480
+ from mistralai.client import MistralClient
481
+ self.client = MistralClient(api_key=key)
482
+ self.model_name = model_name
483
+
484
+ def chat(self, system, history, gen_conf):
485
+ if system:
486
+ history.insert(0, {"role": "system", "content": system})
487
+ for k in list(gen_conf.keys()):
488
+ if k not in ["temperature", "top_p", "max_tokens"]:
489
+ del gen_conf[k]
490
+ try:
491
+ response = self.client.chat(
492
+ model=self.model_name,
493
+ messages=history,
494
+ **gen_conf)
495
+ ans = response.choices[0].message.content
496
+ if response.choices[0].finish_reason == "length":
497
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
498
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
499
+ return ans, response.usage.total_tokens
500
+ except openai.APIError as e:
501
+ return "**ERROR**: " + str(e), 0
502
+
503
+ def chat_streamly(self, system, history, gen_conf):
504
+ if system:
505
+ history.insert(0, {"role": "system", "content": system})
506
+ for k in list(gen_conf.keys()):
507
+ if k not in ["temperature", "top_p", "max_tokens"]:
508
+ del gen_conf[k]
509
+ ans = ""
510
+ total_tokens = 0
511
+ try:
512
+ response = self.client.chat_stream(
513
+ model=self.model_name,
514
+ messages=history,
515
+ **gen_conf)
516
+ for resp in response:
517
+ if not resp.choices or not resp.choices[0].delta.content:continue
518
+ ans += resp.choices[0].delta.content
519
+ total_tokens += 1
520
+ if resp.choices[0].finish_reason == "length":
521
+ ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
522
+ [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
523
+ yield ans
524
+
525
+ except openai.APIError as e:
526
+ yield ans + "\n**ERROR**: " + str(e)
527
+
528
+ yield total_tokens
rag/llm/embedding_model.py CHANGED
@@ -343,4 +343,24 @@ class InfinityEmbed(Base):
343
  def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
344
  # Using the internal tokenizer to encode the texts and get the total
345
  # number of tokens
346
- return self.encode([text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
  def encode_queries(self, text: str) -> tuple[np.ndarray, int]:
344
  # Using the internal tokenizer to encode the texts and get the total
345
  # number of tokens
346
+ return self.encode([text])
347
+
348
+
349
+ class MistralEmbed(Base):
350
+ def __init__(self, key, model_name="mistral-embed",
351
+ base_url=None):
352
+ from mistralai.client import MistralClient
353
+ self.client = MistralClient(api_key=key)
354
+ self.model_name = model_name
355
+
356
+ def encode(self, texts: list, batch_size=32):
357
+ texts = [truncate(t, 8196) for t in texts]
358
+ res = self.client.embeddings(input=texts,
359
+ model=self.model_name)
360
+ return np.array([d.embedding for d in res.data]
361
+ ), res.usage.total_tokens
362
+
363
+ def encode_queries(self, text):
364
+ res = self.client.embeddings(input=[truncate(text, 8196)],
365
+ model=self.model_name)
366
+ return np.array(res.data[0].embedding), res.usage.total_tokens