KevinHuSh commited on
Commit
4d5f1c9
·
1 Parent(s): d11e82f

add support for deepseek (#668)

Browse files

### What problem does this PR solve?

#666

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

Files changed (3) hide show
  1. api/db/init_data.py +21 -1
  2. rag/llm/__init__.py +2 -1
  3. rag/llm/chat_model.py +20 -51
api/db/init_data.py CHANGED
@@ -123,7 +123,12 @@ factory_infos = [{
123
  "name": "Youdao",
124
  "logo": "",
125
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
126
- "status": "1",
 
 
 
 
 
127
  },
128
  # {
129
  # "name": "文心一言",
@@ -331,6 +336,21 @@ def init_llm_factory():
331
  "max_tokens": 512,
332
  "model_type": LLMType.EMBEDDING.value
333
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
  ]
335
  for info in factory_infos:
336
  try:
 
123
  "name": "Youdao",
124
  "logo": "",
125
  "tags": "LLM,TEXT EMBEDDING,SPEECH2TEXT,MODERATION",
126
+ "status": "1",
127
+ },{
128
+ "name": "DeepSeek",
129
+ "logo": "",
130
+ "tags": "LLM",
131
+ "status": "1",
132
  },
133
  # {
134
  # "name": "文心一言",
 
336
  "max_tokens": 512,
337
  "model_type": LLMType.EMBEDDING.value
338
  },
339
+ # ------------------------ DeepSeek -----------------------
340
+ {
341
+ "fid": factory_infos[8]["name"],
342
+ "llm_name": "deepseek-chat",
343
+ "tags": "LLM,CHAT,",
344
+ "max_tokens": 32768,
345
+ "model_type": LLMType.CHAT.value
346
+ },
347
+ {
348
+ "fid": factory_infos[8]["name"],
349
+ "llm_name": "deepseek-coder",
350
+ "tags": "LLM,CHAT,",
351
+ "max_tokens": 16385,
352
+ "model_type": LLMType.CHAT.value
353
+ },
354
  ]
355
  for info in factory_infos:
356
  try:
rag/llm/__init__.py CHANGED
@@ -45,6 +45,7 @@ ChatModel = {
45
  "Tongyi-Qianwen": QWenChat,
46
  "Ollama": OllamaChat,
47
  "Xinference": XinferenceChat,
48
- "Moonshot": MoonshotChat
 
49
  }
50
 
 
45
  "Tongyi-Qianwen": QWenChat,
46
  "Ollama": OllamaChat,
47
  "Xinference": XinferenceChat,
48
+ "Moonshot": MoonshotChat,
49
+ "DeepSeek": DeepSeekChat
50
  }
51
 
rag/llm/chat_model.py CHANGED
@@ -24,16 +24,7 @@ from rag.utils import num_tokens_from_string
24
 
25
 
26
  class Base(ABC):
27
- def __init__(self, key, model_name):
28
- pass
29
-
30
- def chat(self, system, history, gen_conf):
31
- raise NotImplementedError("Please implement encode method!")
32
-
33
-
34
- class GptTurbo(Base):
35
- def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
36
- if not base_url: base_url="https://api.openai.com/v1"
37
  self.client = OpenAI(api_key=key, base_url=base_url)
38
  self.model_name = model_name
39
 
@@ -54,28 +45,28 @@ class GptTurbo(Base):
54
  return "**ERROR**: " + str(e), 0
55
 
56
 
57
- class MoonshotChat(GptTurbo):
 
 
 
 
 
 
58
  def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
59
  if not base_url: base_url="https://api.moonshot.cn/v1"
60
- self.client = OpenAI(
61
- api_key=key, base_url=base_url)
62
- self.model_name = model_name
63
 
64
- def chat(self, system, history, gen_conf):
65
- if system:
66
- history.insert(0, {"role": "system", "content": system})
67
- try:
68
- response = self.client.chat.completions.create(
69
- model=self.model_name,
70
- messages=history,
71
- **gen_conf)
72
- ans = response.choices[0].message.content.strip()
73
- if response.choices[0].finish_reason == "length":
74
- ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
75
- [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
76
- return ans, response.usage.total_tokens
77
- except openai.APIError as e:
78
- return "**ERROR**: " + str(e), 0
79
 
80
 
81
  class QWenChat(Base):
@@ -157,25 +148,3 @@ class OllamaChat(Base):
157
  except Exception as e:
158
  return "**ERROR**: " + str(e), 0
159
 
160
-
161
- class XinferenceChat(Base):
162
- def __init__(self, key=None, model_name="", base_url=""):
163
- self.client = OpenAI(api_key="xxx", base_url=base_url)
164
- self.model_name = model_name
165
-
166
- def chat(self, system, history, gen_conf):
167
- if system:
168
- history.insert(0, {"role": "system", "content": system})
169
- try:
170
- response = self.client.chat.completions.create(
171
- model=self.model_name,
172
- messages=history,
173
- **gen_conf)
174
- ans = response.choices[0].message.content.strip()
175
- if response.choices[0].finish_reason == "length":
176
- ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
177
- [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
178
- return ans, response.usage.total_tokens
179
- except openai.APIError as e:
180
- return "**ERROR**: " + str(e), 0
181
-
 
24
 
25
 
26
  class Base(ABC):
27
+ def __init__(self, key, model_name, base_url):
 
 
 
 
 
 
 
 
 
28
  self.client = OpenAI(api_key=key, base_url=base_url)
29
  self.model_name = model_name
30
 
 
45
  return "**ERROR**: " + str(e), 0
46
 
47
 
48
+ class GptTurbo(Base):
49
+ def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
50
+ if not base_url: base_url="https://api.openai.com/v1"
51
+ super().__init__(key, model_name, base_url)
52
+
53
+
54
+ class MoonshotChat(Base):
55
  def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
56
  if not base_url: base_url="https://api.moonshot.cn/v1"
57
+ super().__init__(key, model_name, base_url)
 
 
58
 
59
+
60
+ class XinferenceChat(Base):
61
+ def __init__(self, key=None, model_name="", base_url=""):
62
+ key = "xxx"
63
+ super().__init__(key, model_name, base_url)
64
+
65
+
66
+ class DeepSeekChat(Base):
67
+ def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
68
+ if not base_url: base_url="https://api.deepseek.com/v1"
69
+ super().__init__(key, model_name, base_url)
 
 
 
 
70
 
71
 
72
  class QWenChat(Base):
 
148
  except Exception as e:
149
  return "**ERROR**: " + str(e), 0
150