KevinHuSh commited on
Commit
255441a
·
1 Parent(s): 2a78c1d

be more specific for error message (#1409)

Browse files

### What problem does this PR solve?

#918

### Type of change

- [x] Refactoring

api/apps/canvas_app.py CHANGED
@@ -95,14 +95,16 @@ def run():
95
  final_ans = {"reference": [], "content": ""}
96
  try:
97
  canvas = Canvas(cvs.dsl, current_user.id)
98
- print(canvas)
99
  if "message" in req:
100
  canvas.messages.append({"role": "user", "content": req["message"]})
101
  canvas.add_user_input(req["message"])
102
  answer = canvas.run(stream=stream)
 
103
  except Exception as e:
104
  return server_error_response(e)
105
 
 
 
106
  if stream:
107
  assert isinstance(answer, partial)
108
 
@@ -116,7 +118,7 @@ def run():
116
  yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
117
 
118
  canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
119
- if "reference" in final_ans:
120
  canvas.reference.append(final_ans["reference"])
121
  cvs.dsl = json.loads(str(canvas))
122
  UserCanvasService.update_by_id(req["id"], cvs.to_dict())
@@ -134,7 +136,7 @@ def run():
134
  return resp
135
 
136
  canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
137
- if "reference" in final_ans:
138
  canvas.reference.append(final_ans["reference"])
139
  cvs.dsl = json.loads(str(canvas))
140
  UserCanvasService.update_by_id(req["id"], cvs.to_dict())
 
95
  final_ans = {"reference": [], "content": ""}
96
  try:
97
  canvas = Canvas(cvs.dsl, current_user.id)
 
98
  if "message" in req:
99
  canvas.messages.append({"role": "user", "content": req["message"]})
100
  canvas.add_user_input(req["message"])
101
  answer = canvas.run(stream=stream)
102
+ print(canvas)
103
  except Exception as e:
104
  return server_error_response(e)
105
 
106
+ assert answer, "Nothing. Is it over?"
107
+
108
  if stream:
109
  assert isinstance(answer, partial)
110
 
 
118
  yield "data:" + json.dumps({"retcode": 0, "retmsg": "", "data": ans}, ensure_ascii=False) + "\n\n"
119
 
120
  canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
121
+ if final_ans.get("reference"):
122
  canvas.reference.append(final_ans["reference"])
123
  cvs.dsl = json.loads(str(canvas))
124
  UserCanvasService.update_by_id(req["id"], cvs.to_dict())
 
136
  return resp
137
 
138
  canvas.messages.append({"role": "assistant", "content": final_ans["content"]})
139
+ if final_ans.get("reference"):
140
  canvas.reference.append(final_ans["reference"])
141
  cvs.dsl = json.loads(str(canvas))
142
  UserCanvasService.update_by_id(req["id"], cvs.to_dict())
graph/canvas.py CHANGED
@@ -121,7 +121,6 @@ class Canvas(ABC):
121
  if desc["to"] not in cpn["downstream"]:
122
  cpn["downstream"].append(desc["to"])
123
 
124
-
125
  self.path = self.dsl["path"]
126
  self.history = self.dsl["history"]
127
  self.messages = self.dsl["messages"]
@@ -136,9 +135,21 @@ class Canvas(ABC):
136
  self.dsl["answer"] = self.answer
137
  self.dsl["reference"] = self.reference
138
  self.dsl["embed_id"] = self._embed_id
139
- dsl = deepcopy(self.dsl)
 
 
 
 
 
 
140
  for k, cpn in self.components.items():
141
- dsl["components"][k]["obj"] = json.loads(str(cpn["obj"]))
 
 
 
 
 
 
142
  return json.dumps(dsl, ensure_ascii=False)
143
 
144
  def reset(self):
@@ -161,6 +172,9 @@ class Canvas(ABC):
161
  except Exception as e:
162
  ans = ComponentBase.be_output(str(e))
163
  self.path[-1].append(cpn_id)
 
 
 
164
  self.history.append(("assistant", ans.to_dict("records")))
165
  return ans
166
 
@@ -190,6 +204,8 @@ class Canvas(ABC):
190
  cpn = self.get_component(cpn_id)
191
  if not cpn["downstream"]: break
192
 
 
 
193
  if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
194
  switch_out = cpn["obj"].output()[1].iloc[0, 0]
195
  assert switch_out in self.components, \
@@ -249,3 +265,27 @@ class Canvas(ABC):
249
 
250
  def get_embedding_model(self):
251
  return self._embed_id
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  if desc["to"] not in cpn["downstream"]:
122
  cpn["downstream"].append(desc["to"])
123
 
 
124
  self.path = self.dsl["path"]
125
  self.history = self.dsl["history"]
126
  self.messages = self.dsl["messages"]
 
135
  self.dsl["answer"] = self.answer
136
  self.dsl["reference"] = self.reference
137
  self.dsl["embed_id"] = self._embed_id
138
+ dsl = {
139
+ "components": {}
140
+ }
141
+ for k in self.dsl.keys():
142
+ if k in ["components"]:continue
143
+ dsl[k] = deepcopy(self.dsl[k])
144
+
145
  for k, cpn in self.components.items():
146
+ if k not in dsl["components"]:
147
+ dsl["components"][k] = {}
148
+ for c in cpn.keys():
149
+ if c == "obj":
150
+ dsl["components"][k][c] = json.loads(str(cpn["obj"]))
151
+ continue
152
+ dsl["components"][k][c] = deepcopy(cpn[c])
153
  return json.dumps(dsl, ensure_ascii=False)
154
 
155
  def reset(self):
 
172
  except Exception as e:
173
  ans = ComponentBase.be_output(str(e))
174
  self.path[-1].append(cpn_id)
175
+ if kwargs.get("stream"):
176
+ assert isinstance(ans, partial)
177
+ return ans
178
  self.history.append(("assistant", ans.to_dict("records")))
179
  return ans
180
 
 
204
  cpn = self.get_component(cpn_id)
205
  if not cpn["downstream"]: break
206
 
207
+ if self._find_loop(): raise OverflowError("Too much loops!")
208
+
209
  if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]:
