vJul / stage.py
deepaksj's picture
Upload 23 files
9eafbe3 verified
import json
import importlib
import copy
class Stage:
def __init__(self, stage_json, ai_llm, pathway_parameters=None):
self.stage_json = stage_json
self.ai_llm = ai_llm
if "ui" in stage_json:
self.ui = copy.deepcopy(stage_json["ui"])
else:
self.ui = {}
#If "text msgs are in the UI, add them to the ui component in the form of an llm message list"
if "text" in self.ui:
self.ui["text"] = []
for msg in stage_json["ui"]["text"]:
self.ui["text"].append({
"role": "assistant",
"content": msg
})
else:
self.ui["text"] = []
#Any parameters that the stage wants to pass on to subsequent stages
self.stage_parameters = {}
#Key values that the stage will use to track its progress
self.stage_key_values = {
"stage_complete": False,
"next_stage": None,
"max_interactions": 10,
"interactions_for_hint": 3
}
if pathway_parameters:
self.stage_key_values["max_interactions"] = pathway_parameters.get("max_interactions", 10)
self.stage_key_values["interactions_for_hint"] = pathway_parameters.get("interactions_for_hint", 3)
if "max_interactions" in self.stage_json:
self.stage_key_values["max_interactions"] = self.stage_json["max_interactions"]
if "interactions_for_hint" in self.stage_json:
self.stage_key_values["interactions_for_hint"] = self.stage_json["interactions_for_hint"]
#Total number of interactions within the stage
self.stage_key_values["interaction_count"] = 0
#Number of interactions before each hint is given (different hints at different phases of the stage)
self.stage_key_values["interaction_count_for_hint"] = 0
self.stage_prompt = ""
file_no = 1
while "prompt_file" + str(file_no) in self.stage_json:
file_name = self.stage_json["prompt_file" + str(file_no)]
with open(file_name, "r") as file:
self.stage_prompt += file.read()
file_no += 1
self.llm_messages = [{"role": "system", "content": self.stage_prompt}]
self.functions_list = None
self.functions_module = None
self.structured_output_schema = None
if "structured_output_schema" in self.stage_json:
self.structured_output_schema = self.stage_json["structured_output_schema"]
self.functions_module = importlib.import_module(self.stage_json["functions_definition"])
elif "functions_list" in self.stage_json:
with open(self.stage_json["functions_list"], "r") as file:
self.functions_list = json.load(file)
self.functions_module = importlib.import_module(self.stage_json["functions_definition"])
def next_step(self, player_input):
self.stage_key_values["interaction_count"] += 1
self.stage_key_values["interaction_count_for_hint"] += 1
checker_ui = None
response_ui = {"text": [""]}
if not self.stage_key_values["stage_complete"]:
self.llm_messages.append({"role": "user", "content": player_input})
self.ui["text"].append({"role": "user", "content": player_input})
if self.structured_output_schema:
response_ui = self.ai_llm.get_structured_output(self.llm_messages, self.structured_output_schema, self.functions_module, self.stage_key_values)
else:
response_ui = self.ai_llm.talk_to_LLM(self.llm_messages, self.functions_list, self.functions_module, self.stage_key_values)
self.ui["text"].append({"role": "assistant", "content": response_ui["text"]})
if self.stage_key_values["interaction_count"] >= self.stage_key_values["max_interactions"]:
forced_exit_function = getattr(self.functions_module, "forced_stage_exit", None)
if forced_exit_function:
forced_exit_msg = forced_exit_function(stage_key_values=self.stage_key_values)
response_ui["text"] += forced_exit_msg["text"]
self.ui["text"].append({"role": "assistant", "content": response_ui["text"]})
return True, response_ui, checker_ui
else:
print("No forced exit function found in module", self.functions_module.__name__ if self.functions_module else "None")
if self.stage_key_values["interaction_count_for_hint"] >= self.stage_key_values["interactions_for_hint"] and not self.stage_key_values["stage_complete"]:
checker_function = getattr(self.functions_module, "checker_function", None)
if checker_function:
checker_ui = checker_function(stage_key_values=self.stage_key_values)
else:
print("No checker function found in module", self.functions_module.__name__ if self.functions_module else "None")
return self.stage_key_values["stage_complete"], response_ui, checker_ui
def get_next_stage(self):
return self.stage_key_values["next_stage"]
def get_stage_parameters(self):
return self.get_stage_parameters
def get_initial_ui(self):
return copy.deepcopy(self.stage_json.get("ui", None))