Kevin Hu
commited on
Commit
·
32dd133
1
Parent(s):
16e3fae
add stream chat with TTS (#2228)
Browse files### What problem does this PR solve?
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
api/apps/conversation_app.py
CHANGED
|
@@ -196,8 +196,8 @@ def tts():
|
|
| 196 |
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
| 197 |
def stream_audio():
|
| 198 |
try:
|
| 199 |
-
for chunk in tts_mdl.tts(text):
|
| 200 |
-
yield chunk
|
| 201 |
except Exception as e:
|
| 202 |
yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
| 203 |
"data": {"answer": "**ERROR**: "+str(e)}},
|
|
|
|
| 196 |
tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
|
| 197 |
def stream_audio():
|
| 198 |
try:
|
| 199 |
+
for chunk in tts_mdl.tts(text):
|
| 200 |
+
yield chunk
|
| 201 |
except Exception as e:
|
| 202 |
yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
|
| 203 |
"data": {"answer": "**ERROR**: "+str(e)}},
|
api/db/services/dialog_service.py
CHANGED
|
@@ -13,6 +13,7 @@
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
|
|
|
| 16 |
import os
|
| 17 |
import json
|
| 18 |
import re
|
|
@@ -120,6 +121,9 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 120 |
|
| 121 |
prompt_config = dialog.prompt_config
|
| 122 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
|
|
|
|
|
|
|
|
|
| 123 |
# try to use sql if field mapping is good to go
|
| 124 |
if field_map:
|
| 125 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
|
@@ -168,7 +172,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 168 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
| 169 |
|
| 170 |
if not knowledges and prompt_config.get("empty_response"):
|
| 171 |
-
|
|
|
|
| 172 |
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
| 173 |
|
| 174 |
kwargs["knowledge"] = "\n".join(knowledges)
|
|
@@ -214,16 +219,26 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|
| 214 |
return {"answer": answer, "reference": refs, "prompt": prompt}
|
| 215 |
|
| 216 |
if stream:
|
|
|
|
| 217 |
answer = ""
|
| 218 |
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
| 219 |
answer = ans
|
| 220 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
yield decorate_answer(answer)
|
| 222 |
else:
|
| 223 |
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
| 224 |
chat_logger.info("User: {}|Assistant: {}".format(
|
| 225 |
msg[-1]["content"], answer))
|
| 226 |
-
|
|
|
|
|
|
|
| 227 |
|
| 228 |
|
| 229 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
@@ -392,3 +407,12 @@ def rewrite(tenant_id, llm_id, question):
|
|
| 392 |
"""
|
| 393 |
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
|
| 394 |
return ans
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
# See the License for the specific language governing permissions and
|
| 14 |
# limitations under the License.
|
| 15 |
#
|
| 16 |
+
import binascii
|
| 17 |
import os
|
| 18 |
import json
|
| 19 |
import re
|
|
|
|
| 121 |
|
| 122 |
prompt_config = dialog.prompt_config
|
| 123 |
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
| 124 |
+
tts_mdl = None
|
| 125 |
+
if prompt_config.get("tts"):
|
| 126 |
+
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
| 127 |
# try to use sql if field mapping is good to go
|
| 128 |
if field_map:
|
| 129 |
chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
|
|
|
|
| 172 |
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
| 173 |
|
| 174 |
if not knowledges and prompt_config.get("empty_response"):
|
| 175 |
+
empty_res = prompt_config["empty_response"]
|
| 176 |
+
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
|
| 177 |
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
| 178 |
|
| 179 |
kwargs["knowledge"] = "\n".join(knowledges)
|
|
|
|
| 219 |
return {"answer": answer, "reference": refs, "prompt": prompt}
|
| 220 |
|
| 221 |
if stream:
|
| 222 |
+
last_ans = ""
|
| 223 |
answer = ""
|
| 224 |
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
| 225 |
answer = ans
|
| 226 |
+
delta_ans = ans[len(last_ans):]
|
| 227 |
+
if num_tokens_from_string(delta_ans) < 12:
|
| 228 |
+
continue
|
| 229 |
+
last_ans = answer
|
| 230 |
+
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
| 231 |
+
delta_ans = answer[len(last_ans):]
|
| 232 |
+
if delta_ans:
|
| 233 |
+
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
| 234 |
yield decorate_answer(answer)
|
| 235 |
else:
|
| 236 |
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
| 237 |
chat_logger.info("User: {}|Assistant: {}".format(
|
| 238 |
msg[-1]["content"], answer))
|
| 239 |
+
res = decorate_answer(answer)
|
| 240 |
+
res["audio_binary"] = tts(tts_mdl, answer)
|
| 241 |
+
yield res
|
| 242 |
|
| 243 |
|
| 244 |
def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
|
|
|
|
| 407 |
"""
|
| 408 |
ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
|
| 409 |
return ans
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def tts(tts_mdl, text):
|
| 413 |
+
return
|
| 414 |
+
if not tts_mdl or not text: return
|
| 415 |
+
bin = b""
|
| 416 |
+
for chunk in tts_mdl.tts(text):
|
| 417 |
+
bin += chunk
|
| 418 |
+
return binascii.hexlify(bin).decode("utf-8")
|