Kevin Hu
commited on
Commit
·
32dd133
1
Parent(s):
16e3fae
add stream chat with TTS (#2228)
Browse files### What problem does this PR solve?
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
api/apps/conversation_app.py
CHANGED
@@ -196,8 +196,8 @@ def tts():
|
|
196 |
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
197 |
def stream_audio():
|
198 |
try:
|
199 |
-
for chunk in tts_mdl.tts(text):
|
200 |
-
yield chunk
|
201 |
except Exception as e:
|
202 |
yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
203 |
"data": {"answer": "**ERROR**: "+str(e)}},
|
|
|
196 |
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
197 |
def stream_audio():
|
198 |
try:
|
199 |
+
for chunk in tts_mdl.tts(text):
|
200 |
+
yield chunk
|
201 |
except Exception as e:
|
202 |
yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
203 |
"data": {"answer": "**ERROR**: "+str(e)}},
|
api/db/services/dialog_service.py
CHANGED
@@ -13,6 +13,7 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
|
|
16 |
import os
|
17 |
import json
|
18 |
import re
|
@@ -120,6 +121,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
120 |
|
121 |
prompt_config = dialog.prompt_config
|
122 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
|
|
|
|
|
|
123 |
# try to use sql if field mapping is good to go
|
124 |
if field_map:
|
125 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
@@ -168,7 +172,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
168 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
169 |
|
170 |
if not knowledges and prompt_config.get("empty_response"):
|
171 |
-
|
|
|
172 |
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
173 |
|
174 |
kwargs["knowledge"] = "\n".join(knowledges)
|
@@ -214,16 +219,26 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
214 |
return {"answer": answer, "reference": refs, "prompt": prompt}
|
215 |
|
216 |
if stream:
|
|
|
217 |
answer = ""
|
218 |
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
219 |
answer = ans
|
220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
yield decorate_answer(answer)
|
222 |
else:
|
223 |
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
224 |
chat_logger.info("User: {}|Assistant: {}".format(
|
225 |
msg[-1]["content"], answer))
|
226 |
-
|
|
|
|
|
227 |
|
228 |
|
229 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
@@ -392,3 +407,12 @@ def rewrite(tenant_id, llm_id, question):
|
|
392 |
"""
|
393 |
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
|
394 |
return ans
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
#
|
16 |
+
import binascii
|
17 |
import os
|
18 |
import json
|
19 |
import re
|
|
|
121 |
|
122 |
prompt_config = dialog.prompt_config
|
123 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
124 |
+
tts_mdl = None
|
125 |
+
if prompt_config.get("tts"):
|
126 |
+
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
127 |
# try to use sql if field mapping is good to go
|
128 |
if field_map:
|
129 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
|
|
172 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
173 |
|
174 |
if not knowledges and prompt_config.get("empty_response"):
|
175 |
+
empty_res = prompt_config["empty_response"]
|
176 |
+
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
|
177 |
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
178 |
|
179 |
kwargs["knowledge"] = "\n".join(knowledges)
|
|
|
219 |
return {"answer": answer, "reference": refs, "prompt": prompt}
|
220 |
|
221 |
if stream:
|
222 |
+
last_ans = ""
|
223 |
answer = ""
|
224 |
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
225 |
answer = ans
|
226 |
+
delta_ans = ans[len(last_ans):]
|
227 |
+
if num_tokens_from_string(delta_ans) < 12:
|
228 |
+
continue
|
229 |
+
last_ans = answer
|
230 |
+
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
231 |
+
delta_ans = answer[len(last_ans):]
|
232 |
+
if delta_ans:
|
233 |
+
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
234 |
yield decorate_answer(answer)
|
235 |
else:
|
236 |
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
237 |
chat_logger.info("User: {}|Assistant: {}".format(
|
238 |
msg[-1]["content"], answer))
|
239 |
+
res = decorate_answer(answer)
|
240 |
+
res["audio_binary"] = tts(tts_mdl, answer)
|
241 |
+
yield res
|
242 |
|
243 |
|
244 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
|
407 |
"""
|
408 |
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
|
409 |
return ans
|
410 |
+
|
411 |
+
|
412 |
+
def tts(tts_mdl, text):
|
413 |
+
return
|
414 |
+
if not tts_mdl or not text: return
|
415 |
+
bin = b""
|
416 |
+
for chunk in tts_mdl.tts(text):
|
417 |
+
bin += chunk
|
418 |
+
return binascii.hexlify(bin).decode("utf-8")
|