Kevin Hu commited on
Commit
6d672a7
·
1 Parent(s): e10ed78

add prompt to message (#2099)

Browse files

### What problem does this PR solve?

#2098

### Type of change

- [x] New Feature (non-breaking change which adds functionality)

api/apps/conversation_app.py CHANGED
@@ -140,7 +140,8 @@ def completion():
140
  if not conv.reference:
141
  conv.reference.append(ans["reference"])
142
  else: conv.reference[-1] = ans["reference"]
143
- conv.message[-1] = {"role": "assistant", "content": ans["answer"], "id": message_id}
 
144
 
145
  def stream():
146
  nonlocal dia, msg, req, conv
 
140
  if not conv.reference:
141
  conv.reference.append(ans["reference"])
142
  else: conv.reference[-1] = ans["reference"]
143
+ conv.message[-1] = {"role": "assistant", "content": ans["answer"],
144
+ "id": message_id, "prompt": ans.get("prompt", "")}
145
 
146
  def stream():
147
  nonlocal dia, msg, req, conv
api/db/services/dialog_service.py CHANGED
@@ -179,6 +179,7 @@ def chat(dialog, messages, stream=True, **kwargs):
179
  for m in messages if m["role"] != "system"])
180
  used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
181
  assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
 
182
 
183
  if "max_tokens" in gen_conf:
184
  gen_conf["max_tokens"] = min(
@@ -186,7 +187,7 @@ def chat(dialog, messages, stream=True, **kwargs):
186
  max_tokens - used_token_count)
187
 
188
  def decorate_answer(answer):
189
- nonlocal prompt_config, knowledges, kwargs, kbinfos
190
  refs = []
191
  if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
192
  answer, idx = retr.insert_citations(answer,
@@ -210,17 +211,16 @@ def chat(dialog, messages, stream=True, **kwargs):
210
 
211
  if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
212
  answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
213
- return {"answer": answer, "reference": refs}
214
 
215
  if stream:
216
  answer = ""
217
- for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], gen_conf):
218
  answer = ans
219
- yield {"answer": answer, "reference": {}}
220
  yield decorate_answer(answer)
221
  else:
222
- answer = chat_mdl.chat(
223
- msg[0]["content"], msg[1:], gen_conf)
224
  chat_logger.info("User: {}|Assistant: {}".format(
225
  msg[-1]["content"], answer))
226
  yield decorate_answer(answer)
@@ -334,7 +334,8 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
334
  chat_logger.warning("SQL missing field: " + sql)
335
  return {
336
  "answer": "\n".join([clmns, line, rows]),
337
- "reference": {"chunks": [], "doc_aggs": []}
 
338
  }
339
 
340
  docid_idx = list(docid_idx)[0]
@@ -348,7 +349,8 @@ def use_sql(question, field_map, tenant_id, chat_mdl, quota=True):
348
  "answer": "\n".join([clmns, line, rows]),
349
  "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
350
  "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
351
- doc_aggs.items()]}
 
352
  }
353
 
354
 
 
179
  for m in messages if m["role"] != "system"])
180
  used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
181
  assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
182
+ prompt = msg[0]["content"]
183
 
184
  if "max_tokens" in gen_conf:
185
  gen_conf["max_tokens"] = min(
 
187
  max_tokens - used_token_count)
188
 
189
  def decorate_answer(answer):
190
+ nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt
191
  refs = []
192
  if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
193
  answer, idx = retr.insert_citations(answer,
 
211
 
212
  if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
213
  answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
214
+ return {"answer": answer, "reference": refs, "prompt": prompt}
215
 
216
  if stream:
217
  answer = ""
218
+ for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
219
  answer = ans
220
+ yield {"answer": answer, "reference": {}, "prompt": prompt}
221
  yield decorate_answer(answer)
222
  else:
223
+ answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
 
224
  chat_logger.info("User: {}|Assistant: {}".format(
225
  msg[-1]["content"], answer))
226
  yield decorate_answer(answer)
 
334
  chat_logger.warning("SQL missing field: " + sql)
335
  return {
336
  "answer": "\n".join([clmns, line, rows]),
337
+ "reference": {"chunks": [], "doc_aggs": []},
338
+ "prompt": sys_prompt
339
  }
340
 
341
  docid_idx = list(docid_idx)[0]
 
349
  "answer": "\n".join([clmns, line, rows]),
350
  "reference": {"chunks": [{"doc_id": r[docid_idx], "docnm_kwd": r[docnm_idx]} for r in tbl["rows"]],
351
  "doc_aggs": [{"doc_id": did, "doc_name": d["doc_name"], "count": d["count"]} for did, d in
352
+ doc_aggs.items()]},
353
+ "prompt": sys_prompt
354
  }
355
 
356