JobSmithManipulation commited on
Commit
fe03167
·
1 Parent(s): 7b1a8ac

solve knowledgegraph issue when calling gemini model (#2738)

Browse files

### What problem does this PR solve?
#2720

### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

Files changed (1) hide show
  1. rag/llm/chat_model.py +64 -62
rag/llm/chat_model.py CHANGED
@@ -23,7 +23,7 @@ from ollama import Client
23
  from rag.nlp import is_english
24
  from rag.utils import num_tokens_from_string
25
  from groq import Groq
26
- import os
27
  import json
28
  import requests
29
  import asyncio
@@ -62,17 +62,17 @@ class Base(ABC):
62
  stream=True,
63
  **gen_conf)
64
  for resp in response:
65
- if not resp.choices:continue
66
  if not resp.choices[0].delta.content:
67
- resp.choices[0].delta.content = ""
68
  ans += resp.choices[0].delta.content
69
  total_tokens = (
70
  (
71
- total_tokens
72
- + num_tokens_from_string(resp.choices[0].delta.content)
73
  )
74
  if not hasattr(resp, "usage") or not resp.usage
75
- else resp.usage.get("total_tokens",total_tokens)
76
  )
77
  if resp.choices[0].finish_reason == "length":
78
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
@@ -87,13 +87,13 @@ class Base(ABC):
87
 
88
  class GptTurbo(Base):
89
  def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
90
- if not base_url: base_url="https://api.openai.com/v1"
91
  super().__init__(key, model_name, base_url)
92
 
93
 
94
  class MoonshotChat(Base):
95
  def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
96
- if not base_url: base_url="https://api.moonshot.cn/v1"
97
  super().__init__(key, model_name, base_url)
98
 
99
 
@@ -108,7 +108,7 @@ class XinferenceChat(Base):
108
 
109
  class DeepSeekChat(Base):
110
  def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
111
- if not base_url: base_url="https://api.deepseek.com/v1"
112
  super().__init__(key, model_name, base_url)
113
 
114
 
@@ -178,14 +178,14 @@ class BaiChuanChat(Base):
178
  stream=True,
179
  **self._format_params(gen_conf))
180
  for resp in response:
181
- if not resp.choices:continue
182
  if not resp.choices[0].delta.content:
183
- resp.choices[0].delta.content = ""
184
  ans += resp.choices[0].delta.content
185
  total_tokens = (
186
  (
187
- total_tokens
188
- + num_tokens_from_string(resp.choices[0].delta.content)
189
  )
190
  if not hasattr(resp, "usage")
191
  else resp.usage["total_tokens"]
@@ -252,7 +252,8 @@ class QWenChat(Base):
252
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
253
  yield ans
254
  else:
255
- yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**"
 
256
  except Exception as e:
257
  yield ans + "\n**ERROR**: " + str(e)
258
 
@@ -298,7 +299,7 @@ class ZhipuChat(Base):
298
  **gen_conf
299
  )
300
  for resp in response:
301
- if not resp.choices[0].delta.content:continue
302
  delta = resp.choices[0].delta.content
303
  ans += delta
304
  if resp.choices[0].finish_reason == "length":
@@ -411,7 +412,7 @@ class LocalLLM(Base):
411
  self.client = Client(port=12345, protocol="grpc", asyncio=True)
412
 
413
  def _prepare_prompt(self, system, history, gen_conf):
414
- from rag.svr.jina_server import Prompt,Generation
415
  if system:
416
  history.insert(0, {"role": "system", "content": system})
417
  if "max_tokens" in gen_conf:
@@ -419,7 +420,7 @@ class LocalLLM(Base):
419
  return Prompt(message=history, gen_conf=gen_conf)
420
 
421
  def _stream_response(self, endpoint, prompt):
422
- from rag.svr.jina_server import Prompt,Generation
423
  answer = ""
424
  try:
425
  res = self.client.stream_doc(
@@ -463,10 +464,10 @@ class VolcEngineChat(Base):
463
 
464
  class MiniMaxChat(Base):
465
  def __init__(
466
- self,
467
- key,
468
- model_name,
469
- base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
470
  ):
471
  if not base_url:
472
  base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
@@ -583,7 +584,7 @@ class MistralChat(Base):
583
  messages=history,
584
  **gen_conf)
