Kevin Hu commited on
Commit
32dd133
·
1 Parent(s): 16e3fae

add stream chat with TTS (#2228)

Browse files

### What problem does this PR solve?



### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/conversation_app.py CHANGED
@@ -196,8 +196,8 @@ def tts():
196
  tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
197
  def stream_audio():
198
  try:
199
- for chunk in tts_mdl.tts(text):
200
- yield chunk
201
  except Exception as e:
202
  yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
203
  "data": {"answer": "**ERROR**: "+str(e)}},
 
196
  tts_mdl = LLMBundle(tenants[0]["tenant_id"], LLMType.TTS, tts_id)
197
  def stream_audio():
198
  try:
199
+ for chunk in tts_mdl.tts(text):
200
+ yield chunk
201
  except Exception as e:
202
  yield ("data:" + json.dumps({"retcode": 500, "retmsg": str(e),
203
  "data": {"answer": "**ERROR**: "+str(e)}},
api/db/services/dialog_service.py CHANGED
@@ -13,6 +13,7 @@
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
 
16
  import os
17
  import json
18
  import re
@@ -120,6 +121,9 @@ def chat(dialog, messages, stream=True, **kwargs):
120
 
121
  prompt_config = dialog.prompt_config
122
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
 
 
 
123
  # try to use sql if field mapping is good to go
124
  if field_map:
125
  chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
@@ -168,7 +172,8 @@ def chat(dialog, messages, stream=True, **kwargs):
168
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
169
 
170
  if not knowledges and prompt_config.get("empty_response"):
171
- yield {"answer": prompt_config["empty_response"], "reference": kbinfos}
 
172
  return {"answer": prompt_config["empty_response"], "reference": kbinfos}
173
 
174
  kwargs["knowledge"] = "\n".join(knowledges)
@@ -214,16 +219,26 @@ def chat(dialog, messages, stream=True, **kwargs):
214
  return {"answer": answer, "reference": refs, "prompt": prompt}
215
 
216
  if stream:
 
217
  answer = ""
218
  for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
219
  answer = ans
220
- yield {"answer": answer, "reference": {}}
 
 
 
 
 
 
 
221
  yield decorate_answer(answer)
222
  else:
223
  answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
224
  chat_logger.info("User: {}|Assistant: {}".format(
225
  msg[-1]["content"], answer))
226
- yield decorate_answer(answer)
 
 
227
 
228
 
229
  def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
@@ -392,3 +407,12 @@ def rewrite(tenant_id, llm_id, question):
392
  """
393
  ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
394
  return ans
 
 
 
 
 
 
 
 
 
 
13
  # See the License for the specific language governing permissions and
14
  # limitations under the License.
15
  #
16
+ import binascii
17
  import os
18
  import json
19
  import re
 
121
 
122
  prompt_config = dialog.prompt_config
123
  field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
124
+ tts_mdl = None
125
+ if prompt_config.get("tts"):
126
+ tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
127
  # try to use sql if field mapping is good to go
128
  if field_map:
129
  chat_logger.info("Use SQL to retrieval:{}".format(questions[-1]))
 
172
  "{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
173
 
174
  if not knowledges and prompt_config.get("empty_response"):
175
+ empty_res = prompt_config["empty_response"]
176
+ yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
177
  return {"answer": prompt_config["empty_response"], "reference": kbinfos}
178
 
179
  kwargs["knowledge"] = "\n".join(knowledges)
 
219
  return {"answer": answer, "reference": refs, "prompt": prompt}
220
 
221
  if stream:
222
+ last_ans = ""
223
  answer = ""
224
  for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
225
  answer = ans
226
+ delta_ans = ans[len(last_ans):]
227
+ if num_tokens_from_string(delta_ans) < 12:
228
+ continue
229
+ last_ans = answer
230
+ yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
231
+ delta_ans = answer[len(last_ans):]
232
+ if delta_ans:
233
+ yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
234
  yield decorate_answer(answer)
235
  else:
236
  answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
237
  chat_logger.info("User: {}|Assistant: {}".format(
238
  msg[-1]["content"], answer))
239
+ res = decorate_answer(answer)
240
+ res["audio_binary"] = tts(tts_mdl, answer)
241
+ yield res
242
 
243
 
244
  def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
 
407
  """
408
  ans = chat_mdl.chat(prompt, [{"role": "user", "content": question}], {"temperature": 0.8})
409
  return ans
410
+
411
+
412
+ def tts(tts_mdl, text):
413
+ return
414
+ if not tts_mdl or not text: return
415
+ bin = b""
416
+ for chunk in tts_mdl.tts(text):
417
+ bin += chunk
418
+ return binascii.hexlify(bin).decode("utf-8")