Kevin Hu commited on
Commit
f586a68
·
1 Parent(s): cf813df

fix bugs about multi input for generate (#1525)

Browse files

### What problem does this PR solve?



### Type of change

- [x] Bug Fix (non-breaking change which fixes an issue)

graph/canvas.py CHANGED
@@ -193,9 +193,13 @@ class Canvas(ABC):
193
  self.answer.append(c)
194
  else:
195
  if DEBUG: print("RUN: ", c)
 
 
 
 
196
  ans = cpn.run(self.history, **kwargs)
197
  self.path[-1].append(c)
198
- ran += 1
199
 
200
  prepare2run(self.components[self.path[-2][-1]]["downstream"])
201
  while 0 <= ran < len(self.path[-1]):
 
193
  self.answer.append(c)
194
  else:
195
  if DEBUG: print("RUN: ", c)
196
+ if cpn.component_name == "Generate":
197
+ cpids = cpn.get_dependent_components()
198
+ if any([c not in self.path[-1] for c in cpids]):
199
+ continue
200
  ans = cpn.run(self.history, **kwargs)
201
  self.path[-1].append(c)
202
+ ran += 1
203
 
204
  prepare2run(self.components[self.path[-2][-1]]["downstream"])
205
  while 0 <= ran < len(self.path[-1]):
graph/component/base.py CHANGED
@@ -445,6 +445,7 @@ class ComponentBase(ABC):
445
  if DEBUG: print(self.component_name, reversed_cpnts[::-1])
446
  for u in reversed_cpnts[::-1]:
447
  if self.get_component_name(u) in ["switch"]: continue
 
448
  if self.component_name.lower().find("switch") < 0 \
449
  and self.get_component_name(u) in ["relevant", "categorize"]:
450
  continue
 
445
  if DEBUG: print(self.component_name, reversed_cpnts[::-1])
446
  for u in reversed_cpnts[::-1]:
447
  if self.get_component_name(u) in ["switch"]: continue
448
+ if u not in self._canvas.get_component(self._id)["upstream"]: continue
449
  if self.component_name.lower().find("switch") < 0 \
450
  and self.get_component_name(u) in ["relevant", "categorize"]:
451
  continue
graph/component/generate.py CHANGED
@@ -63,6 +63,10 @@ class GenerateParam(ComponentParamBase):
63
  class Generate(ComponentBase):
64
  component_name = "Generate"
65
 
 
 
 
 
66
  def _run(self, history, **kwargs):
67
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
68
  prompt = self._param.prompt
 
63
  class Generate(ComponentBase):
64
  component_name = "Generate"
65
 
66
+ def get_dependent_components(self):
67
+ cpnts = [para["component_id"] for para in self._param.parameters]
68
+ return cpnts
69
+
70
  def _run(self, history, **kwargs):
71
  chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
72
  prompt = self._param.prompt