585
  for resp in response:
586
- if not resp.choices or not resp.choices[0].delta.content:continue
587
  ans += resp.choices[0].delta.content
588
  total_tokens += 1
589
  if resp.choices[0].finish_reason == "length":
@@ -620,9 +621,8 @@ class BedrockChat(Base):
620
  gen_conf["topP"] = gen_conf["top_p"]
621
  _ = gen_conf.pop("top_p")
622
  for item in history:
623
- if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
624
- item["content"] = [{"text":item["content"]}]
625
-
626
 
627
  try:
628
  # Send the message to the model, using a basic inference configuration.
@@ -630,9 +630,9 @@ class BedrockChat(Base):
630
  modelId=self.model_name,
631
  messages=history,
632
  inferenceConfig=gen_conf,
633
- system=[{"text": (system if system else "Answer the user's message.")}] ,
634
  )
635
-
636
  # Extract and print the response text.
637
  ans = response["output"]["message"]["content"][0]["text"]
638
  return ans, num_tokens_from_string(ans)
@@ -652,9 +652,9 @@ class BedrockChat(Base):
652
  gen_conf["topP"] = gen_conf["top_p"]
653
  _ = gen_conf.pop("top_p")
654
  for item in history:
655
- if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
656
- item["content"] = [{"text":item["content"]}]
657
-
658
  if self.model_name.split('.')[0] == 'ai21':
659
  try:
660
  response = self.client.converse(
@@ -684,7 +684,7 @@ class BedrockChat(Base):
684
  if "contentBlockDelta" in resp:
685
  ans += resp["contentBlockDelta"]["delta"]["text"]
686
  yield ans
687
-
688
  except (ClientError, Exception) as e:
689
  yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
690
 
@@ -693,22 +693,21 @@ class BedrockChat(Base):
693
 
694
  class GeminiChat(Base):
695
 
696
- def __init__(self, key, model_name,base_url=None):
697
- from google.generativeai import client,GenerativeModel
698
-
699
  client.configure(api_key=key)
700
  _client = client.get_default_generative_client()
701
  self.model_name = 'models/' + model_name
702
  self.model = GenerativeModel(model_name=self.model_name)
703
  self.model._client = _client
704
-
705
-
706
- def chat(self,system,history,gen_conf):
707
  from google.generativeai.types import content_types
708
-
709
  if system:
710
  self.model._system_instruction = content_types.to_content(system)
711
-
712
  if 'max_tokens' in gen_conf:
713
  gen_conf['max_output_tokens'] = gen_conf['max_tokens']
714
  for k in list(gen_conf.keys()):
@@ -717,9 +716,11 @@ class GeminiChat(Base):
717
  for item in history:
718
  if 'role' in item and item['role'] == 'assistant':
719
  item['role'] = 'model'
720
- if 'content' in item :
 
 
721
  item['parts'] = item.pop('content')
722
-
723
  try:
724
  response = self.model.generate_content(
725
  history,
@@ -731,7 +732,7 @@ class GeminiChat(Base):
731
 
732
  def chat_streamly(self, system, history, gen_conf):
733
  from google.generativeai.types import content_types
734
-
735
  if system:
736
  self.model._system_instruction = content_types.to_content(system)
737
  if 'max_tokens' in gen_conf:
@@ -742,13 +743,13 @@ class GeminiChat(Base):
742
  for item in history:
743
  if 'role' in item and item['role'] == 'assistant':
744
  item['role'] = 'model'
745
- if 'content' in item :
746
  item['parts'] = item.pop('content')
747
  ans = ""
748
  try:
749
  response = self.model.generate_content(
750
  history,
751
- generation_config=gen_conf,stream=True)
752
  for resp in response:
753
  ans += resp.text
754
  yield ans
@@ -756,11 +757,11 @@ class GeminiChat(Base):
756
  except Exception as e:
757
  yield ans + "\n**ERROR**: " + str(e)
758
 
759
- yield response._chunks[-1].usage_metadata.total_token_count
760
 
761
 
762
  class GroqChat:
763
- def __init__(self, key, model_name,base_url=''):
764
  self.client = Groq(api_key=key)
765
  self.model_name = model_name
766
 
@@ -942,7 +943,7 @@ class CoHereChat(Base):
942
  class LeptonAIChat(Base):
943
  def __init__(self, key, model_name, base_url=None):
944
  if not base_url:
945
- base_url = os.path.join("https://"+model_name+".lepton.run","api","v1")
946
  super().__init__(key, model_name, base_url)
947
 
948
 
@@ -1058,7 +1059,7 @@ class HunyuanChat(Base):
1058
  )
1059
 
1060
  _gen_conf = {}
1061
- _history = [{k.capitalize(): v for k, v in item.items() } for item in history]
1062
  if system:
1063
  _history.insert(0, {"Role": "system", "Content": system})
1064
  if "temperature" in gen_conf:
@@ -1084,7 +1085,7 @@ class HunyuanChat(Base):
1084
  )
