JobSmithManipulation
commited on
Commit
·
fe03167
1
Parent(s):
7b1a8ac
solve knowledgegraph issue when calling gemini model (#2738)
Browse files### What problem does this PR solve?
#2720
### Type of change
- [x] Bug Fix (non-breaking change which fixes an issue)
- rag/llm/chat_model.py +64 -62
rag/llm/chat_model.py
CHANGED
@@ -23,7 +23,7 @@ from ollama import Client
|
|
23 |
from rag.nlp import is_english
|
24 |
from rag.utils import num_tokens_from_string
|
25 |
from groq import Groq
|
26 |
-
import os
|
27 |
import json
|
28 |
import requests
|
29 |
import asyncio
|
@@ -62,17 +62,17 @@ class Base(ABC):
|
|
62 |
stream=True,
|
63 |
**gen_conf)
|
64 |
for resp in response:
|
65 |
-
if not resp.choices:continue
|
66 |
if not resp.choices[0].delta.content:
|
67 |
-
resp.choices[0].delta.content = ""
|
68 |
ans += resp.choices[0].delta.content
|
69 |
total_tokens = (
|
70 |
(
|
71 |
-
|
72 |
-
|
73 |
)
|
74 |
if not hasattr(resp, "usage") or not resp.usage
|
75 |
-
else resp.usage.get("total_tokens",total_tokens)
|
76 |
)
|
77 |
if resp.choices[0].finish_reason == "length":
|
78 |
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
@@ -87,13 +87,13 @@ class Base(ABC):
|
|
87 |
|
88 |
class GptTurbo(Base):
|
89 |
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
90 |
-
if not base_url: base_url="https://api.openai.com/v1"
|
91 |
super().__init__(key, model_name, base_url)
|
92 |
|
93 |
|
94 |
class MoonshotChat(Base):
|
95 |
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
|
96 |
-
if not base_url: base_url="https://api.moonshot.cn/v1"
|
97 |
super().__init__(key, model_name, base_url)
|
98 |
|
99 |
|
@@ -108,7 +108,7 @@ class XinferenceChat(Base):
|
|
108 |
|
109 |
class DeepSeekChat(Base):
|
110 |
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
|
111 |
-
if not base_url: base_url="https://api.deepseek.com/v1"
|
112 |
super().__init__(key, model_name, base_url)
|
113 |
|
114 |
|
@@ -178,14 +178,14 @@ class BaiChuanChat(Base):
|
|
178 |
stream=True,
|
179 |
**self._format_params(gen_conf))
|
180 |
for resp in response:
|
181 |
-
if not resp.choices:continue
|
182 |
if not resp.choices[0].delta.content:
|
183 |
-
resp.choices[0].delta.content = ""
|
184 |
ans += resp.choices[0].delta.content
|
185 |
total_tokens = (
|
186 |
(
|
187 |
-
|
188 |
-
|
189 |
)
|
190 |
if not hasattr(resp, "usage")
|
191 |
else resp.usage["total_tokens"]
|
@@ -252,7 +252,8 @@ class QWenChat(Base):
|
|
252 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
253 |
yield ans
|
254 |
else:
|
255 |
-
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
|
|
|
256 |
except Exception as e:
|
257 |
yield ans + "\n**ERROR**: " + str(e)
|
258 |
|
@@ -298,7 +299,7 @@ class ZhipuChat(Base):
|
|
298 |
**gen_conf
|
299 |
)
|
300 |
for resp in response:
|
301 |
-
if not resp.choices[0].delta.content:continue
|
302 |
delta = resp.choices[0].delta.content
|
303 |
ans += delta
|
304 |
if resp.choices[0].finish_reason == "length":
|
@@ -411,7 +412,7 @@ class LocalLLM(Base):
|
|
411 |
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
412 |
|
413 |
def _prepare_prompt(self, system, history, gen_conf):
|
414 |
-
from rag.svr.jina_server import Prompt,Generation
|
415 |
if system:
|
416 |
history.insert(0, {"role": "system", "content": system})
|
417 |
if "max_tokens" in gen_conf:
|
@@ -419,7 +420,7 @@ class LocalLLM(Base):
|
|
419 |
return Prompt(message=history, gen_conf=gen_conf)
|
420 |
|
421 |
def _stream_response(self, endpoint, prompt):
|
422 |
-
from rag.svr.jina_server import Prompt,Generation
|
423 |
answer = ""
|
424 |
try:
|
425 |
res = self.client.stream_doc(
|
@@ -463,10 +464,10 @@ class VolcEngineChat(Base):
|
|
463 |
|
464 |
class MiniMaxChat(Base):
|
465 |
def __init__(
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
):
|
471 |
if not base_url:
|
472 |
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
@@ -583,7 +584,7 @@ class MistralChat(Base):
|
|
583 |
messages=history,
|
584 |
**gen_conf)
|
585 |
for resp in response:
|
586 |
-
if not resp.choices or not resp.choices[0].delta.content:continue
|
587 |
ans += resp.choices[0].delta.content
|
588 |
total_tokens += 1
|
589 |
if resp.choices[0].finish_reason == "length":
|
@@ -620,9 +621,8 @@ class BedrockChat(Base):
|
|
620 |
gen_conf["topP"] = gen_conf["top_p"]
|
621 |
_ = gen_conf.pop("top_p")
|
622 |
for item in history:
|
623 |
-
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
|
624 |
-
item["content"] = [{"text":item["content"]}]
|
625 |
-
|
626 |
|
627 |
try:
|
628 |
# Send the message to the model, using a basic inference configuration.
|
@@ -630,9 +630,9 @@ class BedrockChat(Base):
|
|
630 |
modelId=self.model_name,
|
631 |
messages=history,
|
632 |
inferenceConfig=gen_conf,
|
633 |
-
system=[{"text": (system if system else "Answer the user's message.")}]
|
634 |
)
|
635 |
-
|
636 |
# Extract and print the response text.
|
637 |
ans = response["output"]["message"]["content"][0]["text"]
|
638 |
return ans, num_tokens_from_string(ans)
|
@@ -652,9 +652,9 @@ class BedrockChat(Base):
|
|
652 |
gen_conf["topP"] = gen_conf["top_p"]
|
653 |
_ = gen_conf.pop("top_p")
|
654 |
for item in history:
|
655 |
-
if not isinstance(item["content"],list) and not isinstance(item["content"],tuple):
|
656 |
-
item["content"] = [{"text":item["content"]}]
|
657 |
-
|
658 |
if self.model_name.split('.')[0] == 'ai21':
|
659 |
try:
|
660 |
response = self.client.converse(
|
@@ -684,7 +684,7 @@ class BedrockChat(Base):
|
|
684 |
if "contentBlockDelta" in resp:
|
685 |
ans += resp["contentBlockDelta"]["delta"]["text"]
|
686 |
yield ans
|
687 |
-
|
688 |
except (ClientError, Exception) as e:
|
689 |
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
|
690 |
|
@@ -693,22 +693,21 @@ class BedrockChat(Base):
|
|
693 |
|
694 |
class GeminiChat(Base):
|
695 |
|
696 |
-
def __init__(self, key, model_name,base_url=None):
|
697 |
-
from google.generativeai import client,GenerativeModel
|
698 |
-
|
699 |
client.configure(api_key=key)
|
700 |
_client = client.get_default_generative_client()
|
701 |
self.model_name = 'models/' + model_name
|
702 |
self.model = GenerativeModel(model_name=self.model_name)
|
703 |
self.model._client = _client
|
704 |
-
|
705 |
-
|
706 |
-
def chat(self,system,history,gen_conf):
|
707 |
from google.generativeai.types import content_types
|
708 |
-
|
709 |
if system:
|
710 |
self.model._system_instruction = content_types.to_content(system)
|
711 |
-
|
712 |
if 'max_tokens' in gen_conf:
|
713 |
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
|
714 |
for k in list(gen_conf.keys()):
|
@@ -717,9 +716,11 @@ class GeminiChat(Base):
|
|
717 |
for item in history:
|
718 |
if 'role' in item and item['role'] == 'assistant':
|
719 |
item['role'] = 'model'
|
720 |
-
if
|
|
|
|
|
721 |
item['parts'] = item.pop('content')
|
722 |
-
|
723 |
try:
|
724 |
response = self.model.generate_content(
|
725 |
history,
|
@@ -731,7 +732,7 @@ class GeminiChat(Base):
|
|
731 |
|
732 |
def chat_streamly(self, system, history, gen_conf):
|
733 |
from google.generativeai.types import content_types
|
734 |
-
|
735 |
if system:
|
736 |
self.model._system_instruction = content_types.to_content(system)
|
737 |
if 'max_tokens' in gen_conf:
|
@@ -742,13 +743,13 @@ class GeminiChat(Base):
|
|
742 |
for item in history:
|
743 |
if 'role' in item and item['role'] == 'assistant':
|
744 |
item['role'] = 'model'
|
745 |
-
if
|
746 |
item['parts'] = item.pop('content')
|
747 |
ans = ""
|
748 |
try:
|
749 |
response = self.model.generate_content(
|
750 |
history,
|
751 |
-
generation_config=gen_conf,stream=True)
|
752 |
for resp in response:
|
753 |
ans += resp.text
|
754 |
yield ans
|
@@ -756,11 +757,11 @@ class GeminiChat(Base):
|
|
756 |
except Exception as e:
|
757 |
yield ans + "\n**ERROR**: " + str(e)
|
758 |
|
759 |
-
yield
|
760 |
|
761 |
|
762 |
class GroqChat:
|
763 |
-
def __init__(self, key, model_name,base_url=''):
|
764 |
self.client = Groq(api_key=key)
|
765 |
self.model_name = model_name
|
766 |
|
@@ -942,7 +943,7 @@ class CoHereChat(Base):
|
|
942 |
class LeptonAIChat(Base):
|
943 |
def __init__(self, key, model_name, base_url=None):
|
944 |
if not base_url:
|
945 |
-
base_url = os.path.join("https://"+model_name+".lepton.run","api","v1")
|
946 |
super().__init__(key, model_name, base_url)
|
947 |
|
948 |
|
@@ -1058,7 +1059,7 @@ class HunyuanChat(Base):
|
|
1058 |
)
|
1059 |
|
1060 |
_gen_conf = {}
|
1061 |
-
_history = [{k.capitalize(): v for k, v in item.items()
|
1062 |
if system:
|
1063 |
_history.insert(0, {"Role": "system", "Content": system})
|
1064 |
if "temperature" in gen_conf:
|
@@ -1084,7 +1085,7 @@ class HunyuanChat(Base):
|
|
1084 |
)
|
1085 |
|
1086 |
_gen_conf = {}
|
1087 |
-
_history = [{k.capitalize(): v for k, v in item.items()
|
1088 |
if system:
|
1089 |
_history.insert(0, {"Role": "system", "Content": system})
|
1090 |
|
@@ -1121,7 +1122,7 @@ class HunyuanChat(Base):
|
|
1121 |
|
1122 |
class SparkChat(Base):
|
1123 |
def __init__(
|
1124 |
-
|
1125 |
):
|
1126 |
if not base_url:
|
1127 |
base_url = "https://spark-api-open.xf-yun.com/v1"
|
@@ -1141,9 +1142,9 @@ class BaiduYiyanChat(Base):
|
|
1141 |
import qianfan
|
1142 |
|
1143 |
key = json.loads(key)
|
1144 |
-
ak = key.get("yiyan_ak","")
|
1145 |
-
sk = key.get("yiyan_sk","")
|
1146 |
-
self.client = qianfan.ChatCompletion(ak=ak,sk=sk)
|
1147 |
self.model_name = model_name.lower()
|
1148 |
self.system = ""
|
1149 |
|
@@ -1151,16 +1152,17 @@ class BaiduYiyanChat(Base):
|
|
1151 |
if system:
|
1152 |
self.system = system
|
1153 |
gen_conf["penalty_score"] = (
|
1154 |
-
|
1155 |
-
|
|
|
1156 |
if "max_tokens" in gen_conf:
|
1157 |
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
1158 |
ans = ""
|
1159 |
|
1160 |
try:
|
1161 |
response = self.client.do(
|
1162 |
-
model=self.model_name,
|
1163 |
-
messages=history,
|
1164 |
system=self.system,
|
1165 |
**gen_conf
|
1166 |
).body
|
@@ -1174,8 +1176,9 @@ class BaiduYiyanChat(Base):
|
|
1174 |
if system:
|
1175 |
self.system = system
|
1176 |
gen_conf["penalty_score"] = (
|
1177 |
-
|
1178 |
-
|
|
|
1179 |
if "max_tokens" in gen_conf:
|
1180 |
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
1181 |
ans = ""
|
@@ -1183,8 +1186,8 @@ class BaiduYiyanChat(Base):
|
|
1183 |
|
1184 |
try:
|
1185 |
response = self.client.do(
|
1186 |
-
model=self.model_name,
|
1187 |
-
messages=history,
|
1188 |
system=self.system,
|
1189 |
stream=True,
|
1190 |
**gen_conf
|
@@ -1415,4 +1418,3 @@ class GoogleChat(Base):
|
|
1415 |
yield ans + "\n**ERROR**: " + str(e)
|
1416 |
|
1417 |
yield response._chunks[-1].usage_metadata.total_token_count
|
1418 |
-
|
|
|
23 |
from rag.nlp import is_english
|
24 |
from rag.utils import num_tokens_from_string
|
25 |
from groq import Groq
|
26 |
+
import os
|
27 |
import json
|
28 |
import requests
|
29 |
import asyncio
|
|
|
62 |
stream=True,
|
63 |
**gen_conf)
|
64 |
for resp in response:
|
65 |
+
if not resp.choices: continue
|
66 |
if not resp.choices[0].delta.content:
|
67 |
+
resp.choices[0].delta.content = ""
|
68 |
ans += resp.choices[0].delta.content
|
69 |
total_tokens = (
|
70 |
(
|
71 |
+
total_tokens
|
72 |
+
+ num_tokens_from_string(resp.choices[0].delta.content)
|
73 |
)
|
74 |
if not hasattr(resp, "usage") or not resp.usage
|
75 |
+
else resp.usage.get("total_tokens", total_tokens)
|
76 |
)
|
77 |
if resp.choices[0].finish_reason == "length":
|
78 |
ans += "...\nFor the content length reason, it stopped, continue?" if is_english(
|
|
|
87 |
|
88 |
class GptTurbo(Base):
|
89 |
def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"):
|
90 |
+
if not base_url: base_url = "https://api.openai.com/v1"
|
91 |
super().__init__(key, model_name, base_url)
|
92 |
|
93 |
|
94 |
class MoonshotChat(Base):
|
95 |
def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"):
|
96 |
+
if not base_url: base_url = "https://api.moonshot.cn/v1"
|
97 |
super().__init__(key, model_name, base_url)
|
98 |
|
99 |
|
|
|
108 |
|
109 |
class DeepSeekChat(Base):
|
110 |
def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"):
|
111 |
+
if not base_url: base_url = "https://api.deepseek.com/v1"
|
112 |
super().__init__(key, model_name, base_url)
|
113 |
|
114 |
|
|
|
178 |
stream=True,
|
179 |
**self._format_params(gen_conf))
|
180 |
for resp in response:
|
181 |
+
if not resp.choices: continue
|
182 |
if not resp.choices[0].delta.content:
|
183 |
+
resp.choices[0].delta.content = ""
|
184 |
ans += resp.choices[0].delta.content
|
185 |
total_tokens = (
|
186 |
(
|
187 |
+
total_tokens
|
188 |
+
+ num_tokens_from_string(resp.choices[0].delta.content)
|
189 |
)
|
190 |
if not hasattr(resp, "usage")
|
191 |
else resp.usage["total_tokens"]
|
|
|
252 |
[ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?"
|
253 |
yield ans
|
254 |
else:
|
255 |
+
yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find(
|
256 |
+
"Access") < 0 else "Out of credit. Please set the API key in **settings > Model providers.**"
|
257 |
except Exception as e:
|
258 |
yield ans + "\n**ERROR**: " + str(e)
|
259 |
|
|
|
299 |
**gen_conf
|
300 |
)
|
301 |
for resp in response:
|
302 |
+
if not resp.choices[0].delta.content: continue
|
303 |
delta = resp.choices[0].delta.content
|
304 |
ans += delta
|
305 |
if resp.choices[0].finish_reason == "length":
|
|
|
412 |
self.client = Client(port=12345, protocol="grpc", asyncio=True)
|
413 |
|
414 |
def _prepare_prompt(self, system, history, gen_conf):
|
415 |
+
from rag.svr.jina_server import Prompt, Generation
|
416 |
if system:
|
417 |
history.insert(0, {"role": "system", "content": system})
|
418 |
if "max_tokens" in gen_conf:
|
|
|
420 |
return Prompt(message=history, gen_conf=gen_conf)
|
421 |
|
422 |
def _stream_response(self, endpoint, prompt):
|
423 |
+
from rag.svr.jina_server import Prompt, Generation
|
424 |
answer = ""
|
425 |
try:
|
426 |
res = self.client.stream_doc(
|
|
|
464 |
|
465 |
class MiniMaxChat(Base):
|
466 |
def __init__(
|
467 |
+
self,
|
468 |
+
key,
|
469 |
+
model_name,
|
470 |
+
base_url="https://api.minimax.chat/v1/text/chatcompletion_v2",
|
471 |
):
|
472 |
if not base_url:
|
473 |
base_url = "https://api.minimax.chat/v1/text/chatcompletion_v2"
|
|
|
584 |
messages=history,
|
585 |
**gen_conf)
|
586 |
for resp in response:
|
587 |
+
if not resp.choices or not resp.choices[0].delta.content: continue
|
588 |
ans += resp.choices[0].delta.content
|
589 |
total_tokens += 1
|
590 |
if resp.choices[0].finish_reason == "length":
|
|
|
621 |
gen_conf["topP"] = gen_conf["top_p"]
|
622 |
_ = gen_conf.pop("top_p")
|
623 |
for item in history:
|
624 |
+
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
|
625 |
+
item["content"] = [{"text": item["content"]}]
|
|
|
626 |
|
627 |
try:
|
628 |
# Send the message to the model, using a basic inference configuration.
|
|
|
630 |
modelId=self.model_name,
|
631 |
messages=history,
|
632 |
inferenceConfig=gen_conf,
|
633 |
+
system=[{"text": (system if system else "Answer the user's message.")}],
|
634 |
)
|
635 |
+
|
636 |
# Extract and print the response text.
|
637 |
ans = response["output"]["message"]["content"][0]["text"]
|
638 |
return ans, num_tokens_from_string(ans)
|
|
|
652 |
gen_conf["topP"] = gen_conf["top_p"]
|
653 |
_ = gen_conf.pop("top_p")
|
654 |
for item in history:
|
655 |
+
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
|
656 |
+
item["content"] = [{"text": item["content"]}]
|
657 |
+
|
658 |
if self.model_name.split('.')[0] == 'ai21':
|
659 |
try:
|
660 |
response = self.client.converse(
|
|
|
684 |
if "contentBlockDelta" in resp:
|
685 |
ans += resp["contentBlockDelta"]["delta"]["text"]
|
686 |
yield ans
|
687 |
+
|
688 |
except (ClientError, Exception) as e:
|
689 |
yield ans + f"ERROR: Can't invoke '{self.model_name}'. Reason: {e}"
|
690 |
|
|
|
693 |
|
694 |
class GeminiChat(Base):
|
695 |
|
696 |
+
def __init__(self, key, model_name, base_url=None):
|
697 |
+
from google.generativeai import client, GenerativeModel
|
698 |
+
|
699 |
client.configure(api_key=key)
|
700 |
_client = client.get_default_generative_client()
|
701 |
self.model_name = 'models/' + model_name
|
702 |
self.model = GenerativeModel(model_name=self.model_name)
|
703 |
self.model._client = _client
|
704 |
+
|
705 |
+
def chat(self, system, history, gen_conf):
|
|
|
706 |
from google.generativeai.types import content_types
|
707 |
+
|
708 |
if system:
|
709 |
self.model._system_instruction = content_types.to_content(system)
|
710 |
+
|
711 |
if 'max_tokens' in gen_conf:
|
712 |
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
|
713 |
for k in list(gen_conf.keys()):
|
|
|
716 |
for item in history:
|
717 |
if 'role' in item and item['role'] == 'assistant':
|
718 |
item['role'] = 'model'
|
719 |
+
if 'role' in item and item['role'] == 'system':
|
720 |
+
item['role'] = 'user'
|
721 |
+
if 'content' in item:
|
722 |
item['parts'] = item.pop('content')
|
723 |
+
|
724 |
try:
|
725 |
response = self.model.generate_content(
|
726 |
history,
|
|
|
732 |
|
733 |
def chat_streamly(self, system, history, gen_conf):
|
734 |
from google.generativeai.types import content_types
|
735 |
+
|
736 |
if system:
|
737 |
self.model._system_instruction = content_types.to_content(system)
|
738 |
if 'max_tokens' in gen_conf:
|
|
|
743 |
for item in history:
|
744 |
if 'role' in item and item['role'] == 'assistant':
|
745 |
item['role'] = 'model'
|
746 |
+
if 'content' in item:
|
747 |
item['parts'] = item.pop('content')
|
748 |
ans = ""
|
749 |
try:
|
750 |
response = self.model.generate_content(
|
751 |
history,
|
752 |
+
generation_config=gen_conf, stream=True)
|
753 |
for resp in response:
|
754 |
ans += resp.text
|
755 |
yield ans
|
|
|
757 |
except Exception as e:
|
758 |
yield ans + "\n**ERROR**: " + str(e)
|
759 |
|
760 |
+
yield response._chunks[-1].usage_metadata.total_token_count
|
761 |
|
762 |
|
763 |
class GroqChat:
|
764 |
+
def __init__(self, key, model_name, base_url=''):
|
765 |
self.client = Groq(api_key=key)
|
766 |
self.model_name = model_name
|
767 |
|
|
|
943 |
class LeptonAIChat(Base):
|
944 |
def __init__(self, key, model_name, base_url=None):
|
945 |
if not base_url:
|
946 |
+
base_url = os.path.join("https://" + model_name + ".lepton.run", "api", "v1")
|
947 |
super().__init__(key, model_name, base_url)
|
948 |
|
949 |
|
|
|
1059 |
)
|
1060 |
|
1061 |
_gen_conf = {}
|
1062 |
+
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
1063 |
if system:
|
1064 |
_history.insert(0, {"Role": "system", "Content": system})
|
1065 |
if "temperature" in gen_conf:
|
|
|
1085 |
)
|
1086 |
|
1087 |
_gen_conf = {}
|
1088 |
+
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
1089 |
if system:
|
1090 |
_history.insert(0, {"Role": "system", "Content": system})
|
1091 |
|
|
|
1122 |
|
1123 |
class SparkChat(Base):
|
1124 |
def __init__(
|
1125 |
+
self, key, model_name, base_url="https://spark-api-open.xf-yun.com/v1"
|
1126 |
):
|
1127 |
if not base_url:
|
1128 |
base_url = "https://spark-api-open.xf-yun.com/v1"
|
|
|
1142 |
import qianfan
|
1143 |
|
1144 |
key = json.loads(key)
|
1145 |
+
ak = key.get("yiyan_ak", "")
|
1146 |
+
sk = key.get("yiyan_sk", "")
|
1147 |
+
self.client = qianfan.ChatCompletion(ak=ak, sk=sk)
|
1148 |
self.model_name = model_name.lower()
|
1149 |
self.system = ""
|
1150 |
|
|
|
1152 |
if system:
|
1153 |
self.system = system
|
1154 |
gen_conf["penalty_score"] = (
|
1155 |
+
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
|
1156 |
+
0)) / 2
|
1157 |
+
) + 1
|
1158 |
if "max_tokens" in gen_conf:
|
1159 |
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
1160 |
ans = ""
|
1161 |
|
1162 |
try:
|
1163 |
response = self.client.do(
|
1164 |
+
model=self.model_name,
|
1165 |
+
messages=history,
|
1166 |
system=self.system,
|
1167 |
**gen_conf
|
1168 |
).body
|
|
|
1176 |
if system:
|
1177 |
self.system = system
|
1178 |
gen_conf["penalty_score"] = (
|
1179 |
+
(gen_conf.get("presence_penalty", 0) + gen_conf.get("frequency_penalty",
|
1180 |
+
0)) / 2
|
1181 |
+
) + 1
|
1182 |
if "max_tokens" in gen_conf:
|
1183 |
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
1184 |
ans = ""
|
|
|
1186 |
|
1187 |
try:
|
1188 |
response = self.client.do(
|
1189 |
+
model=self.model_name,
|
1190 |
+
messages=history,
|
1191 |
system=self.system,
|
1192 |
stream=True,
|
1193 |
**gen_conf
|
|
|
1418 |
yield ans + "\n**ERROR**: " + str(e)
|
1419 |
|
1420 |
yield response._chunks[-1].usage_metadata.total_token_count
|
|