Spaces:
Runtime error
Runtime error
| from fastapi import FastAPI, HTTPException, Query as QueryParam | |
| from pydantic import BaseModel, Field | |
| from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
| from qdrant_client import QdrantClient | |
| from langchain.agents import Tool, AgentExecutor, create_openai_tools_agent | |
| from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
| from langchain.memory import ConversationBufferMemory | |
| from typing import Optional, List, Dict, Any | |
| import os | |
| import warnings | |
| import base64 | |
| import requests | |
| from dotenv import load_dotenv | |
| from datetime import datetime | |
| import json | |
| import uuid | |
| import redis | |
| # Pandas AI imports | |
| import re | |
| import urllib.parse | |
| import pandas as pd | |
| import dask.dataframe as dd | |
| from math import ceil | |
| import psycopg2 | |
| from pandasai import SmartDataframe | |
| from pandasai.llm.openai import OpenAI as PandasOpenAI | |
| from fastapi import FastAPI, Request | |
| from fastapi.responses import JSONResponse | |
| import json | |
| # Import your existing S3 connection details | |
| from retrive_secrects import * # CONNECTIONS_HOST, etc. | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore", message="Qdrant client version.*is incompatible.*") | |
| load_dotenv() | |
| app = FastAPI(title="AI Agent with Redis Session Management and Pandas AI") | |
| # Environment variables | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "vatsav_test_1") | |
| QDRANT_HOST = os.getenv("QDRANT_HOST", "127.0.0.1") | |
| QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333)) | |
| # Redis Configuration | |
| REDIS_URL = os.getenv("REDIS_URL") | |
| REDIS_HOST = os.getenv("REDIS_HOST", "127.0.0.1") | |
| REDIS_PORT = int(os.getenv("REDIS_PORT", 6379)) | |
| REDIS_PASSWORD = os.getenv("REDIS_PASSWORD") | |
| # S3 Constants (from your original code) | |
| S3_Bucket_Name = 'ingenspark-user-files' | |
| S3_Raw_Files_Folder = 'User-Uploaded-Raw-Files' | |
| S3_Modified_Files_Folder = 'Modified-Files/' | |
| S3_Output_Files_Folder = 'Output-Files/' | |
| S3_Published_Results_Folder = 'Published-Results/' | |
| S3_Ingen_Customer_Output = 'Ingen-Customer/' | |
| Dominant_Segmentation_Output = 'Dominant-Segmentation/' | |
| Trend_Segmentation_Output = 'Trend-Segmentation/' | |
| Decile_Quartile_segmentation_Output = 'Decile-Quartile-Segmentation/' | |
| Combined_Segmentation_Output = 'Combine-Segmentation/' | |
| Custom_Segmentation_Output = 'Custom-Segmentation/' | |
| Customer_360_Output = 'Customer-360/' | |
| Merge_file_folder = S3_Modified_Files_Folder + 'IngenData-Merged-Tables/' | |
| S3_Dev_Doc_Images_Folder = 'Developers-Documentation-Images/' | |
| S3_Temporary_Files_Folder = S3_Raw_Files_Folder | |
| S3_App_Specific_Data = 'Application-Specific-Data/' | |
| S3_Transformation_Tables_Folder = 'Modified-Files/Modified-Tables/Transformation-Tables/' | |
| cloud_front_url = "https://files.dev.ingenspark.com/" | |
| # Initialize Redis client | |
| def get_redis_client(): | |
| """Initialize Redis client with fallback to local Redis""" | |
| try: | |
| if REDIS_URL: | |
| # Use deployed Redis URL | |
| redis_client = redis.from_url( | |
| REDIS_URL, | |
| decode_responses=True, | |
| socket_connect_timeout=5, | |
| socket_timeout=5 | |
| ) | |
| # Test connection | |
| redis_client.ping() | |
| print(f"β Connected to deployed Redis: {REDIS_URL}") | |
| return redis_client | |
| else: | |
| # Use local Redis | |
| redis_client = redis.StrictRedis( | |
| host=REDIS_HOST, | |
| port=REDIS_PORT, | |
| password=REDIS_PASSWORD, | |
| decode_responses=True, | |
| socket_connect_timeout=5, | |
| socket_timeout=5 | |
| ) | |
| # Test connection | |
| redis_client.ping() | |
| print(f"β Connected to local Redis: {REDIS_HOST}:{REDIS_PORT}") | |
| return redis_client | |
| except Exception as e: | |
| print(f"β Redis connection failed: {e}") | |
| raise HTTPException(status_code=500, detail=f"Redis connection failed: {str(e)}") | |
| # Initialize Redis client | |
| redis_client = get_redis_client() | |
| # Initialize models | |
| embedding_model = OpenAIEmbeddings( | |
| model="text-embedding-3-large", | |
| openai_api_key=OPENAI_API_KEY, | |
| ) | |
| qdrant_client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) | |
| llm = ChatOpenAI(model="gpt-4o", temperature=0, openai_api_key=OPENAI_API_KEY) | |
| # === PANDAS AI FUNCTIONS === | |
| def read_parquet_file_from_s3(ufuid=None, columns_list=None, records_count=None, file_location=''): | |
| """ | |
| Reads a Parquet file from S3 using Dask and returns it as a Pandas DataFrame. | |
| Parameters: | |
| ufuid (int): Optional user_file_upload_id to fetch S3 path from DB. | |
| columns_list (list/str): Columns to read. | |
| records_count (int): Not used currently. | |
| file_location (str): Direct file path in S3. | |
| Returns: | |
| pandas.DataFrame | |
| """ | |
| try: | |
| # Connect to PostgreSQL | |
| conn = psycopg2.connect( | |
| host=CONNECTIONS_HOST, | |
| database=CONNECTIONS_DB, | |
| user=CONNECTIONS_USER, | |
| password=CONNECTIONS_PASS | |
| ) | |
| cursor = conn.cursor() | |
| if ufuid is not None: | |
| query = """SELECT file_name, table_names FROM public.user_file_upload WHERE user_file_upload_id = %s""" | |
| cursor.execute(query, (ufuid,)) | |
| file = cursor.fetchone() | |
| if not file: | |
| raise ValueError(f"No file found for ufuid: {ufuid}") | |
| file_name, s3_file_path = file | |
| else: | |
| # Normalize input path | |
| file_location = re.sub(r'\.parquet(?!$)', '', file_location) | |
| s3_file_path = file_location if file_location.endswith('.parquet') else file_location + '.parquet' | |
| # Extract relative S3 path | |
| s3_file_path = urllib.parse.unquote(s3_file_path.split(f"{S3_Bucket_Name}/")[-1]) | |
| if not s3_file_path.endswith('.parquet'): | |
| s3_file_path += '.parquet' | |
| # Parse columns if given as comma-separated string | |
| if columns_list and not isinstance(columns_list, list): | |
| columns_list = [col.strip(' "\'') for col in columns_list.split(',')] | |
| print(f"\n{'!' * 100}\nReading from: s3://{S3_Bucket_Name}/{s3_file_path}\n") | |
| # Read using Dask | |
| ddf = dd.read_parquet( | |
| f"s3://{S3_Bucket_Name}/{s3_file_path}", | |
| engine="pyarrow", | |
| columns=columns_list, | |
| assume_missing=True | |
| ) | |
| ddf = ddf.repartition(npartitions=8) # Optimize for processing | |
| print("Reading Parquet file from S3 completed successfully.") | |
| # Close database connection | |
| cursor.close() | |
| conn.close() | |
| return ddf.compute() | |
| except Exception as e: | |
| print(f"β Error reading Parquet file: {e}") | |
| return pd.DataFrame() # Return empty DataFrame on error | |
| def pandas_agent(filepath: str, query: str) -> str: | |
| """ | |
| PandasAI agent that reads data from S3 and answers queries about the data. | |
| Parameters: | |
| filepath (str): S3 file path or ufuid | |
| query (str): Natural language query about the data | |
| Returns: | |
| str: Answer from PandasAI | |
| """ | |
| try: | |
| # Check if filepath is a number (ufuid) or a file path | |
| if filepath.isdigit(): | |
| # It's a ufuid | |
| data = read_parquet_file_from_s3(ufuid=int(filepath)) | |
| else: | |
| # It's a file path | |
| data = read_parquet_file_from_s3(file_location=filepath) | |
| if data.empty: | |
| return "β No data found or failed to load the file. Please check the file path or ufuid." | |
| # Initialize PandasAI LLM | |
| if not OPENAI_API_KEY: | |
| return "β OPENAI_API_KEY is not set in environment variables." | |
| pandas_llm = PandasOpenAI(api_token=OPENAI_API_KEY) | |
| # Create SmartDataframe | |
| sdf = SmartDataframe(data, config={"llm": pandas_llm}) | |
| # Ask the question | |
| print(f"π Processing query: {query}") | |
| result = sdf.chat(query) | |
| # Handle different types of results | |
| if isinstance(result, str): | |
| return f"π Analysis Result:\n{result}" | |
| elif isinstance(result, (pd.DataFrame, pd.Series)): | |
| return f"π Analysis Result:\n{result.to_string()}" | |
| else: | |
| return f"π Analysis Result:\n{str(result)}" | |
| except Exception as e: | |
| error_msg = f"β Error in pandas_agent: {str(e)}" | |
| print(error_msg) | |
| return error_msg | |
| # === INPUT SCHEMAS === | |
| class Query(BaseModel): | |
| message: str | |
| class ProjectRequest(BaseModel): | |
| userLoginId: int | |
| orgId: int | |
| auth_token: str | |
| class BotQuery(BaseModel): | |
| userLoginId: int | |
| orgId: int | |
| auth_token: str | |
| session_id: Optional[str] = None | |
| message: str | |
| class PandasAgentQuery(BaseModel): | |
| filepath: str = Field(..., description="S3 file path or ufuid") | |
| query: str = Field(..., description="Natural language query about the data") | |
| class SessionResponse(BaseModel): | |
| session_id: str | |
| userLoginId: int | |
| orgId: int | |
| created_at: str | |
| status: str | |
| title: Optional[str] = "New Chat" | |
| class MessageResponse(BaseModel): | |
| message_id: str | |
| session_id: str | |
| role: str # "user" or "assistant" | |
| message: str | |
| timestamp: str | |
| class ChatHistoryResponse(BaseModel): | |
| session_id: str | |
| messages: List[MessageResponse] | |
| total_messages: int | |
| # === SESSION MANAGEMENT FUNCTIONS === | |
| @app.middleware("http") | |
| async def add_success_flag(request: Request, call_next): | |
| response = await call_next(request) | |
| # Only modify JSON responses | |
| if "application/json" in response.headers.get("content-type", ""): | |
| try: | |
| body = b"".join([chunk async for chunk in response.body_iterator]) | |
| data = json.loads(body.decode()) | |
| # Add success flag | |
| data["success"] = 200 <= response.status_code < 300 | |
| # Build new JSONResponse (auto handles Content-Length) | |
| response = JSONResponse( | |
| content=data, | |
| status_code=response.status_code, | |
| headers={k: v for k, v in response.headers.items() if k.lower() != "content-length"}, | |
| ) | |
| except Exception: | |
| # fallback if response is not JSON parseable | |
| pass | |
| return response | |
| def create_session(userLoginId: int, orgId: int, auth_token: str) -> dict: | |
| """Create a new chat session""" | |
| session_id = str(uuid.uuid4()) | |
| session_data = { | |
| "session_id": session_id, | |
| "userLoginId": userLoginId, | |
| "orgId": orgId, | |
| "auth_token": auth_token, | |
| "created_at": datetime.now().isoformat(), | |
| "status": "active", | |
| "title": "New Chat" # Default title, will be updated after first message | |
| } | |
| # Store session in Redis with 24 hour TTL | |
| redis_client.setex( | |
| f"session:{session_id}", | |
| 86400, # 24 hours | |
| json.dumps(session_data) | |
| ) | |
| # Initialize empty chat history | |
| redis_client.setex( | |
| f"chat:{session_id}", | |
| 86400, # 24 hours | |
| json.dumps([]) | |
| ) | |
| # Initialize conversation memory | |
| redis_client.setex( | |
| f"memory:{session_id}", | |
| 86400, # 24 hours | |
| json.dumps([]) | |
| ) | |
| return session_data | |
| def get_session(session_id: str) -> dict: | |
| """Get session data from Redis""" | |
| session_data = redis_client.get(f"session:{session_id}") | |
| if not session_data: | |
| raise HTTPException(status_code=404, detail="Session not found or expired") | |
| return json.loads(session_data) | |
| def add_message_to_session(session_id: str, role: str, message: str) -> str: | |
| """Add message to session chat history""" | |
| message_id = str(uuid.uuid4()) | |
| message_data = { | |
| "message_id": message_id, | |
| "session_id": session_id, | |
| "role": role, | |
| "message": message, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| # Get current chat history | |
| chat_history = redis_client.get(f"chat:{session_id}") | |
| if chat_history: | |
| messages = json.loads(chat_history) | |
| else: | |
| messages = [] | |
| # Add new message | |
| messages.append(message_data) | |
| # Update chat history in Redis with extended TTL | |
| redis_client.setex( | |
| f"chat:{session_id}", | |
| 86400, # 24 hours | |
| json.dumps(messages) | |
| ) | |
| return message_id | |
| def get_session_memory(session_id: str) -> List[Dict]: | |
| """Get conversation memory for session""" | |
| memory_data = redis_client.get(f"memory:{session_id}") | |
| if memory_data: | |
| return json.loads(memory_data) | |
| return [] | |
| def update_session_memory(session_id: str, messages: List[Dict]): | |
| """Update conversation memory for session""" | |
| redis_client.setex( | |
| f"memory:{session_id}", | |
| 86400, # 24 hours | |
| json.dumps(messages) | |
| ) | |
| def update_session_title(session_id: str): | |
| """Update session title after first message""" | |
| try: | |
| # Get session data | |
| session_data = redis_client.get(f"session:{session_id}") | |
| if not session_data: | |
| return | |
| session = json.loads(session_data) | |
| # Only update if current title is "New Chat" | |
| if session.get("title", "New Chat") == "New Chat": | |
| new_title = generate_session_title(session_id) | |
| session["title"] = new_title | |
| # Update session in Redis | |
| redis_client.setex( | |
| f"session:{session_id}", | |
| 86400, # 24 hours | |
| json.dumps(session) | |
| ) | |
| except Exception as e: | |
| print(f"Error updating session title: {e}") | |
| pass # Don't fail the request if title update fails | |
| def generate_session_title(session_id: str) -> str: | |
| """Generate a title for the session based on chat history""" | |
| try: | |
| # Get chat history | |
| chat_data = redis_client.get(f"chat:{session_id}") | |
| if not chat_data: | |
| return "New Chat" | |
| messages = json.loads(chat_data) | |
| if not messages: | |
| return "New Chat" | |
| # Get first user message for title generation | |
| first_user_message = None | |
| for msg in messages: | |
| if msg["role"] == "user": | |
| first_user_message = msg["message"] | |
| break | |
| if not first_user_message: | |
| return "New Chat" | |
| # Generate title using LLM | |
| title_prompt = f"""Generate a short, descriptive title (maximum 6 words) for a chat conversation that starts with this message: | |
| "{first_user_message[:200]}" | |
| Return only the title, no quotes or additional text. The title should capture the main topic or intent of the conversation.""" | |
| try: | |
| response = llm.invoke(title_prompt) | |
| title = response.content.strip() | |
| # Clean and limit title | |
| title = title.replace('"', '').replace("'", "") | |
| if len(title) > 50: | |
| title = title[:47] + "..." | |
| return title if title else "New Chat" | |
| except Exception as e: | |
| print(f"Error generating title: {e}") | |
| # Fallback: use first few words of the message | |
| words = first_user_message.split()[:4] | |
| return " ".join(words) + ("..." if len(words) >= 4 else "") | |
| except Exception as e: | |
| print(f"Error in generate_session_title: {e}") | |
| return "New Chat" | |
| def get_user_sessions(userLoginId: int) -> List[dict]: | |
| """Get all sessions for a user with generated titles""" | |
| sessions = [] | |
| # Scan for all session keys | |
| for key in redis_client.scan_iter(match="session:*"): | |
| session_data = redis_client.get(key) | |
| if session_data: | |
| session = json.loads(session_data) | |
| if session["userLoginId"] == userLoginId: | |
| # Generate title based on chat history | |
| session["title"] = generate_session_title(session["session_id"]) | |
| sessions.append(session) | |
| # Sort sessions by created_at (most recent first) | |
| sessions.sort(key=lambda x: x["created_at"], reverse=True) | |
| return sessions | |
| def delete_session(session_id: str): | |
| """Delete session and associated data""" | |
| # Delete session data | |
| redis_client.delete(f"session:{session_id}") | |
| # Delete chat history | |
| redis_client.delete(f"chat:{session_id}") | |
| # Delete memory | |
| redis_client.delete(f"memory:{session_id}") | |
| # === UTILITY FUNCTIONS === | |
| def get_encoded_auth_token(user: int, token: str) -> str: | |
| auth_string = f"{user}:{token}" | |
| return base64.b64encode(auth_string.encode("utf-8")).decode("utf-8") | |
| def fetch_user_projects(userLoginId: int, orgId: int, auth_token: str): | |
| url = "https://japidemo.dev.ingenspark.com/fetchUserProjects" | |
| payload = { | |
| "userLoginId": userLoginId, | |
| "orgId": orgId | |
| } | |
| headers = { | |
| 'accept': 'application/json, text/plain, */*', | |
| 'authorization': f'Basic {auth_token}', | |
| 'content-type': 'application/json; charset=UTF-8' | |
| } | |
| try: | |
| response = requests.post(url, headers=headers, json=payload) | |
| response.raise_for_status() | |
| return response.json() | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=response.status_code if 'response' in locals() else 500, | |
| detail=str(e)) | |
| def format_project_response(data: dict) -> str: | |
| my_projects = data.get("data", {}).get("Myprojects", []) | |
| other_projects = data.get("data", {}).get("Otherprojects", []) | |
| all_projects = [] | |
| for project in my_projects: | |
| all_projects.append({ | |
| "type": "Your Project", | |
| "projectNm": project["projectNm"], | |
| "projectId": project["projectId"], | |
| "created_dttm": project["created_dttm"].split('.')[0], | |
| "description": project["description"], | |
| "categoryName": project["categoryName"] | |
| }) | |
| for project in other_projects: | |
| all_projects.append({ | |
| "type": "Other Project", | |
| "projectNm": project["projectNm"], | |
| "projectId": project["projectId"], | |
| "created_dttm": project["created_dttm"].split('.')[0], | |
| "description": project["description"], | |
| "categoryName": project["categoryName"] | |
| }) | |
| if not all_projects: | |
| return "β No projects found." | |
| # Build the formatted string | |
| result = [f"β You have access to {len(all_projects)} project(s):\n"] | |
| for i, project in enumerate(all_projects, 1): | |
| result.append(f"{i}. Project Name: {project['projectNm']} ({project['type']})") | |
| result.append(f" Project ID: {project['projectId']}") | |
| result.append(f" Created On: {project['created_dttm']}") | |
| result.append(f" Description: {project['description']}") | |
| result.append(f" Category: {project['categoryName']}\n") | |
| return "\n".join(result) | |
| # === TOOL FUNCTIONS === | |
| def search_documents(query: str) -> str: | |
| """Search through ingested documents and get relevant information.""" | |
| try: | |
| # Generate embedding for the query | |
| query_vector = embedding_model.embed_query(query) | |
| # Search in Qdrant | |
| search_result = qdrant_client.search( | |
| collection_name=QDRANT_COLLECTION_NAME, | |
| query_vector=query_vector, | |
| limit=5, | |
| ) | |
| if not search_result: | |
| return "No relevant information found in the knowledge base." | |
| # Convert results to text content | |
| context_texts = [] | |
| sources = [] | |
| for hit in search_result: | |
| context_texts.append(hit.payload["text"]) | |
| sources.append(hit.payload.get("source", "unknown")) | |
| # Create a simple prompt for answering based on context | |
| context = "\n\n".join(context_texts) | |
| unique_sources = list(set(sources)) | |
| # Use the LLM directly to answer the message based on context | |
| prompt = f"""Based on the following context, answer the message: {query} | |
| Context: | |
| {context} | |
| Please provide a comprehensive answer based on the context above. If the context doesn't contain enough information to answer the message, say so clearly.""" | |
| response = llm.invoke(prompt) | |
| return f"{response.content}\n\nSources: {', '.join(unique_sources)}" | |
| except Exception as e: | |
| return f"Error searching documents: {str(e)}" | |
| # Global variables to store auth context (for tool functions) | |
| _current_user_id = None | |
| _current_org_id = None | |
| _current_auth_token = None | |
| def get_user_projects(userLoginId: str) -> str: | |
| """Get list of projects for a user.""" | |
| try: | |
| # Use global auth context if available | |
| if _current_auth_token and _current_user_id: | |
| user_id = _current_user_id | |
| org_id = _current_org_id or 1 | |
| auth_token = _current_auth_token | |
| else: | |
| return "β Authentication token required. Please provide auth_token in your request." | |
| # Encode auth token using the actual user ID and provided token | |
| encoded_token = get_encoded_auth_token(user_id, auth_token) | |
| # Fetch projects | |
| data = fetch_user_projects(user_id, org_id, encoded_token) | |
| # Format and return the project list | |
| formatted = format_project_response(data) | |
| return formatted | |
| except ValueError: | |
| return "β Invalid userLoginId format. Please provide a valid number." | |
| except Exception as e: | |
| return f"β Error fetching projects: {str(e)}" | |
| def pandas_data_analysis(query_with_filepath: str) -> str: | |
| """ | |
| Tool for data analysis using PandasAI. | |
| Input format: 'filepath|query' where filepath is S3 path or ufuid, and query is the analysis question. | |
| """ | |
| try: | |
| # Parse the input to extract filepath and query | |
| parts = query_with_filepath.split('|', 1) | |
| if len(parts) != 2: | |
| return "β Invalid input format. Please use: 'filepath|query' format." | |
| filepath, query = parts | |
| filepath = filepath.strip() | |
| query = query.strip() | |
| if not filepath or not query: | |
| return "β Both filepath and query are required." | |
| # Use the pandas_agent function | |
| result = pandas_agent(filepath, query) | |
| return result | |
| except Exception as e: | |
| return f"β Error in pandas data analysis: {str(e)}" | |
| # === CREATE TOOLS === | |
| document_search_tool = Tool( | |
| name="document_search", | |
| description="""Use this tool to search through ingested documents and get relevant information from the knowledge base. | |
| Perfect for answering messages about uploaded documents, manuals, or any content that was previously stored. | |
| Input should be a search query or message about the documents.""", | |
| func=search_documents | |
| ) | |
| project_list_tool = Tool( | |
| name="get_user_projects", | |
| description="""Use this tool to get the list of projects for a user. | |
| Perfect for when users ask about their projects, want to see available projects, or need project information. | |
| Input should be the userLoginId (e.g., '25'). | |
| Note: This tool requires authentication context to be set.""", | |
| func=get_user_projects | |
| ) | |
| pandas_analysis_tool = Tool( | |
| name="pandas_data_analysis", | |
| description="""Use this tool for data analysis on CSV/Parquet files using PandasAI. | |
| Perfect for when users ask questions about data analysis, statistics, insights, or want to query their datasets. | |
| Input format: 'filepath|query' where: | |
| - filepath: S3 file path (e.g., 'User-Uploaded-Raw-Files/Data2004csv1754926601269756') or ufuid (e.g., '123') | |
| - query: Natural language question about the data (e.g., 'What are the top 5 values?', 'Show me summary statistics') | |
| Examples: | |
| - 'User-Uploaded-Raw-Files/mydata.csv|What is this file about?' | |
| - '123|Show me the first 5 rows' | |
| - 'Modified-Files/processed_data|What are the most common values in column X?' | |
| """, | |
| func=pandas_data_analysis | |
| ) | |
| # === AGENT SETUP === | |
| def create_agent_with_session_memory(session_id: str): | |
| """Create agent with session memory from Redis""" | |
| # Get memory from Redis | |
| memory_messages = get_session_memory(session_id) | |
| agent_prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """You are a helpful AI assistant with access to multiple tools and conversation memory: | |
| 1. **Document Search**: Search through uploaded documents and knowledge base | |
| 2. **Project Management**: Get list of user projects and project information | |
| 3. **Data Analysis**: Analyze CSV/Parquet files using PandasAI for insights, statistics, and queries | |
| Your capabilities: | |
| - Answer messages about documents using the document search tool | |
| - Help users find their projects and project information | |
| - Perform data analysis on uploaded datasets using natural language queries | |
| - Remember previous conversations in this session | |
| - Provide general assistance and information | |
| - Use appropriate tools based on user queries | |
| Guidelines: | |
| - Use the document search tool when users ask about specific content, documentation, or information that might be in uploaded files | |
| - Use the project tool when users ask about projects, want to see their projects, or need project-related information | |
| - Use the pandas analysis tool when users ask about data analysis, statistics, insights, or want to query datasets | |
| - For pandas analysis, you need both a filepath (S3 path or ufuid) and a query - ask for missing information if needed | |
| - Reference previous conversation context when relevant | |
| - Be clear about which tool you're using and what information you're providing | |
| - If you're unsure which tool to use, you can ask for clarification | |
| - Provide helpful, accurate, and well-formatted responses | |
| Remember: Always use the most appropriate tool based on the user's message and conversation context to provide the best possible answer."""), | |
| MessagesPlaceholder(variable_name="chat_history"), | |
| ("user", "{input}"), | |
| MessagesPlaceholder(variable_name="agent_scratchpad"), | |
| ]) | |
| # Create memory object | |
| memory = ConversationBufferMemory( | |
| memory_key="chat_history", | |
| return_messages=True | |
| ) | |
| # Load existing messages into memory | |
| for msg in memory_messages: | |
| if msg["role"] == "user": | |
| memory.chat_memory.add_user_message(msg["message"]) | |
| else: | |
| memory.chat_memory.add_ai_message(msg["message"]) | |
| # Create tools list | |
| tools = [document_search_tool, project_list_tool, pandas_analysis_tool] | |
| # Create the agent | |
| agent = create_openai_tools_agent(llm, tools, agent_prompt) | |
| # Create the agent executor with memory | |
| agent_executor = AgentExecutor( | |
| agent=agent, | |
| tools=tools, | |
| verbose=True, | |
| memory=memory | |
| ) | |
| return agent_executor, memory | |
| # === API ENDPOINTS === | |
| @app.post("/sessions", response_model=SessionResponse) | |
| def create_new_session(userLoginId: int, orgId: int, auth_token: str): | |
| """Create a new chat session""" | |
| try: | |
| session_data = create_session(userLoginId, orgId, auth_token) | |
| return SessionResponse(**session_data) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error creating session: {str(e)}") | |
| @app.get("/sessions") | |
| def list_user_sessions(userLoginId: int): | |
| """List all sessions for a user""" | |
| try: | |
| sessions = get_user_sessions(userLoginId) | |
| return { | |
| "userLoginId": userLoginId, | |
| "total_sessions": len(sessions), | |
| "sessions": sessions | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error fetching sessions: {str(e)}") | |
| @app.delete("/sessions/{session_id}") | |
| def delete_user_session(session_id: str): | |
| """Delete/close a session""" | |
| try: | |
| # Verify session exists | |
| get_session(session_id) | |
| # Delete session | |
| delete_session(session_id) | |
| return { | |
| "message": f"Session {session_id} deleted successfully", | |
| "session_id": session_id | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error deleting session: {str(e)}") | |
| @app.post("/bot") | |
| def chat_with_bot(query: BotQuery): | |
| """Main bot endpoint with session management""" | |
| try: | |
| # Set global auth context for tools | |
| global _current_user_id, _current_org_id, _current_auth_token | |
| _current_user_id = query.userLoginId | |
| _current_org_id = query.orgId | |
| _current_auth_token = query.auth_token | |
| session_id = query.session_id | |
| # Create new session if not provided | |
| if not session_id: | |
| session_data = create_session(query.userLoginId, query.orgId, query.auth_token) | |
| session_id = session_data["session_id"] | |
| else: | |
| # Verify existing session | |
| get_session(session_id) | |
| # Add user message to session | |
| user_message_id = add_message_to_session(session_id, "user", query.message) | |
| # Create agent with session memory | |
| agent_executor, memory = create_agent_with_session_memory(session_id) | |
| # Use the agent to process the query | |
| result = agent_executor.invoke({"input": query.message}) | |
| # Add AI response to session | |
| ai_message_id = add_message_to_session(session_id, "assistant", result["output"]) | |
| # Update session memory in Redis | |
| updated_messages = [] | |
| for message in memory.chat_memory.messages: | |
| if hasattr(message, 'content'): | |
| role = "user" if message.__class__.__name__ == "HumanMessage" else "assistant" | |
| updated_messages.append({ | |
| "role": role, | |
| "message": message.content, | |
| "timestamp": datetime.now().isoformat() | |
| }) | |
| update_session_memory(session_id, updated_messages) | |
| # Update session title after first user message | |
| update_session_title(session_id) | |
| # Clear auth context after use | |
| _current_user_id = None | |
| _current_org_id = None | |
| _current_auth_token = None | |
| return { | |
| "session_id": session_id, | |
| "user_message_id": user_message_id, | |
| "ai_message_id": ai_message_id, | |
| "message": query.message, | |
| "answer": result["output"], | |
| "userLoginId": query.userLoginId, | |
| "agent_used": True | |
| } | |
| except Exception as e: | |
| # Clear auth context on error | |
| _current_user_id = None | |
| _current_org_id = None | |
| _current_auth_token = None | |
| raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}") | |
| @app.get("/sessions/{session_id}/history", response_model=ChatHistoryResponse) | |
| def get_session_history(session_id: str, n: int = QueryParam(50, description="Number of recent messages to return")): | |
| """Get chat history for a session""" | |
| try: | |
| # Verify session exists | |
| get_session(session_id) | |
| # Get chat history | |
| chat_data = redis_client.get(f"chat:{session_id}") | |
| if not chat_data: | |
| return ChatHistoryResponse( | |
| session_id=session_id, | |
| messages=[], | |
| total_messages=0 | |
| ) | |
| messages = json.loads(chat_data) | |
| # Get the last n messages (or all if less than n) | |
| recent_messages = messages[-n:] if len(messages) > n else messages | |
| # Convert to MessageResponse objects | |
| message_responses = [MessageResponse(**msg) for msg in recent_messages] | |
| return ChatHistoryResponse( | |
| session_id=session_id, | |
| messages=message_responses, | |
| total_messages=len(messages) | |
| ) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error fetching chat history: {str(e)}") | |
| @app.post("/chat-documents") | |
| def chat_documents_only(query: Query): | |
| """Direct document search without agent""" | |
| try: | |
| result = search_documents(query.message) | |
| return { | |
| "message": query.message, | |
| "answer": result, | |
| "tool_used": "document_search" | |
| } | |
| except Exception as e: | |
| return { | |
| "message": query.message, | |
| "answer": f"An error occurred: {str(e)}", | |
| "tool_used": "document_search" | |
| } | |
| @app.post("/list-projects") | |
| def list_projects(request: ProjectRequest): | |
| """Direct project listing without agent""" | |
| try: | |
| # Use the provided auth token and userLoginId | |
| encoded_token = get_encoded_auth_token(request.userLoginId, request.auth_token) | |
| # Fetch projects | |
| data = fetch_user_projects(request.userLoginId, request.orgId, encoded_token) | |
| # Format and return the project list | |
| formatted = format_project_response(data) | |
| return { | |
| "projects": formatted, | |
| "tool_used": "project_list" | |
| } | |
| except Exception as e: | |
| return { | |
| "error": f"An error occurred: {str(e)}", | |
| "tool_used": "project_list" | |
| } | |
| @app.post("/chat-with-pandas-agent") | |
| def chat_with_pandas_agent(request: PandasAgentQuery): | |
| """Direct pandas AI agent endpoint for data analysis""" | |
| try: | |
| result = pandas_agent(request.filepath, request.query) | |
| return { | |
| "filepath": request.filepath, | |
| "query": request.query, | |
| "answer": result, | |
| "tool_used": "pandas_agent", | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| except Exception as e: | |
| error_msg = f"An error occurred: {str(e)}" | |
| return { | |
| "filepath": request.filepath, | |
| "query": request.query, | |
| "answer": error_msg, | |
| "tool_used": "pandas_agent", | |
| "error": True, | |
| "timestamp": datetime.now().isoformat() | |
| } | |
| @app.put("/sessions/{session_id}/title") | |
| def refresh_session_title(session_id: str): | |
| """Manually refresh/regenerate session title""" | |
| try: | |
| # Verify session exists | |
| session_data = get_session(session_id) | |
| # Generate new title | |
| new_title = generate_session_title(session_id) | |
| # Update session | |
| session_data["title"] = new_title | |
| redis_client.setex( | |
| f"session:{session_id}", | |
| 86400, # 24 hours | |
| json.dumps(session_data) | |
| ) | |
| return { | |
| "session_id": session_id, | |
| "new_title": new_title, | |
| "message": "Session title updated successfully" | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error updating session title: {str(e)}") | |
| @app.get("/redis-info") | |
| def redis_info(): | |
| """Get Redis connection information""" | |
| try: | |
| info = redis_client.info() | |
| return { | |
| "redis_connected": True, | |
| "redis_version": info.get("redis_version"), | |
| "used_memory": info.get("used_memory_human"), | |
| "connected_clients": info.get("connected_clients"), | |
| "total_keys": redis_client.dbsize() | |
| } | |
| except Exception as e: | |
| return { | |
| "redis_connected": False, | |
| "error": str(e) | |
| } | |
| @app.get("/health") | |
| def health(): | |
| try: | |
| redis_client.ping() | |
| redis_status = "connected" | |
| except: | |
| redis_status = "disconnected" | |
| return { | |
| "status": "ok", | |
| "tools": ["document_search", "project_list", "pandas_data_analysis"], | |
| "agent": "active", | |
| "session_management": "enabled", | |
| "redis_status": redis_status, | |
| "pandas_ai": "enabled", | |
| "total_sessions": len(list(redis_client.scan_iter(match="session:*"))) | |
| } |