1085
 
1086
  _gen_conf = {}
1087
- _history = [{k.capitalize(): v for k, v in item.items() } for item in history]
1088
  if system:
1089
  _history.insert(0, {"Role": "system", "Content": system})
1090
 
@@ -1121,7 +1122,7 @@ class HunyuanChat(Base):
1121
 
1122
  class SparkChat(Base):
1123
  def __init__(
1124
- self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
1125
  ):
1126
  if not base_url:
1127
  base_url = "https://spark-api-open.xf-yun.com/v1"
@@ -1141,9 +1142,9 @@ class BaiduYiyanChat(Base):
1141
  import qianfan
1142
 
1143
  key = json.loads(key)
1144
- ak = key.get("yiyan_ak","")
1145
- sk = key.get("yiyan_sk","")
1146
- self.client = qianfan.ChatCompletion(ak=ak,sk=sk)
1147
  self.model_name = model_name.lower()
1148
  self.system = ""
1149
 
@@ -1151,16 +1152,17 @@ class BaiduYiyanChat(Base):
1151
  if system:
1152
  self.system = system
1153
  gen_conf["penalty_score"] = (
1154
- (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
1155
- ) + 1
 
1156
  if "max_tokens" in gen_conf:
1157
  gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
1158
  ans = ""
1159
 
1160
  try:
1161
  response = self.client.do(
1162
- model=self.model_name,
1163
- messages=history,
1164
  system=self.system,
1165
  **gen_conf
1166
  ).body
@@ -1174,8 +1176,9 @@ class BaiduYiyanChat(Base):
1174
  if system:
1175
  self.system = system
1176
  gen_conf["penalty_score"] = (
1177
- (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty", 0)) / 2
1178
- ) + 1
 
1179
  if "max_tokens" in gen_conf:
1180
  gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
1181
  ans = ""
@@ -1183,8 +1186,8 @@ class BaiduYiyanChat(Base):
1183
 
1184
  try:
1185
  response = self.client.do(
1186
- model=self.model_name,
1187
- messages=history,
1188
  system=self.system,
1189
  stream=True,
1190
  **gen_conf
@@ -1415,4 +1418,3 @@ class GoogleChat(Base):
1415
  yield ans + "\n**ERROR**: " + str(e)
1416
 
1417
  yield response._chunks[-1].usage_metadata.total_token_count
1418
-
 
23
  from rag.nlp import is_english
24
  from rag.utils import num_tokens_from_string
25
  from groq import Groq
26
+ import os
27
  import json
28
  import requests
29
  import asyncio
 
62
  stream=True,
63
  **gen_conf)
64
  for resp in response:
65
+ if not resp.choices: continue
66
  if not resp.choices[0].delta.content:
67
+ resp.choices[0].delta.content = ""
68
  ans += resp.choices[0].delta.content
69
  total_tokens = (
70
  (
71
+ total_tokens
72
+ + num_tokens_from_string(resp.choices[0].delta.content)
73
  )
74
  if not hasattr(resp, "usage") or not resp.usage
75
+ else resp.usage.get("total_tokens", total_tokens)
76
  )
77
  if resp.choices[0].finish_reason == "length":
78
  ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
 
87
 
88
  class GptTurbo(Base):
89
  def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
90
+ if not base_url: base_url = "https://api.openai.com/v1"
91
  super().__init__(key, model_name, base_url)
92
 
93
 
94
  class MoonshotChat(Base):
95
  def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
96
+ if not base_url: base_url = "https://api.moonshot.cn/v1"
97
  super().__init__(key, model_name, base_url)
98
 
99
 
 
108
 
109
  class DeepSeekChat(Base):
110
  def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
111
+ if not base_url: base_url = "https://api.deepseek.com/v1"
112
  super().__init__(key, model_name, base_url)
113
 
114
 
 
178
  stream=True,
179
  **self._format_params(gen_conf))
180
  for resp in response:
181
+ if not resp.choices: continue
182
  if not resp.choices[0].delta.content:
183
+ resp.choices[0].delta.content = ""
184
  ans += resp.choices[0].delta.content
185
  total_tokens = (
186
  (
187
+ total_tokens
188
+ + num_tokens_from_string(resp.choices[0].delta.content)
189
  )
190
  if not hasattr(resp, "usage")
191
  else resp.usage["total_tokens"]
 
252
  [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
253
  yield ans
254
  else:
255
+ yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
256
+ "Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
257
  except Exception as e:
258
  yield ans + "\n**ERROR**: " + str(e)
259
 
 
299
  **gen_conf
300
  )
301
  for resp in response:
302
+ if not resp.choices[0].delta.content: continue
303
  delta = resp.choices[0].delta.content
304
  ans += delta
305
  if resp.choices[0].finish_reason == "length":
 
412
  self.client = Client(port=12345, protocol="grpc", asyncio=True)
413
 
414
  def _prepare_prompt(self, system, history, gen_conf):
415
+ from rag.svr.jina_server import Prompt, Generation
416
  if system:
417
  history.insert(0, {"role": "system", "content": system})
418
  if "max_tokens" in gen_conf:
 
420
  return Prompt(message=history, gen_conf=gen_conf)
421
 
422
  def _stream_response(self, endpoint, prompt):
423
+ from rag.svr.jina_server import Prompt, Generation
424
  answer = ""
425
  try:
426
  res = self.client.stream_doc(
 
464
 
465
  class MiniMaxChat(Base):
466
  def __init__(
467
+ self,
468
+ key,
469
+ model_name,
470
+ base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
471
  ):
472
  if not base_url:
473
  base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
 
584
  messages=history,
585
  **gen_conf)
586
  for resp in response:
587
+ if not resp.choices or not resp.choices[0].delta.content: continue
588
  ans += resp.choices[0].delta.content
589
  total_tokens += 1
590
  if resp.choices[0].finish_reason == "length":
 
621
  gen_conf["topP"] = gen_conf["top_p"]
622
  _ = gen_conf.pop("top_p")
623
  for item in history:
624
+ if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
625
+ item["content"] = [{"text": item["content"]}]
 
626
 
627
  try:
628
  # Send the message to the model, using a basic inference configuration.
 
630
  modelId=self.model_name,
631
  messages=history,
632
  inferenceConfig=gen_conf,
633
+ system=[{"text": (system if system else "Answer the user's message.")}],
634
  )
635
+
636
  # Extract and print the response text.
637
  ans = response["output"]["message"]["content"][0]["text"]
638
  return ans, num_tokens_from_string(ans)
 
652
  gen_conf["topP"] = gen_conf["top_p"]
653
  _ = gen_conf.pop("top_p")
654
  for item in history:
655
+ if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
656
+ item["content"] = [{"text": item["content"]}]
657
+
658
  if self.model_name.split('.')[0] == 'ai21':
659
  try:
660
  response = self.client.converse(
 
684
  if "contentBlockDelta" in resp:
685
  ans += resp["contentBlockDelta"]["delta"]["text"]
686
  yield ans
687
+
688
  except (ClientError, Exception) as e:
689
  yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
690
 
 
693
 
694
  class GeminiChat(Base):
695
 
696
+ def __init__(self, key, model_name, base_url=None):
697
+ from google.generativeai import client, GenerativeModel
698
+
699
  client.configure(api_key=key)
700
  _client = client.get_default_generative_client()
701
  self.model_name = 'models/' + model_name
702
  self.model = GenerativeModel(model_name=self.model_name)
703
  self.model._client = _client
704
+
705
+ def chat(self, system, history, gen_conf):
 
706
  from google.generativeai.types import content_types
707
+
708
  if system:
709
  self.model._system_instruction = content_types.to_content(system)
710
+
711
  if 'max_tokens' in gen_conf:
712
  gen_conf['max_output_tokens'] = gen_conf['max_tokens']
713
  for k in list(gen_conf.keys()):
 
716
  for item in history:
717
  if 'role' in item and item['role'] == 'assistant':
718
  item['role'] = 'model'
719
+ if 'role' in item and item['role'] == 'system':
720
+ item['role'] = 'user'
721
+ if 'content' in item:
722
  item['parts'] = item.pop('content')
723
+
724
  try:
725
  response = self.model.generate_content(
726
  history,
 
732
 
733
  def chat_streamly(self, system, history, gen_conf):
734
  from google.generativeai.types import content_types
735
+
736
  if system:
737
  self.model._system_instruction = content_types.to_content(system)
738
  if 'max_tokens' in gen_conf:
 
743
  for item in history:
744
  if 'role' in item and item['role'] == 'assistant':
745
  item['role'] = 'model'
746
+ if 'content' in item:
747
  item['parts'] = item.pop('content')
748
  ans = ""
749
  try:
750
  response = self.model.generate_content(
751
  history,
752
+ generation_config=gen_conf, stream=True)
753
  for resp in response:
754
  ans += resp.text
755
  yield ans
 
757
  except Exception as e:
758
  yield ans + "\n**ERROR**: " + str(e)
759
 
760
+ yield response._chunks[-1].usage_metadata.total_token_count
761
 
762
 
763
  class GroqChat:
764
+ def __init__(self, key, model_name, base_url=''):
765
  self.client = Groq(api_key=key)
766
  self.model_name = model_name
767
 
 
943
  class LeptonAIChat(Base):
944
  def __init__(self, key, model_name, base_url=None):
945
  if not base_url:
946
+ base_url = os.path.join("https://" + model_name + ".lepton.run", "api", "v1")
947
  super().__init__(key, model_name, base_url)
948
 
949
 
 
1059
  )
1060
 
1061
  _gen_conf = {}
1062
+ _history = [{k.capitalize(): v for k, v in item.items()} for item in history]
1063
  if system:
1064
  _history.insert(0, {"Role": "system", "Content": system})
1065
  if "temperature" in gen_conf:
 
1085
  )
1086
 
1087
  _gen_conf = {}
1088
+ _history = [{k.capitalize(): v for k, v in item.items()} for item in history]
1089
  if system:
1090
  _history.insert(0, {"Role": "system", "Content": system})
1091
 
 
1122
 
1123
  class SparkChat(Base):
1124
  def __init__(
1125
+ self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
1126
  ):
1127
  if not base_url:
1128
  base_url = "https://spark-api-open.xf-yun.com/v1"
 
1142
  import qianfan
1143
 
1144
  key = json.loads(key)
1145
+ ak = key.get("yiyan_ak", "")
1146
+ sk = key.get("yiyan_sk", "")
1147
+ self.client = qianfan.ChatCompletion(ak=ak, sk=sk)
1148
  self.model_name = model_name.lower()
1149
  self.system = ""
1150
 
 
1152
  if system:
1153
  self.system = system
1154
  gen_conf["penalty_score"] = (
1155
+ (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
1156
+ 0)) / 2
1157
+ ) + 1
1158
  if "max_tokens" in gen_conf:
1159
  gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
1160
  ans = ""
1161
 
1162
  try:
1163
  response = self.client.do(
1164
+ model=self.model_name,
1165
+ messages=history,
1166
  system=self.system,
1167
  **gen_conf
1168
  ).body
 
1176
  if system:
1177
  self.system = system
1178
  gen_conf["penalty_score"] = (
1179
+ (gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
1180
+ 0)) / 2
1181
+ ) + 1
1182
  if "max_tokens" in gen_conf:
1183
  gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
1184
  ans = ""
 
1186
 
1187
  try:
1188
  response = self.client.do(
1189
+ model=self.model_name,
1190
+ messages=history,
1191
  system=self.system,
1192
  stream=True,
1193
  **gen_conf
 
1418
  yield ans + "\n**ERROR**: " + str(e)
1419
 
1420
  yield response._chunks[-1].usage_metadata.total_token_count