210
  switch_out = cpn["obj"].output()[1].iloc[0, 0]
211
  assert switch_out in self.components, \
 
265
 
266
  def get_embedding_model(self):
267
  return self._embed_id
268
+
269
+ def _find_loop(self, max_loops=2):
270
+ path = self.path[-1][::-1]
271
+ if len(path) < 2: return False
272
+
273
+ for i in range(len(path)):
274
+ if path[i].lower().find("answer") >= 0:
275
+ path = path[:i]
276
+ break
277
+
278
+ if len(path) < 2: return False
279
+
280
+ for l in range(1, len(path) // 2):
281
+ pat = ",".join(path[0:l])
282
+ path_str = ",".join(path)
283
+ if len(pat) >= len(path_str): return False
284
+ path_str = path_str[len(pat):]
285
+ loop = max_loops
286
+ while path_str.find(pat) >= 0 and loop >= 0:
287
+ loop -= 1
288
+ path_str = path_str[len(pat):]
289
+ if loop < 0: return True
290
+
291
+ return False
graph/component/base.py CHANGED
@@ -19,7 +19,7 @@ import json
19
  import os
20
  from copy import deepcopy
21
  from functools import partial
22
- from typing import List, Dict
23
 
24
  import pandas as pd
25
 
@@ -246,7 +246,7 @@ class ComponentParamBase(ABC):
246
  def check_empty(param, descr):
247
  if not param:
248
  raise ValueError(
249
- descr + " {} not supported empty value."
250
  )
251
 
252
  @staticmethod
@@ -411,12 +411,23 @@ class ComponentBase(ABC):
411
  def _run(self, history, **kwargs):
412
  raise NotImplementedError()
413
 
414
- def output(self) -> pd.DataFrame:
415
  o = getattr(self._param, self._param.output_var_name)
416
  if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
417
  if not isinstance(o, list): o = [o]
418
  o = pd.DataFrame(o)
419
- return self._param.output_var_name, o
 
 
 
 
 
 
 
 
 
 
 
420
 
421
  def reset(self):
422
  setattr(self._param, self._param.output_var_name, None)
@@ -446,7 +457,7 @@ class ComponentBase(ABC):
446
  if self.component_name.lower().find("answer") >= 0:
447
  if self.get_component_name(u) in ["relevant"]: continue
448
 
449
- upstream_outs.append(self._canvas.get_component(u)["obj"].output()[1])
450
  break
451
 
452
  return pd.concat(upstream_outs, ignore_index=False)
 
19
  import os
20
  from copy import deepcopy
21
  from functools import partial
22
+ from typing import List, Dict, Tuple, Union
23
 
24
  import pandas as pd
25
 
 
246
  def check_empty(param, descr):
247
  if not param:
248
  raise ValueError(
249
+ descr + " does not support empty value."
250
  )
251
 
252
  @staticmethod
 
411
  def _run(self, history, **kwargs):
412
  raise NotImplementedError()
413
 
414
+ def output(self, allow_partial=True) -> Tuple[str, Union[pd.DataFrame, partial]]:
415
  o = getattr(self._param, self._param.output_var_name)
416
  if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
417
  if not isinstance(o, list): o = [o]
418
  o = pd.DataFrame(o)
419
+
420
+ if allow_partial or not isinstance(o, partial):
421
+ if not isinstance(o, partial) and not isinstance(o, pd.DataFrame):
422
+ return pd.DataFrame(o if isinstance(o, list) else [o])
423
+ return self._param.output_var_name, o
424
+
425
+ outs = None
426
+ for oo in o():
427
+ if not isinstance(oo, pd.DataFrame):
428
+ outs = pd.DataFrame(oo if isinstance(oo, list) else [oo])
429
+ else: outs = oo
430
+ return self._param.output_var_name, outs
431
 
432
  def reset(self):
433
  setattr(self._param, self._param.output_var_name, None)
 
457
  if self.component_name.lower().find("answer") >= 0:
458
  if self.get_component_name(u) in ["relevant"]: continue
459
 
460
+ else: upstream_outs.append(self._canvas.get_component(u)["obj"].output(allow_partial=False)[1])
461
  break
462
 
463
  return pd.concat(upstream_outs, ignore_index=False)
graph/component/categorize.py CHANGED
@@ -35,7 +35,10 @@ class CategorizeParam(GenerateParam):
35
 
36
  def check(self):
37
  super().check()
38
- self.check_empty(self.category_description, "Category examples")
 
 
 
39
 
40
  def get_prompt(self):
41
  cate_lines = []
 
35
 
36
  def check(self):
37
  super().check()
38
+ self.check_empty(self.category_description, "[Categorize] Category examples")
39
+ for k, v in self.category_description.items():
40
+ if not k: raise ValueError(f"[Categorize] Category name can not be empty!")
41
+ if not v["to"]: raise ValueError(f"[Categorize] 'To' of category {k} can not be empty!")
42
 
43
  def get_prompt(self):
44
  cate_lines = []
graph/component/generate.py CHANGED
@@ -33,31 +33,31 @@ class GenerateParam(ComponentParamBase):
33
  super().__init__()
34
  self.llm_id = ""
35
  self.prompt = ""
36
- self.max_tokens = 256
37
- self.temperature = 0.1
38
- self.top_p = 0.3
39
- self.presence_penalty = 0.4
40
- self.frequency_penalty = 0.7
41
  self.cite = True
42
- #self.parameters = []
43
 
44
  def check(self):
45
- self.check_decimal_float(self.temperature, "Temperature")
46
- self.check_decimal_float(self.presence_penalty, "Presence penalty")
47
- self.check_decimal_float(self.frequency_penalty, "Frequency penalty")
48
- self.check_positive_number(self.max_tokens, "Max tokens")
49
- self.check_decimal_float(self.top_p, "Top P")
50
- self.check_empty(self.llm_id, "LLM")
51
- #self.check_defined_type(self.parameters, "Parameters", ["list"])
52
 
53
  def gen_conf(self):
54
- return {
55
- "max_tokens": self.max_tokens,
56
- "temperature": self.temperature,
57
- "top_p": self.top_p,
58
- "presence_penalty": self.presence_penalty,
59
- "frequency_penalty": self.frequency_penalty,
60
- }
61
 
62
 
63
  class Generate(ComponentBase):
@@ -69,12 +69,15 @@ class Generate(ComponentBase):
69
 
70
  retrieval_res = self.get_input()
71
  input = "\n- ".join(retrieval_res["content"])
72
-
 
 
 
73
 
74
  kwargs["input"] = input
75
  for n, v in kwargs.items():
76
- #prompt = re.sub(r"\{%s\}"%n, re.escape(str(v)), prompt)
77
- prompt = re.sub(r"\{%s\}"%n, str(v), prompt)
78
 
79
  if kwargs.get("stream"):
80
  return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
@@ -82,23 +85,25 @@ class Generate(ComponentBase):
82
  if "empty_response" in retrieval_res.columns:
83
  return Generate.be_output(input)
84
 
85
- ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size), self._param.gen_conf())
 
