File size: 10,431 Bytes
a3ebd45 1b1a5b7 95863fc a3ebd45 6101699 6054f54 a3ebd45 255441a a3ebd45 255441a a3ebd45 255441a a3ebd45 255441a 0404a52 255441a a3ebd45 f586a68 d25ba26 f586a68 87d8c78 2cbcd9d ce5e150 6101699 87d8c78 0404a52 87d8c78 28e0efa 87d8c78 1b1a5b7 87d8c78 1b1a5b7 a004a85 1b1a5b7 a004a85 1b1a5b7 a3ebd45 70153b9 255441a 0404a52 d25ba26 58f507b d25ba26 58f507b d25ba26 6229599 f2439de 6229599 255441a a68b28a f2439de a68b28a 70153b9 d25ba26 70153b9 0404a52 a3ebd45 4db2410 a3ebd45 70153b9 b5e86a6 87d8c78 a3ebd45 fc6e68a a57190d a3ebd45 95863fc 0404a52 95863fc 0404a52 95863fc a3ebd45 a57190d a3ebd45 fc6e68a a3ebd45 95863fc 0404a52 95863fc 0404a52 a3ebd45 95863fc a3ebd45 87d8c78 a3ebd45 8db6538 a004a85 8db6538 a004a85 8db6538 a004a85 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 |
# Copyright 2024 The InfiniFlow Authors. All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from functools import partial
import pandas as pd
from api.db import LLMType
from import structure_answer
from import message_fit_in
from import LLMBundle
from api import settings
from agent.component.base import ComponentBase, ComponentParamBase
class GenerateParam(ComponentParamBase):
Define the Generate component parameters.
def __init__(self):
self.llm_id = ""
self.prompt = ""
self.max_tokens = 0
self.temperature = 0
self.top_p = 0
self.presence_penalty = 0
self.frequency_penalty = 0
self.cite = True
self.parameters = []
def check(self):
self.check_decimal_float(self.temperature, "[Generate] Temperature")
self.check_decimal_float(self.presence_penalty, "[Generate] Presence penalty")
self.check_decimal_float(self.frequency_penalty, "[Generate] Frequency penalty")
self.check_nonnegative_number(self.max_tokens, "[Generate] Max tokens")
self.check_decimal_float(self.top_p, "[Generate] Top P")
self.check_empty(self.llm_id, "[Generate] LLM")
# self.check_defined_type(self.parameters, "Parameters", ["list"])
def gen_conf(self):
conf = {}
if self.max_tokens > 0:
conf["max_tokens"] = self.max_tokens
if self.temperature > 0:
conf["temperature"] = self.temperature
if self.top_p > 0:
conf["top_p"] = self.top_p
if self.presence_penalty > 0:
conf["presence_penalty"] = self.presence_penalty
if self.frequency_penalty > 0:
conf["frequency_penalty"] = self.frequency_penalty
return conf
class Generate(ComponentBase):
component_name = "Generate"
def get_dependent_components(self):
cpnts = set([para["component_id"].split("@")[0] for para in self._param.parameters \
if para.get("component_id") \
and para["component_id"].lower().find("answer") < 0 \
and para["component_id"].lower().find("begin") < 0])
return list(cpnts)
def set_cite(self, retrieval_res, answer):
retrieval_res = retrieval_res.dropna(subset=["vector", "content_ltks"]).reset_index(drop=True)
if "empty_response" in retrieval_res.columns:
retrieval_res["empty_response"].fillna("", inplace=True)
answer, idx = settings.retrievaler.insert_citations(answer,
[ck["content_ltks"] for _, ck in retrieval_res.iterrows()],
[ck["vector"] for _, ck in retrieval_res.iterrows()],
LLMBundle(self._canvas.get_tenant_id(), LLMType.EMBEDDING,
self._canvas.get_embedding_model()), tkweight=0.7,
doc_ids = set([])
recall_docs = []
for i in idx:
did = retrieval_res.loc[int(i), "doc_id"]
if did in doc_ids:
recall_docs.append({"doc_id": did, "doc_name": retrieval_res.loc[int(i), "docnm_kwd"]})
del retrieval_res["vector"]
del retrieval_res["content_ltks"]
reference = {
"chunks": [ck.to_dict() for _, ck in retrieval_res.iterrows()],
"doc_aggs": recall_docs
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
res = {"content": answer, "reference": reference}
res = structure_answer(None, res, "", "")
return res
def get_input_elements(self):
if self._param.parameters:
return [{"key": "user", "name": "User"}, *self._param.parameters]
return [{"key": "user", "name": "User"}]
def _run(self, history, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
prompt = self._param.prompt
retrieval_res = []
self._param.inputs = []
for para in self._param.parameters:
if not para.get("component_id"):
component_id = para["component_id"].split("@")[0]
if para["component_id"].lower().find("@") >= 0:
cpn_id, key = para["component_id"].split("@")
for p in self._canvas.get_component(cpn_id)["obj"]._param.query:
if p["key"] == key:
kwargs[para["key"]] = p.get("value", "")
{"component_id": para["component_id"], "content": kwargs[para["key"]]})
assert False, f"Can't find parameter '{key}' for {cpn_id}"
cpn = self._canvas.get_component(component_id)["obj"]
if cpn.component_name.lower() == "answer":
hist = self._canvas.get_history(1)
if hist:
hist = hist[0]["content"]
hist = ""
kwargs[para["key"]] = hist
_, out = cpn.output(allow_partial=False)
if "content" not in out.columns:
kwargs[para["key"]] = ""
if cpn.component_name.lower() == "retrieval":
kwargs[para["key"]] = " - "+"\n - ".join([o if isinstance(o, str) else str(o) for o in out["content"]])
self._param.inputs.append({"component_id": para["component_id"], "content": kwargs[para["key"]]})
if retrieval_res:
retrieval_res = pd.concat(retrieval_res, ignore_index=True)
retrieval_res = pd.DataFrame([])
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
if not self._param.inputs and prompt.find("{input}") >= 0:
retrieval_res = self.get_input()
input = (" - " + "\n - ".join(
[c for c in retrieval_res["content"] if isinstance(c, str)])) if "content" in retrieval_res else ""
prompt = re.sub(r"\{input\}", re.escape(input), prompt)
downstreams = self._canvas.get_component(self._id)["downstream"]
if kwargs.get("stream") and len(downstreams) == 1 and self._canvas.get_component(downstreams[0])[
"obj"].component_name.lower() == "answer":
return partial(self.stream_output, chat_mdl, prompt, retrieval_res)
if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]):
res = {"content": "\n- ".join(retrieval_res["empty_response"]) if "\n- ".join(
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
return pd.DataFrame([res])
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
ans =[0]["content"], msg[1:], self._param.gen_conf())
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
res = self.set_cite(retrieval_res, ans)
return pd.DataFrame([res])
return Generate.be_output(ans)
def stream_output(self, chat_mdl, prompt, retrieval_res):
res = None
if "empty_response" in retrieval_res.columns and not "".join(retrieval_res["content"]):
res = {"content": "\n- ".join(retrieval_res["empty_response"]) if "\n- ".join(
retrieval_res["empty_response"]) else "Nothing found in knowledgebase!", "reference": []}
yield res
msg = self._canvas.get_history(self._param.message_history_window_size)
if len(msg) < 1:
msg.append({"role": "user", "content": ""})
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
if len(msg) < 2:
msg.append({"role": "user", "content": ""})
answer = ""
for ans in chat_mdl.chat_streamly(msg[0]["content"], msg[1:], self._param.gen_conf()):
res = {"content": ans, "reference": []}
answer = ans
yield res
if self._param.cite and "content_ltks" in retrieval_res.columns and "vector" in retrieval_res.columns:
res = self.set_cite(retrieval_res, answer)
yield res
def debug(self, **kwargs):
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
prompt = self._param.prompt
for para in self._param.debug_inputs:
kwargs[para["key"]] = para.get("value", "")
for n, v in kwargs.items():
prompt = re.sub(r"\{%s\}" % re.escape(n), str(v).replace("\\", " "), prompt)
ans =, [{"role": "user", "content": kwargs.get("user", "")}], self._param.gen_conf())
return pd.DataFrame([ans])