Spaces:
Running
Running
import json | |
import importlib | |
import copy | |
class Stage: | |
def __init__(self, stage_json, ai_llm, pathway_parameters=None): | |
self.stage_json = stage_json | |
self.ai_llm = ai_llm | |
if "ui" in stage_json: | |
self.ui = copy.deepcopy(stage_json["ui"]) | |
else: | |
self.ui = {} | |
#If "text msgs are in the UI, add them to the ui component in the form of an llm message list" | |
if "text" in self.ui: | |
self.ui["text"] = [] | |
for msg in stage_json["ui"]["text"]: | |
self.ui["text"].append({ | |
"role": "assistant", | |
"content": msg | |
}) | |
else: | |
self.ui["text"] = [] | |
#Any parameters that the stage wants to pass on to subsequent stages | |
self.stage_parameters = {} | |
#Key values that the stage will use to track its progress | |
self.stage_key_values = { | |
"stage_complete": False, | |
"next_stage": None, | |
"max_interactions": 10, | |
"interactions_for_hint": 3 | |
} | |
if pathway_parameters: | |
self.stage_key_values["max_interactions"] = pathway_parameters.get("max_interactions", 10) | |
self.stage_key_values["interactions_for_hint"] = pathway_parameters.get("interactions_for_hint", 3) | |
if "max_interactions" in self.stage_json: | |
self.stage_key_values["max_interactions"] = self.stage_json["max_interactions"] | |
if "interactions_for_hint" in self.stage_json: | |
self.stage_key_values["interactions_for_hint"] = self.stage_json["interactions_for_hint"] | |
#Total number of interactions within the stage | |
self.stage_key_values["interaction_count"] = 0 | |
#Number of interactions before each hint is given (different hints at different phases of the stage) | |
self.stage_key_values["interaction_count_for_hint"] = 0 | |
self.stage_prompt = "" | |
file_no = 1 | |
while "prompt_file" + str(file_no) in self.stage_json: | |
file_name = self.stage_json["prompt_file" + str(file_no)] | |
with open(file_name, "r") as file: | |
self.stage_prompt += file.read() | |
file_no += 1 | |
self.llm_messages = [{"role": "system", "content": self.stage_prompt}] | |
self.functions_list = None | |
self.functions_module = None | |
self.structured_output_schema = None | |
if "structured_output_schema" in self.stage_json: | |
self.structured_output_schema = self.stage_json["structured_output_schema"] | |
self.functions_module = importlib.import_module(self.stage_json["functions_definition"]) | |
elif "functions_list" in self.stage_json: | |
with open(self.stage_json["functions_list"], "r") as file: | |
self.functions_list = json.load(file) | |
self.functions_module = importlib.import_module(self.stage_json["functions_definition"]) | |
def next_step(self, player_input): | |
self.stage_key_values["interaction_count"] += 1 | |
self.stage_key_values["interaction_count_for_hint"] += 1 | |
checker_ui = None | |
response_ui = {"text": [""]} | |
if not self.stage_key_values["stage_complete"]: | |
self.llm_messages.append({"role": "user", "content": player_input}) | |
self.ui["text"].append({"role": "user", "content": player_input}) | |
if self.structured_output_schema: | |
response_ui = self.ai_llm.get_structured_output(self.llm_messages, self.structured_output_schema, self.functions_module, self.stage_key_values) | |
else: | |
response_ui = self.ai_llm.talk_to_LLM(self.llm_messages, self.functions_list, self.functions_module, self.stage_key_values) | |
self.ui["text"].append({"role": "assistant", "content": response_ui["text"]}) | |
if self.stage_key_values["interaction_count"] >= self.stage_key_values["max_interactions"]: | |
forced_exit_function = getattr(self.functions_module, "forced_stage_exit", None) | |
if forced_exit_function: | |
forced_exit_msg = forced_exit_function(stage_key_values=self.stage_key_values) | |
response_ui["text"] += forced_exit_msg["text"] | |
self.ui["text"].append({"role": "assistant", "content": response_ui["text"]}) | |
return True, response_ui, checker_ui | |
else: | |
print("No forced exit function found in module", self.functions_module.__name__ if self.functions_module else "None") | |
if self.stage_key_values["interaction_count_for_hint"] >= self.stage_key_values["interactions_for_hint"] and not self.stage_key_values["stage_complete"]: | |
checker_function = getattr(self.functions_module, "checker_function", None) | |
if checker_function: | |
checker_ui = checker_function(stage_key_values=self.stage_key_values) | |
else: | |
print("No checker function found in module", self.functions_module.__name__ if self.functions_module else "None") | |
return self.stage_key_values["stage_complete"], response_ui, checker_ui | |
def get_next_stage(self): | |
return self.stage_key_values["next_stage"] | |
def get_stage_parameters(self): | |
return self.get_stage_parameters | |
def get_initial_ui(self): | |
return copy.deepcopy(self.stage_json.get("ui", None)) |