86
 
87
  if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
88
  ans, idx = retrievaler.insert_citations(ans,
89
- [ck["content_ltks"]
90
- for _, ck in retrieval_res.iterrows()],
91
- [ck["vector"]
92
- for _,ck in retrieval_res.iterrows()],
93
- LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, self._canvas.get_embedding_model()),
94
- tkweight=0.7,
95
- vtweight=0.3)
 
96
  del retrieval_res["vector"]
97
  retrieval_res = retrieval_res.to_dict("records")
98
  df = []
99
  for i in idx:
100
  df.append(retrieval_res[int(i)])
101
- r = re.search(r"^((.|[\r\n])*? ##%s\$\$)"%str(i), ans)
102
  assert r, f"{i} => {ans}"
103
  df[-1]["content"] = r.group(1)
104
  ans = re.sub(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), "", ans)
@@ -116,20 +121,22 @@ class Generate(ComponentBase):
116
  return
117
 
118
  answer = ""
119
- for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size), self._param.gen_conf()):
 
120
  res = {"content": ans, "reference": []}
121
  answer = ans
122
  yield res
123
 
124
  if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
125
  answer, idx = retrievaler.insert_citations(answer,
126
- [ck["content_ltks"]
127
- for _, ck in retrieval_res.iterrows()],
128
- [ck["vector"]
129
- for _, ck in retrieval_res.iterrows()],
130
- LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING, self._canvas.get_embedding_model()),
131
- tkweight=0.7,
132
- vtweight=0.3)
 
