import logging |
import json |
from abc import ABC |
from copy import deepcopy |
from functools import partial |
from agent.component import component_class |
from agent.component.base import ComponentBase |
class Canvas(ABC): |
""" |
dsl = { |
"components": { |
"begin": { |
"obj":{ |
"component_name": "Begin", |
"params": {}, |
}, |
"downstream": ["answer_0"], |
"upstream": [], |
}, |
"answer_0": { |
"obj": { |
"component_name": "Answer", |
"params": {} |
}, |
"downstream": ["retrieval_0"], |
"upstream": ["begin", "generate_0"], |
}, |
"retrieval_0": { |
"obj": { |
"component_name": "Retrieval", |
"params": {} |
}, |
"downstream": ["generate_0"], |
"upstream": ["answer_0"], |
}, |
"generate_0": { |
"obj": { |
"component_name": "Generate", |
"params": {} |
}, |
"downstream": ["answer_0"], |
"upstream": ["retrieval_0"], |
} |
}, |
"history": [], |
"messages": [], |
"reference": [], |
"path": [["begin"]], |
"answer": [] |
} |
""" |
def __init__(self, dsl: str, tenant_id=None): |
self.path = [] |
self.history = [] |
self.messages = [] |
self.answer = [] |
self.components = {} |
self.dsl = json.loads(dsl) if dsl else { |
"components": { |
"begin": { |
"obj": { |
"component_name": "Begin", |
"params": { |
"prologue": "Hi there!" |
} |
}, |
"downstream": [], |
"upstream": [] |
} |
}, |
"history": [], |
"messages": [], |
"reference": [], |
"path": [], |
"answer": [] |
} |
self._tenant_id = tenant_id |
self._embed_id = "" |
self.load() |
def load(self): |
self.components = self.dsl["components"] |
cpn_nms = set([]) |
for k, cpn in self.components.items(): |
cpn_nms.add(cpn["obj"]["component_name"]) |
assert "Begin" in cpn_nms, "There have to be an 'Begin' component." |
assert "Answer" in cpn_nms, "There have to be an 'Answer' component." |
for k, cpn in self.components.items(): |
cpn_nms.add(cpn["obj"]["component_name"]) |
param = component_class(cpn["obj"]["component_name"] + "Param")() |
param.update(cpn["obj"]["params"]) |
param.check() |
cpn["obj"] = component_class(cpn["obj"]["component_name"])(self, k, param) |
if cpn["obj"].component_name == "Categorize": |
for _, desc in param.category_description.items(): |
if desc["to"] not in cpn["downstream"]: |
cpn["downstream"].append(desc["to"]) |
self.path = self.dsl["path"] |
self.history = self.dsl["history"] |
self.messages = self.dsl["messages"] |
self.answer = self.dsl["answer"] |
self.reference = self.dsl["reference"] |
self._embed_id = self.dsl.get("embed_id", "") |
def __str__(self): |
self.dsl["path"] = self.path |
self.dsl["history"] = self.history |
self.dsl["messages"] = self.messages |
self.dsl["answer"] = self.answer |
self.dsl["reference"] = self.reference |
self.dsl["embed_id"] = self._embed_id |
dsl = { |
"components": {} |
} |
for k in self.dsl.keys(): |
if k in ["components"]: |
continue |
dsl[k] = deepcopy(self.dsl[k]) |
for k, cpn in self.components.items(): |
if k not in dsl["components"]: |
dsl["components"][k] = {} |
for c in cpn.keys(): |
if c == "obj": |
dsl["components"][k][c] = json.loads(str(cpn["obj"])) |
continue |
dsl["components"][k][c] = deepcopy(cpn[c]) |
return json.dumps(dsl, ensure_ascii=False) |
def reset(self): |
self.path = [] |
self.history = [] |
self.messages = [] |
self.answer = [] |
self.reference = [] |
for k, cpn in self.components.items(): |
self.components[k]["obj"].reset() |
self._embed_id = "" |
def get_compnent_name(self, cid): |
for n in self.dsl["graph"]["nodes"]: |
if cid == n["id"]: |
return n["data"]["name"] |
return "" |
def run(self, **kwargs): |
if self.answer: |
cpn_id = self.answer[0] |
self.answer.pop(0) |
try: |
ans = self.components[cpn_id]["obj"].run(self.history, **kwargs) |
except Exception as e: |
ans = ComponentBase.be_output(str(e)) |
self.path[-1].append(cpn_id) |
if kwargs.get("stream"): |
for an in ans(): |
yield an |
else: |
yield ans |
return |
if not self.path: |
self.components["begin"]["obj"].run(self.history, **kwargs) |
self.path.append(["begin"]) |
self.path.append([]) |
ran = -1 |
waiting = [] |
without_dependent_checking = [] |
def prepare2run(cpns): |
nonlocal ran, ans |
for c in cpns: |
if self.path[-1] and c == self.path[-1][-1]: |
continue |
cpn = self.components[c]["obj"] |
if cpn.component_name == "Answer": |
self.answer.append(c) |
else: |
logging.debug(f"Canvas.prepare2run: {c}") |
if c not in without_dependent_checking: |
cpids = cpn.get_dependent_components() |
if any([cc not in self.path[-1] for cc in cpids]): |
if c not in waiting: |
waiting.append(c) |
continue |
yield "*'{}'* is running...🕞".format(self.get_compnent_name(c)) |
try: |
ans = cpn.run(self.history, **kwargs) |
except Exception as e: |
logging.exception(f"Canvas.run got exception: {e}") |
self.path[-1].append(c) |
ran += 1 |
raise e |
self.path[-1].append(c) |
ran += 1 |
for m in prepare2run(self.components[self.path[-2][-1]]["downstream"]): |
yield {"content": m, "running_status": True} |
while 0 <= ran < len(self.path[-1]): |
logging.debug(f"Canvas.run: {ran} {self.path}") |
cpn_id = self.path[-1][ran] |
cpn = self.get_component(cpn_id) |
if not cpn["downstream"]: |
break |
loop = self._find_loop() |
if loop: |
raise OverflowError(f"Too much loops: {loop}") |
if cpn["obj"].component_name.lower() in ["switch", "categorize", "relevant"]: |
switch_out = cpn["obj"].output()[1].iloc[0, 0] |
assert switch_out in self.components, \ |
"{}'s output: {} not valid.".format(cpn_id, switch_out) |
for m in prepare2run([switch_out]): |
yield {"content": m, "running_status": True} |
continue |
for m in prepare2run(cpn["downstream"]): |
yield {"content": m, "running_status": True} |
if ran >= len(self.path[-1]) and waiting: |
without_dependent_checking = waiting |
waiting = [] |
for m in prepare2run(without_dependent_checking): |
yield {"content": m, "running_status": True} |
ran -= 1 |
if self.answer: |
cpn_id = self.answer[0] |
self.answer.pop(0) |
ans = self.components[cpn_id]["obj"].run(self.history, **kwargs) |
self.path[-1].append(cpn_id) |
if kwargs.get("stream"): |
assert isinstance(ans, partial) |
for an in ans(): |
yield an |
else: |
yield ans |
else: |
raise Exception("The dialog flow has no way to interact with you. Please add an 'Interact' component to the end of the flow.") |
def get_component(self, cpn_id): |
return self.components[cpn_id] |
def get_tenant_id(self): |
return self._tenant_id |
def get_history(self, window_size): |
convs = [] |
for role, obj in self.history[window_size * -1:]: |
if isinstance(obj, list) and obj and all([isinstance(o, dict) for o in obj]): |
convs.append({"role": role, "content": '\n'.join([str(s.get("content", "")) for s in obj])}) |
else: |
convs.append({"role": role, "content": str(obj)}) |
return convs |
def add_user_input(self, question): |
self.history.append(("user", question)) |
def set_embedding_model(self, embed_id): |
self._embed_id = embed_id |
def get_embedding_model(self): |
return self._embed_id |
def _find_loop(self, max_loops=6): |
path = self.path[-1][::-1] |
if len(path) < 2: |
return False |
for i in range(len(path)): |
if path[i].lower().find("answer") >= 0: |
path = path[:i] |
break |
if len(path) < 2: |
return False |
for loc in range(2, len(path) // 2): |
pat = ",".join(path[0:loc]) |
path_str = ",".join(path) |
if len(pat) >= len(path_str): |
return False |
loop = max_loops |
while path_str.find(pat) == 0 and loop >= 0: |
loop -= 1 |
if len(pat)+1 >= len(path_str): |
return False |
path_str = path_str[len(pat)+1:] |
if loop < 0: |
pat = " => ".join([p.split(":")[0] for p in path[0:loc]]) |
return pat + " => " + pat |
return False |
def get_prologue(self): |
return self.components["begin"]["obj"]._param.prologue |
def set_global_param(self, **kwargs): |
for k, v in kwargs.items(): |
for q in self.components["begin"]["obj"]._param.query: |
if k != q["key"]: |
continue |
q["value"] = v |
def get_preset_param(self): |
return self.components["begin"]["obj"]._param.query |
def get_component_input_elements(self, cpnnm): |
return self.components[cpnnm]["obj"].get_input_elements() |