import os import json import uuid from typing import TypedDict, Annotated, Optional, Literal, List, Tuple import gradio as gr from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.messages import AIMessage, HumanMessage from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import StateGraph, END from langgraph.graph.message import add_messages from langgraph.types import Command # ----------------------------------------------------------------------------- # 🔑 OpenAI configuration # ----------------------------------------------------------------------------- OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") if not OPENAI_API_KEY: raise EnvironmentError("OPENAI_API_KEY must be set in your environment.") llm = ChatOpenAI(model="gpt-4.1", temperature=0.5) # ----------------------------------------------------------------------------- # 🗃️ Conversation state definition # ----------------------------------------------------------------------------- class AssistantState(TypedDict, total=False): """Global state tracked by the LangGraph.""" messages: Annotated[List, add_messages] search_mode: Optional[Literal["specific", "filter", "not_set"]] article_id: Optional[str] search_type: Optional[Literal["keyword", "visual+keyword", "visual"]] search_item:Optional[str] size: Optional[Literal["L", "M", "S", "XS"]] color: Optional[Literal["red", "blue", "white", "pink"]] gender: Optional[Literal["Male", "Female"]] conversation_stage: Literal["initial", "collecting_filters", "completed"] human_done: bool thread_id: str # 🔑 keep each reset isolated # ----------------------------------------------------------------------------- # 📝 Prompt helpers # ----------------------------------------------------------------------------- SYSTEM_TEMPLATE = """You are a helpful product-search assistant. Your job is to help users find products by either searching for a specific item or by applying filters. CONVERSATION FLOW: 1. **Initial stage** – ask whether the user wants a specific item or wishes to browse with filters. If user says he wants specific item select "keyword" search type. 2. **Specific-item mode** – if the user wants a specific item, only ask for the `article_id`. 3. **Filter search type** - Ask user if he knows the what item he wants to browse. If he wants to browse an item like for, e.g. linen shirts, black trousers etc. then choose "visual+keyword" Else if he says he wants something for wedding or for halloween etc. i.e. he is more ambiguous about fashion item, then choose "visual". Also help him narrow down exact fashion item in that case. Once he says this is the item he wants then help him fill other filters and populate search_item with that fashion item. 4. **Filter mode** – otherwise, collect these filters **one at a time**: • `size` → L | M | S | XS • `color` → red | blue | white | pink • `gender` → Male | Female 5. **Completion** – the user can say "done" or "finished" at any point; leave any unanswered filters `null`. CURRENT STATE (for your reference): - Search mode: {search_mode} - Conversation stage: {conversation_stage} - Search type: {search_type} - Article ID: {article_id} - Size: {size} - Color: {color} - Gender: {gender} INSTRUCTIONS: • Stay conversational and helpful. • Do **not** request `article_id` while in filter mode. • Validate all user inputs against the allowed values above; reprompt if invalid. • When the user says they're finished, summarise what they asked for and stop asking further questions. IMPORTANT – State hand-off: Along with every reply, append a tag of the form: STATE_UPDATE: {{{{"field": "value", ...}}}} Only include keys that need changing; an empty JSON object means no updates. Valid keys: search_mode, conversation_stage, search_type, article_id, size, color, gender. Make sure once required fields are obtained, you ask user if he is done """ def build_system_prompt(state: "AssistantState") -> str: """Fill the template with the live state values.""" return SYSTEM_TEMPLATE.format( search_mode=state.get("search_mode", "not_set"), conversation_stage=state.get("conversation_stage", "initial"), search_type=state.get("search_type", "null"), article_id=state.get("article_id", "null"), size=state.get("size", "null"), color=state.get("color", "null"), gender=state.get("gender", "null"), ) # ----------------------------------------------------------------------------- # 🔧 Utility helpers # ----------------------------------------------------------------------------- def extract_state_updates(msg: str) -> dict: """Pull the JSON blob from the assistant's reply.""" if "STATE_UPDATE:" not in msg: return {} try: json_part = msg.split("STATE_UPDATE:", 1)[1].strip() first = json_part.find("{") closing = json_part.find("}", first) data = json.loads(json_part[first : closing + 1]) return data if isinstance(data, dict) else {} except Exception: return {} def strip_state_update(msg: str) -> str: """Remove the STATE_UPDATE section so the user never sees it.""" return msg.split("STATE_UPDATE:")[0].rstrip() # ----------------------------------------------------------------------------- # 🤖 LangGraph node implementations # ----------------------------------------------------------------------------- def assistant_node(state: "AssistantState") -> Command[Literal["human_check"]]: """LLM turn that responds and proposes state mutations.""" prompt = ChatPromptTemplate.from_messages([ ("system", build_system_prompt(state)), ("placeholder", "{messages}"), ]) rendered = prompt.invoke({"messages": state["messages"]}) response = llm.invoke(rendered) updates = extract_state_updates(response.content) assistant_msg = AIMessage(content=strip_state_update(response.content)) combined_update = {"messages": [assistant_msg], **updates} return Command(update=combined_update, goto="human_check") def human_check_node(state: "AssistantState") -> Command[Literal["assistant"]]: """Check the latest human message for a termination phrase, then END.""" last_human = next( (m.content for m in reversed(state["messages"]) if isinstance(m, HumanMessage)), "", ) # If the user signalled they're finished if any(word in last_human.lower() for word in ("done", "finished")): return Command( update={"human_done": True, "conversation_stage": "completed"}, goto=END, ) # Otherwise simply pause and wait for the next user turn return Command(goto=END) # ----------------------------------------------------------------------------- # 🔗 Graph wiring # ----------------------------------------------------------------------------- graph = StateGraph(AssistantState) graph.add_node("assistant", assistant_node) graph.add_node("human_check", human_check_node) # assistant → human_check graph.add_edge("assistant", "human_check") # No conditional edge back to assistant – human_check always ENDs # Entry point graph.set_entry_point("assistant") memory = MemorySaver() compiled_graph = graph.compile(checkpointer=memory) # ----------------------------------------------------------------------------- # 🖼️ Helper to show state nicely # ----------------------------------------------------------------------------- def format_state(state: "AssistantState") -> str: lines = [f"**Search mode:** {state.get('search_mode', 'not set')}"] lines.append(f"**Stage:** {state.get('conversation_stage', 'initial')}") if state.get("search_mode") == "specific": lines.append(f"**Article ID:** {state.get('article_id', 'not specified')}") elif state.get("search_mode") == "filter": filters = [ ("Search type", state.get("search_type")), ("Size", state.get("size")), ("Color", state.get("color")), ("Gender", state.get("gender")), ] lines.append("**Filters:**") for label, value in filters: lines.append(f"- {label}: {value or 'not specified'}") return "\n".join(lines) # ----------------------------------------------------------------------------- # 🔄 Session helpers # ----------------------------------------------------------------------------- def _fresh_state() -> AssistantState: """Return a brand‑new state dict with a unique thread id.""" return { "messages": [], "search_mode": "not_set", "conversation_stage": "initial", "thread_id": str(uuid.uuid4()), } def reset_session() -> Tuple[List[Tuple[str, str]], AssistantState, str]: state = _fresh_state() # One super‑step to get the assistant greeting for event in compiled_graph.stream( state, {"configurable": {"thread_id": state["thread_id"]}}, stream_mode="values", ): state.update({k: v for k, v in event.items() if k != "messages"}) if "messages" in event and event["messages"]: greeting = event["messages"][-1].content break history = [(None, greeting)] return history, state, format_state(state) def _json_safe_state(state: AssistantState) -> dict: """Return a JSON‑serialisable snapshot of the state (strip Message objects).""" safe = { k: v for k, v in state.items() if k != "messages" # messages aren't JSON serialisable as‑is } # Optionally include a compact log of messages (just role + text) safe["messages"] = [ {"role": "human", "content": m.content} if isinstance(m, HumanMessage) else {"role": "ai", "content": m.content} for m in state.get("messages", []) ] return safe def chat_turn(user_msg: str, history: List[Tuple[str, str]], state: AssistantState): if not state or not state.get("messages"): history, state, sidebar = reset_session() # Append user input to history right away so UI echoes immediately history.append((user_msg, None)) # Record in graph state human_msg = HumanMessage(content=user_msg) state["messages"].append(human_msg) new_ai_messages: List[AIMessage] = [] for event in compiled_graph.stream( state, {"configurable": {"thread_id": state["thread_id"]}}, stream_mode="values", ): state.update({k: v for k, v in event.items() if k != "messages"}) if "messages" in event and event["messages"]: new_ai_messages.extend([m for m in event["messages"] if isinstance(m, AIMessage)]) # Replace the placeholder None with assistant's reply if new_ai_messages: assistant_reply = new_ai_messages[-1].content history[-1] = (user_msg, assistant_reply) # If completed, append final state dict once if state.get("conversation_stage") == "completed" and ( not history or "Final state:" not in history[-1][1] ): safe_state_json = json.dumps(_json_safe_state(state), indent=2) history.append((None, f"""Final state: ```json {safe_state_json} ```""")) sidebar = format_state(state) if state.get("conversation_stage") == "completed": sidebar += "\n\n**🎉 Search completed!**" return history, state, sidebar # ---------------------------------------------------------------------------- # 🚀 Launch Gradio # ---------------------------------------------------------------------------- with gr.Blocks(title="Product Search Assistant") as demo: gr.Markdown("# 🛍️ Product Search Assistant") gr.Markdown("I can help you find products by specific item ID or by applying filters.") with gr.Row(): with gr.Column(scale=2): chatbox = gr.Chatbot(label="Conversation", height=500, render_markdown=True) with gr.Row(): txt = gr.Textbox(placeholder="Type your message here…", lines=1, scale=4) send = gr.Button("Send", variant="primary") reset = gr.Button("Reset") with gr.Column(scale=1): state_md = gr.Markdown() gr.Markdown( """### 💡 Tips * Say **specific** if you already know the product's ID. * Say **filter** to search by size, colour, etc. * Say **done** when you're happy with the filters! """ ) # Hidden session state session_state = gr.State(_fresh_state()) # ↩️ Send handler send.click(chat_turn, [txt, chatbox, session_state], [chatbox, session_state, state_md]).then( lambda: "", None, txt ) # ↩️ Enter key submits as well txt.submit(chat_turn, [txt, chatbox, session_state], [chatbox, session_state, state_md]).then( lambda: "", None, txt ) # 🔄 Reset handler reset.click(reset_session, None, [chatbox, session_state, state_md]) # First load greeting demo.load(reset_session, None, [chatbox, session_state, state_md]) if __name__ == "__main__": demo.launch()