133
  doc_ids = set([])
134
  recall_docs = []
135
  for i in idx:
@@ -152,5 +159,3 @@ class Generate(ComponentBase):
152
  yield res
153
 
154
  self.set_output(res)
155
-
156
-
 
33
  super().__init__()
34
  self.llm_id = ""
35
  self.prompt = ""
36
+ self.max_tokens = 0
37
+ self.temperature = 0
38
+ self.top_p = 0
39
+ self.presence_penalty = 0
40
+ self.frequency_penalty = 0
41
  self.cite = True
42
+ self.parameters = []
43
 
44
  def check(self):
45
+ self.check_decimal_float(self.temperature, "[Generate] Temperature")
46
+ self.check_decimal_float(self.presence_penalty, "[Generate] Presence penalty")
47
+ self.check_decimal_float(self.frequency_penalty, "[Generate] Frequency penalty")
48
+ self.check_nonnegative_number(self.max_tokens, "[Generate] Max tokens")
49
+ self.check_decimal_float(self.top_p, "[Generate] Top P")
50
+ self.check_empty(self.llm_id, "[Generate] LLM")
51
+ # self.check_defined_type(self.parameters, "Parameters", ["list"])
52
 
53
  def gen_conf(self):
54
+ conf = {}
55
+ if self.max_tokens > 0: conf["max_tokens"] = self.max_tokens
56
+ if self.temperature > 0: conf["temperature"] = self.temperature
57
+ if self.top_p > 0: conf["top_p"] = self.top_p
58
+ if self.presence_penalty > 0: conf["presence_penalty"] = self.presence_penalty
59
+ if self.frequency_penalty > 0: conf["frequency_penalty"] = self.frequency_penalty
60
+ return conf
61
 
62
 
63
  class Generate(ComponentBase):
 
69
 
70
  retrieval_res = self.get_input()
71
  input = "\n- ".join(retrieval_res["content"])
72
+ for para in self._param.parameters:
73
+ cpn = self._canvas.get_component(para["component_id"])["obj"]
74
+ _, out = cpn.output(allow_partial=False)
75
+ kwargs[para["key"]] = "\n - ".join(out["content"])
76
 
77
  kwargs["input"] = input
78
  for n, v in kwargs.items():
79
+ # prompt = re.sub(r"\{%s\}"%n, re.escape(str(v)), prompt)
80
+ prompt = re.sub(r"\{%s\}" % n, str(v), prompt)
81
 
82
  if kwargs.get("stream"):
83
  return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
 
85
  if "empty_response" in retrieval_res.columns:
86
  return Generate.be_output(input)
87
 
88
+ ans = chat_mdl.chat(prompt, self._canvas.get_history(self._param.message_history_window_size),
89
+ self._param.gen_conf())
90
 
91
  if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
92
  ans, idx = retrievaler.insert_citations(ans,
93
+ [ck["content_ltks"]
94
+ for _, ck in retrieval_res.iterrows()],
95
+ [ck["vector"]
96
+ for _, ck in retrieval_res.iterrows()],
97
+ LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
98
+ self._canvas.get_embedding_model()),
99
+ tkweight=0.7,
100
+ vtweight=0.3)
101
  del retrieval_res["vector"]
102
  retrieval_res = retrieval_res.to_dict("records")
103
  df = []
104
  for i in idx:
105
  df.append(retrieval_res[int(i)])
106
+ r = re.search(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), ans)
107
  assert r, f"{i} => {ans}"
108
  df[-1]["content"] = r.group(1)
109
  ans = re.sub(r"^((.|[\r\n])*? ##%s\$\$)" % str(i), "", ans)
 
121
  return
122
 
123
  answer = ""
124
+ for ans in chat_mdl.chat_streamly(prompt, self._canvas.get_history(self._param.message_history_window_size),
125
+ self._param.gen_conf()):
126
  res = {"content": ans, "reference": []}
127
  answer = ans
128
  yield res
129
 
130
  if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
131
  answer, idx = retrievaler.insert_citations(answer,
132
+ [ck["content_ltks"]
133
+ for _, ck in retrieval_res.iterrows()],
134
+ [ck["vector"]
135
+ for _, ck in retrieval_res.iterrows()],
136
+ LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
137
+ self._canvas.get_embedding_model()),
138
+ tkweight=0.7,
139
+ vtweight=0.3)
140
  doc_ids = set([])
141
  recall_docs = []
142
  for i in idx:
 
159
  yield res
160
 
161
  self.set_output(res)
 
 
graph/component/message.py CHANGED
@@ -32,7 +32,7 @@ class MessageParam(ComponentParamBase):
32
  self.messages = []
33
 
34
  def check(self):
35
- self.check_empty(self.messages, "Message")
36
  return True
37
 
38
 
 
32
  self.messages = []
33
 
34
  def check(self):
35
+ self.check_empty(self.messages, "[Message]")
36
  return True
37
 
38
 
graph/component/relevant.py CHANGED
@@ -33,6 +33,8 @@ class RelevantParam(GenerateParam):
33
 
34
  def check(self):
35
  super().check()
 
 
36
 
37
  def get_prompt(self):
38
  self.prompt = """
 
33
 
34
  def check(self):
35
  super().check()
36
+ self.check_empty(self.yes, "[Relevant] 'Yes'")
37
+ self.check_empty(self.no, "[Relevant] 'No'")
38
 
39
  def get_prompt(self):
40
  self.prompt = """
graph/component/retrieval.py CHANGED
@@ -40,10 +40,10 @@ class RetrievalParam(ComponentParamBase):
40
  self.empty_response = ""
41
 
42
  def check(self):
43
- self.check_decimal_float(self.similarity_threshold, "Similarity threshold")
44
- self.check_decimal_float(self.keywords_similarity_weight, "Keywords similarity weight")
45
- self.check_positive_number(self.top_n, "Top N")
46
- self.check_empty(self.kb_ids, "Knowledge bases")
47
 
48
 
49
  class Retrieval(ComponentBase, ABC):
 
40
  self.empty_response = ""
41
 
42
  def check(self):
43
+ self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
44
+ self.check_decimal_float(self.keywords_similarity_weight, "[Retrieval] Keywords similarity weight")
45
+ self.check_positive_number(self.top_n, "[Retrieval] Top N")
46
+ self.check_empty(self.kb_ids, "[Retrieval] Knowledge bases")
47
 
48
 
49
  class Retrieval(ComponentBase, ABC):
graph/component/switch.py CHANGED
@@ -44,8 +44,10 @@ class SwitchParam(ComponentParamBase):
44
  self.default = ""
45
 
46
  def check(self):
47
- self.check_empty(self.conditions, "Switch conditions")
48
- self.check_empty(self.default, "Default path")
 
 
49
 
50
  def operators(self, field, op, value):
51
  if op == "gt":
 
44
  self.default = ""
45
 
46
  def check(self):
47
+ self.check_empty(self.conditions, "[Switch] conditions")
48
+ self.check_empty(self.default, "[Switch] Default path")
49
+ for cond in self.conditions:
50
+ if not cond["to"]: raise ValueError(f"[Switch] 'To' can not be empty!")
51
 
52
  def operators(self, field, op, value):
53
  if op == "gt":