Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, Query as QueryParam, UploadFile, File, Request | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel, Field | |
from langchain_openai import ChatOpenAI, OpenAIEmbeddings | |
from qdrant_client import QdrantClient | |
from qdrant_client.models import VectorParams, Distance, PointStruct, Filter, SearchRequest | |
from langchain.agents import Tool, AgentExecutor, create_openai_tools_agent | |
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder | |
from langchain.memory import ConversationBufferMemory | |
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, Docx2txtLoader, BSHTMLLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from typing import Optional, List, Dict, Any | |
import os | |
import warnings | |
import base64 | |
import requests | |
import tempfile | |
import uuid | |
import json | |
import redis | |
from dotenv import load_dotenv | |
from datetime import datetime | |
# Pandas AI imports | |
import re | |
import urllib.parse | |
import pandas as pd | |
import dask.dataframe as dd | |
from math import ceil | |
import psycopg2 | |
from pandasai import SmartDataframe | |
from pandasai.llm.openai import OpenAI as PandasOpenAI | |
# Import your existing S3 connection details | |
from retrive_secrects import * # CONNECTIONS_HOST, etc. | |
import tempfile | |
import json | |
from typing import List, Dict, Any, Optional | |
# Suppress warnings | |
warnings.filterwarnings("ignore", message="Qdrant client version.*is incompatible.*") | |
load_dotenv() | |
app = FastAPI(title="Combined AI Agent with Qdrant Collections and Redis Session Management") | |
# Environment variables | |
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
if not OPENAI_API_KEY: | |
raise ValueError("β OPENAI_API_KEY not set in environment variables") | |
QDRANT_COLLECTION_NAME = os.getenv("QDRANT_COLLECTION_NAME", "vatsav_test_1") | |
# Qdrant Configuration - Using cloud instance | |
QDRANT_API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJhY2Nlc3MiOiJtIiwiZXhwIjoxNzY0MTQ5OTc3fQ.l_2R-Eyb_530887EGLUkawZQamhPGVklDMlaVs0bDqo" | |
QDRANT_URL = "https://09476415-f871-4664-9c92-2f7f17c223ee.eu-central-1-0.aws.cloud.qdrant.io" | |
# Fallback to local Qdrant if needed | |
QDRANT_HOST = os.getenv("QDRANT_HOST", "127.0.0.1") | |
QDRANT_PORT = int(os.getenv("QDRANT_PORT", 6333)) | |
# Redis Configuration | |
REDIS_URL = os.getenv("REDIS_URL") | |
REDIS_HOST = os.getenv("REDIS_HOST", "127.0.0.1") | |
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379)) | |
REDIS_PASSWORD = os.getenv("REDIS_PASSWORD") | |
# S3 Constants (from your original code) | |
S3_Bucket_Name = 'ingenspark-user-files' | |
S3_Raw_Files_Folder = 'User-Uploaded-Raw-Files' | |
S3_Modified_Files_Folder = 'Modified-Files/' | |
S3_Output_Files_Folder = 'Output-Files/' | |
S3_Published_Results_Folder = 'Published-Results/' | |
S3_Ingen_Customer_Output = 'Ingen-Customer/' | |
Dominant_Segmentation_Output = 'Dominant-Segmentation/' | |
Trend_Segmentation_Output = 'Trend-Segmentation/' | |
Decile_Quartile_segmentation_Output = 'Decile-Quartile-Segmentation/' | |
Combined_Segmentation_Output = 'Combine-Segmentation/' | |
Custom_Segmentation_Output = 'Custom-Segmentation/' | |
Customer_360_Output = 'Customer-360/' | |
Merge_file_folder = S3_Modified_Files_Folder + 'IngenData-Merged-Tables/' | |
S3_Dev_Doc_Images_Folder = 'Developers-Documentation-Images/' | |
S3_Temporary_Files_Folder = S3_Raw_Files_Folder | |
S3_App_Specific_Data = 'Application-Specific-Data/' | |
S3_Transformation_Tables_Folder = 'Modified-Files/Modified-Tables/Transformation-Tables/' | |
cloud_front_url = "https://files.dev.ingenspark.com/" | |
# Initialize Qdrant client with fallback | |
def get_qdrant_client(): | |
"""Initialize Qdrant client with fallback to local instance""" | |
try: | |
# Try cloud instance first | |
client = QdrantClient( | |
url=QDRANT_URL, | |
api_key=QDRANT_API_KEY | |
) | |
# Test connection | |
collections = client.get_collections() | |
print(f"β Connected to cloud Qdrant: {QDRANT_URL}") | |
return client | |
except Exception as e: | |
print(f"β οΈ Cloud Qdrant failed: {e}, trying local...") | |
try: | |
# Fallback to local | |
client = QdrantClient(host=QDRANT_HOST, port=QDRANT_PORT) | |
client.get_collections() | |
print(f"β Connected to local Qdrant: {QDRANT_HOST}:{QDRANT_PORT}") | |
return client | |
except Exception as e2: | |
print(f"β Both Qdrant connections failed: {e2}") | |
raise HTTPException(status_code=500, detail=f"Qdrant connection failed: {str(e2)}") | |
# Initialize Redis client | |
def get_redis_client(): | |
"""Initialize Redis client with fallback to local Redis""" | |
try: | |
if REDIS_URL: | |
# Use deployed Redis URL | |
redis_client = redis.from_url( | |
REDIS_URL, | |
decode_responses=True, | |
socket_connect_timeout=5, | |
socket_timeout=5 | |
) | |
# Test connection | |
redis_client.ping() | |
print(f"β Connected to deployed Redis: {REDIS_URL}") | |
return redis_client | |
else: | |
# Use local Redis | |
redis_client = redis.StrictRedis( | |
host=REDIS_HOST, | |
port=REDIS_PORT, | |
password=REDIS_PASSWORD, | |
decode_responses=True, | |
socket_connect_timeout=5, | |
socket_timeout=5 | |
) | |
# Test connection | |
redis_client.ping() | |
print(f"β Connected to local Redis: {REDIS_HOST}:{REDIS_PORT}") | |
return redis_client | |
except Exception as e: | |
print(f"β Redis connection failed: {e}") | |
raise HTTPException(status_code=500, detail=f"Redis connection failed: {str(e)}") | |
# Initialize clients | |
qdrant_client = get_qdrant_client() | |
redis_client = get_redis_client() | |
# Initialize models | |
embedding_model = OpenAIEmbeddings( | |
model="text-embedding-3-large", | |
openai_api_key=OPENAI_API_KEY, | |
) | |
llm = ChatOpenAI(model="gpt-4o", temperature=0, openai_api_key=OPENAI_API_KEY) | |
# ------------------- MIDDLEWARE ------------------- | |
async def add_success_flag(request: Request, call_next): | |
response = await call_next(request) | |
# Only modify JSON responses | |
if "application/json" in response.headers.get("content-type", ""): | |
try: | |
body = b"".join([chunk async for chunk in response.body_iterator]) | |
data = json.loads(body.decode()) | |
# Add success flag | |
data["success"] = 200 <= response.status_code < 300 | |
# Build new JSONResponse (auto handles Content-Length) | |
response = JSONResponse( | |
content=data, | |
status_code=response.status_code, | |
headers={k: v for k, v in response.headers.items() if k.lower() != "content-length"}, | |
) | |
except Exception: | |
# fallback if response is not JSON parseable | |
pass | |
return response | |
# === PANDAS AI FUNCTIONS === | |
def read_parquet_file_from_s3(ufuid=None, columns_list=None, records_count=None, file_location=''): | |
""" | |
Reads a Parquet file from S3 using Dask and returns it as a Pandas DataFrame. | |
Parameters: | |
ufuid (int): Optional user_file_upload_id to fetch S3 path from DB. | |
columns_list (list/str): Columns to read. | |
records_count (int): Not used currently. | |
file_location (str): Direct file path in S3. | |
Returns: | |
pandas.DataFrame | |
""" | |
try: | |
# Connect to PostgreSQL | |
conn = psycopg2.connect( | |
host=CONNECTIONS_HOST, | |
database=CONNECTIONS_DB, | |
user=CONNECTIONS_USER, | |
password=CONNECTIONS_PASS | |
) | |
cursor = conn.cursor() | |
if ufuid is not None: | |
query = """SELECT file_name, table_names FROM public.user_file_upload WHERE user_file_upload_id = %s""" | |
cursor.execute(query, (ufuid,)) | |
file = cursor.fetchone() | |
if not file: | |
raise ValueError(f"No file found for ufuid: {ufuid}") | |
file_name, s3_file_path = file | |
else: | |
# Normalize input path | |
file_location = re.sub(r'\.parquet(?!$)', '', file_location) | |
s3_file_path = file_location if file_location.endswith('.parquet') else file_location + '.parquet' | |
# Extract relative S3 path | |
s3_file_path = urllib.parse.unquote(s3_file_path.split(f"{S3_Bucket_Name}/")[-1]) | |
if not s3_file_path.endswith('.parquet'): | |
s3_file_path += '.parquet' | |
# Parse columns if given as comma-separated string | |
if columns_list and not isinstance(columns_list, list): | |
columns_list = [col.strip(' "\'') for col in columns_list.split(',')] | |
print(f"\n{'!' * 100}\nReading from: s3://{S3_Bucket_Name}/{s3_file_path}\n") | |
# Read using Dask | |
ddf = dd.read_parquet( | |
f"s3://{S3_Bucket_Name}/{s3_file_path}", | |
engine="pyarrow", | |
columns=columns_list, | |
assume_missing=True | |
) | |
ddf = ddf.repartition(npartitions=8) # Optimize for processing | |
print("Reading Parquet file from S3 completed successfully.") | |
# Close database connection | |
cursor.close() | |
conn.close() | |
return ddf.compute() | |
except Exception as e: | |
print(f"β Error reading Parquet file: {e}") | |
return pd.DataFrame() # Return empty DataFrame on error | |
def pandas_agent(filepath: str, query: str) -> str: | |
""" | |
PandasAI agent that reads data from S3 and answers queries about the data. | |
Parameters: | |
filepath (str): S3 file path or ufuid | |
query (str): Natural language query about the data | |
Returns: | |
str: Answer from PandasAI | |
""" | |
try: | |
# Check if filepath is a number (ufuid) or a file path | |
if filepath.isdigit(): | |
# It's a ufuid | |
data = read_parquet_file_from_s3(ufuid=int(filepath)) | |
else: | |
# It's a file path | |
data = read_parquet_file_from_s3(file_location=filepath) | |
if data.empty: | |
return "β No data found or failed to load the file. Please check the file path or ufuid." | |
# Initialize PandasAI LLM | |
if not OPENAI_API_KEY: | |
return "β OPENAI_API_KEY is not set in environment variables." | |
pandas_llm = PandasOpenAI(api_token=OPENAI_API_KEY) | |
# Create SmartDataframe | |
sdf = SmartDataframe(data, config={"llm": pandas_llm}) | |
# Ask the question | |
print(f"π Processing query: {query}") | |
result = sdf.chat(query) | |
# Handle different types of results | |
if isinstance(result, str): | |
return f"π Analysis Result:\n{result}" | |
elif isinstance(result, (pd.DataFrame, pd.Series)): | |
return f"π Analysis Result:\n{result.to_string()}" | |
else: | |
return f"π Analysis Result:\n{str(result)}" | |
except Exception as e: | |
error_msg = f"β Error in pandas_agent: {str(e)}" | |
print(error_msg) | |
return error_msg | |
# === INPUT SCHEMAS === | |
# Add these new Pydantic models after the existing schemas | |
class DatasetInfoRequest(BaseModel): | |
userLoginId: int | |
orgId: int | |
project_id: int | |
auth_token: str | |
class DatasetInfoResponse(BaseModel): | |
project_id: int | |
dataset_info: Dict[str, Any] | |
ingestion_status: Optional[str] = None | |
# Chat and Session Schemas | |
class Query(BaseModel): | |
message: str | |
class ProjectRequest(BaseModel): | |
userLoginId: int | |
orgId: int | |
auth_token: str | |
class BotQuery(BaseModel): | |
userLoginId: int | |
orgId: int | |
auth_token: str | |
session_id: Optional[str] = None | |
message: str | |
class PandasAgentQuery(BaseModel): | |
filepath: str = Field(..., description="S3 file path or ufuid") | |
query: str = Field(..., description="Natural language query about the data") | |
class SessionResponse(BaseModel): | |
session_id: str | |
userLoginId: int | |
orgId: int | |
created_at: str | |
status: str | |
title: Optional[str] = "New Chat" | |
class MessageResponse(BaseModel): | |
message_id: str | |
session_id: str | |
role: str # "user" or "assistant" | |
message: str | |
timestamp: str | |
class ChatHistoryResponse(BaseModel): | |
session_id: str | |
messages: List[MessageResponse] | |
total_messages: int | |
class UpdateSessionTitleRequest(BaseModel): | |
new_title: str | |
# Qdrant Collection Schemas | |
class CollectionRequest(BaseModel): | |
name: str | |
vector_size: int | |
distance: str = "Cosine" # Cosine, Euclid, Dot | |
class UpdateCollectionRequest(BaseModel): | |
vector_size: int | None = None | |
distance: str | None = None | |
# === SESSION MANAGEMENT FUNCTIONS === | |
def should_ingest_data(user_login_id: int) -> bool: | |
"""Check if data should be ingested based on number of sessions.""" | |
try: | |
response = requests.get(f"http://127.0.0.1:8000/sessions?userLoginId={user_login_id}") | |
if response.status_code == 200: | |
data = response.json() | |
return data.get("total_sessions", 0) <= 0 | |
else: | |
print(f"Failed to fetch sessions: {response.status_code}") | |
return False | |
except Exception as e: | |
print(f"Error checking session count: {e}") | |
return False | |
#_________________________file_ingestion_services___________________________________ | |
def create_session(userLoginId: int, orgId: int, auth_token: str) -> dict: | |
"""Create a new chat session""" | |
session_id = str(uuid.uuid4()) | |
session_data = { | |
"session_id": session_id, | |
"userLoginId": userLoginId, | |
"orgId": orgId, | |
"auth_token": auth_token, | |
"created_at": datetime.now().isoformat(), | |
"status": "active", | |
"title": "New Chat" # Default title, will be updated after first message | |
} | |
# Store session in Redis with 24 hour TTL | |
redis_client.setex( | |
f"session:{session_id}", | |
86400, # 24 hours | |
json.dumps(session_data) | |
) | |
# Initialize empty chat history | |
redis_client.setex( | |
f"chat:{session_id}", | |
86400, # 24 hours | |
json.dumps([]) | |
) | |
# Initialize conversation memory | |
redis_client.setex( | |
f"memory:{session_id}", | |
86400, # 24 hours | |
json.dumps([]) | |
) | |
return session_data | |
def get_session(session_id: str) -> dict: | |
"""Get session data from Redis""" | |
session_data = redis_client.get(f"session:{session_id}") | |
if not session_data: | |
raise HTTPException(status_code=404, detail="Session not found or expired") | |
return json.loads(session_data) | |
def add_message_to_session(session_id: str, role: str, message: str) -> str: | |
"""Add message to session chat history""" | |
message_id = str(uuid.uuid4()) | |
message_data = { | |
"message_id": message_id, | |
"session_id": session_id, | |
"role": role, | |
"message": message, | |
"timestamp": datetime.now().isoformat() | |
} | |
# Get current chat history | |
chat_history = redis_client.get(f"chat:{session_id}") | |
if chat_history: | |
messages = json.loads(chat_history) | |
else: | |
messages = [] | |
# Add new message | |
messages.append(message_data) | |
# Update chat history in Redis with extended TTL | |
redis_client.setex( | |
f"chat:{session_id}", | |
86400, # 24 hours | |
json.dumps(messages) | |
) | |
return message_id | |
def get_session_memory(session_id: str) -> List[Dict]: | |
"""Get conversation memory for session""" | |
memory_data = redis_client.get(f"memory:{session_id}") | |
if memory_data: | |
return json.loads(memory_data) | |
return [] | |
def update_session_memory(session_id: str, messages: List[Dict]): | |
"""Update conversation memory for session""" | |
redis_client.setex( | |
f"memory:{session_id}", | |
86400, # 24 hours | |
json.dumps(messages) | |
) | |
def update_session_title(session_id: str): | |
"""Update session title after first message""" | |
try: | |
# Get session data | |
session_data = redis_client.get(f"session:{session_id}") | |
if not session_data: | |
return | |
session = json.loads(session_data) | |
# Only update if current title is "New Chat" | |
if session.get("title", "New Chat") == "New Chat": | |
new_title = generate_session_title(session_id) | |
session["title"] = new_title | |
# Update session in Redis | |
redis_client.setex( | |
f"session:{session_id}", | |
86400, # 24 hours | |
json.dumps(session) | |
) | |
except Exception as e: | |
print(f"Error updating session title: {e}") | |
pass # Don't fail the request if title update fails | |
def generate_session_title(session_id: str) -> str: | |
"""Generate and optionally save a title for the session based on chat history""" | |
try: | |
# Check session | |
session_data = redis_client.get(f"session:{session_id}") | |
if session_data: | |
session = json.loads(session_data) | |
if "user_title" in session: | |
# Don't override user-defined titles | |
return session["user_title"] | |
# Get chat history | |
chat_data = redis_client.get(f"chat:{session_id}") | |
if not chat_data: | |
return "New Chat" | |
messages = json.loads(chat_data) | |
if not messages: | |
return "New Chat" | |
first_user_message = next( | |
(msg["message"] for msg in messages if msg["role"] == "user"), None | |
) | |
if not first_user_message: | |
return "New Chat" | |
# Generate title with LLM | |
title_prompt = f"""Generate a short, descriptive title (maximum 6 words) for a chat conversation that starts with this message: | |
"{first_user_message[:200]}" | |
Return only the title, no quotes or additional text. The title should capture the main topic or intent of the conversation.""" | |
try: | |
response = llm.invoke(title_prompt) | |
title = response.content.strip().replace('"', '').replace("'", "") | |
if len(title) > 50: | |
title = title[:47] + "..." | |
except Exception as e: | |
print(f"Error generating title with LLM: {e}") | |
# Fallback title | |
words = first_user_message.split()[:4] | |
title = " ".join(words) + ("..." if len(words) >= 4 else "") | |
# Save to session | |
if session_data: | |
session["generated_title"] = title | |
if not session.get("user_title"): | |
session["title"] = title # Only if no user title | |
redis_client.setex(f"session:{session_id}", 86400, json.dumps(session)) | |
return title | |
except Exception as e: | |
print(f"Error in generate_session_title: {e}") | |
return "New Chat" | |
def get_user_sessions(userLoginId: int) -> List[dict]: | |
"""Get all sessions for a user with generated titles""" | |
sessions = [] | |
# Scan for all session keys | |
for key in redis_client.scan_iter(match="session:*"): | |
session_data = redis_client.get(key) | |
if session_data: | |
session = json.loads(session_data) | |
if session["userLoginId"] == userLoginId: | |
# Generate title based on chat history | |
session["title"] = generate_session_title(session["session_id"]) | |
sessions.append(session) | |
# Sort sessions by created_at (most recent first) | |
sessions.sort(key=lambda x: x["created_at"], reverse=True) | |
return sessions | |
def delete_session(session_id: str): | |
"""Delete session and associated data""" | |
# Delete session data | |
redis_client.delete(f"session:{session_id}") | |
# Delete chat history | |
redis_client.delete(f"chat:{session_id}") | |
# Delete memory | |
redis_client.delete(f"memory:{session_id}") | |
# === UTILITY FUNCTIONS === | |
def get_encoded_auth_token(user: int, token: str) -> str: | |
auth_string = f"{user}:{token}" | |
return base64.b64encode(auth_string.encode("utf-8")).decode("utf-8") | |
def fetch_user_projects(userLoginId: int, orgId: int, auth_token: str): | |
url = "https://japidemo.dev.ingenspark.com/fetchUserProjects" | |
payload = { | |
"userLoginId": userLoginId, | |
"orgId": orgId | |
} | |
headers = { | |
'accept': 'application/json, text/plain, */*', | |
'authorization': f'Basic {auth_token}', | |
'content-type': 'application/json; charset=UTF-8' | |
} | |
try: | |
response = requests.post(url, headers=headers, json=payload) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
raise HTTPException(status_code=response.status_code if 'response' in locals() else 500, | |
detail=str(e)) | |
def format_project_response(data: dict) -> str: | |
my_projects = data.get("data", {}).get("Myprojects", []) | |
other_projects = data.get("data", {}).get("Otherprojects", []) | |
all_projects = [] | |
for project in my_projects: | |
all_projects.append({ | |
"type": "Your Project", | |
"projectNm": project["projectNm"], | |
"projectId": project["projectId"], | |
"created_dttm": project["created_dttm"].split('.')[0], | |
"description": project["description"], | |
"categoryName": project["categoryName"] | |
}) | |
for project in other_projects: | |
all_projects.append({ | |
"type": "Other Project", | |
"projectNm": project["projectNm"], | |
"projectId": project["projectId"], | |
"created_dttm": project["created_dttm"].split('.')[0], | |
"description": project["description"], | |
"categoryName": project["categoryName"] | |
}) | |
if not all_projects: | |
return "β No projects found." | |
# Build the formatted string | |
result = [f"β You have access to {len(all_projects)} project(s):\n"] | |
for i, project in enumerate(all_projects, 1): | |
result.append(f"{i}. Project Name: {project['projectNm']} ({project['type']})") | |
result.append(f" Project ID: {project['projectId']}") | |
result.append(f" Created On: {project['created_dttm']}") | |
result.append(f" Description: {project['description']}") | |
result.append(f" Category: {project['categoryName']}\n") | |
return "\n".join(result) | |
# === TOOL FUNCTIONS === | |
# def search_documents(query: str) -> str: | |
# """Search through ingested documents and get relevant information.""" | |
# try: | |
# # Generate embedding for the query | |
# query_vector = embedding_model.embed_query(query) | |
# # Search in Qdrant | |
# search_result = qdrant_client.search( | |
# collection_name=QDRANT_COLLECTION_NAME, | |
# query_vector=query_vector, | |
# limit=5, | |
# ) | |
# if not search_result: | |
# return "No relevant information found in the knowledge base." | |
# # Convert results to text content | |
# context_texts = [] | |
# sources = [] | |
# for hit in search_result: | |
# context_texts.append(hit.payload["text"]) | |
# sources.append(hit.payload.get("source", "unknown")) | |
# # Create a simple prompt for answering based on context | |
# context = "\n\n".join(context_texts) | |
# unique_sources = list(set(sources)) | |
# # Use the LLM directly to answer the message based on context | |
# prompt = f"""Based on the following context, answer the message: {query} | |
# Context: | |
# {context} | |
# Please provide a comprehensive answer based on the context above. If the context doesn't contain enough information to answer the message, say so clearly.""" | |
# response = llm.invoke(prompt) | |
# return f"{response.content}\n\nSources: {', '.join(unique_sources)}" | |
# except Exception as e: | |
# return f"Error searching documents: {str(e)}" | |
import requests | |
def search_documents(query: str) -> str: | |
collection_name = 9 # As per the URL path in your curl example | |
top_k = 5 # Default value, as shown in the curl | |
url = f"https://srivatsavdamaraju-accusaga-bot.hf.space/search/{collection_name}" | |
params = { | |
"query": query, | |
"top_k": top_k | |
} | |
headers = { | |
"accept": "application/json" | |
} | |
response = requests.get(url, params=params, headers=headers) | |
if response.status_code == 200: | |
return response.text # or response.json() if you want to work with structured data | |
else: | |
return f"Error {response.status_code}: {response.text}" | |
# Global variables to store auth context (for tool functions) | |
_current_user_id = None | |
_current_org_id = None | |
_current_auth_token = None | |
def get_user_projects(userLoginId: str) -> str: | |
"""Get list of projects for a user.""" | |
try: | |
# Use global auth context if available | |
if _current_auth_token and _current_user_id: | |
user_id = _current_user_id | |
org_id = _current_org_id or 1 | |
auth_token = _current_auth_token | |
else: | |
return "β Authentication token required. Please provide auth_token in your request." | |
# Encode auth token using the actual user ID and provided token | |
encoded_token = get_encoded_auth_token(user_id, auth_token) | |
# Fetch projects | |
data = fetch_user_projects(user_id, org_id, encoded_token) | |
# Format and return the project list | |
formatted = format_project_response(data) | |
return formatted | |
except ValueError: | |
return "β Invalid userLoginId format. Please provide a valid number." | |
except Exception as e: | |
return f"β Error fetching projects: {str(e)}" | |
# def pandas_data_analysis(query_with_filepath: str) -> str: | |
# """ | |
# Tool for data analysis using PandasAI. | |
# Input format: 'filepath|query' where filepath is S3 path or ufuid, and query is the analysis question. | |
# """ | |
# try: | |
# # Parse the input to extract filepath and query | |
# parts = query_with_filepath.split('|', 1) | |
# if len(parts) != 2: | |
# return "β Invalid input format. Please use: 'filepath|query' format." | |
# filepath, query = parts | |
# filepath = filepath.strip() | |
# query = query.strip() | |
# if not filepath or not query: | |
# return "β Both filepath and query are required." | |
# # Use the pandas_agent function | |
# result = pandas_agent(filepath, query) | |
# return result | |
# except Exception as e: | |
# return f"β Error in pandas data analysis: {str(e)}" | |
def get_dataset_info(userLoginId: int, orgId: int, project_id: int, user: str, token: str): | |
""" | |
Fetch dataset info from the API. | |
""" | |
# Encode auth token | |
auth_token = get_encoded_auth_token(user, token) | |
url = f"https://papidemo.dev.ingenspark.com/get_dataset_info?user_login_id={userLoginId}&project_id={project_id}" | |
headers = { | |
'accept': 'application/json, text/plain, */*', | |
'authorization': f'Basic {auth_token}', | |
'content-type': 'application/json; charset=utf-8', | |
'origin': 'https://demo-app.dev.ingenspark.com', | |
'referer': 'https://demo-app.dev.ingenspark.com/', | |
'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/139.0.0.0 Safari/537.36', | |
} | |
try: | |
response = requests.get(url, headers=headers) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
return {"error": str(e)} | |
except ValueError: | |
return {"error": "Invalid JSON response", "text": response.text} | |
def check_and_create_user_collection(userLoginId: int) -> bool: | |
""" | |
Check if a collection named `userLoginId` exists. | |
If not, create the collection. | |
Returns True if collection exists or created successfully, False otherwise. | |
""" | |
try: | |
# Get all collections | |
collections = qdrant_client.get_collections() | |
collection_names = [col.name for col in collections.collections] | |
collection_name = str(userLoginId) | |
if collection_name in collection_names: | |
print(f"Collection '{collection_name}' already exists") | |
return True | |
else: | |
print(f"Creating new collection for user {userLoginId}...") | |
# Create collection with standard parameters | |
qdrant_client.recreate_collection( | |
collection_name=collection_name, | |
vectors_config=VectorParams(size=3072, distance=Distance.COSINE), | |
) | |
print(f"Collection '{collection_name}' created successfully") | |
return True | |
except Exception as e: | |
print(f"Error managing collection for user {userLoginId}: {str(e)}") | |
return False | |
def ingest_datasets_to_collection(collection_name: str, datasets_data: Dict[str, Any]) -> bool: | |
""" | |
Ingest datasets information to a user's collection. | |
""" | |
try: | |
# Convert datasets data to a formatted text for ingestion | |
datasets_text = json.dumps(datasets_data, indent=2, ensure_ascii=False) | |
# Create a temporary file with the datasets information | |
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False, encoding='utf-8') as tmp_file: | |
tmp_file.write(f"Dataset Information Summary\n") | |
tmp_file.write("=" * 50 + "\n\n") | |
tmp_file.write(datasets_text) | |
tmp_file_path = tmp_file.name | |
try: | |
# Load the temporary file | |
loader = TextLoader(tmp_file_path, encoding='utf-8') | |
docs = loader.load() | |
# Split into chunks | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
chunks = splitter.split_documents(docs) | |
texts = [chunk.page_content for chunk in chunks] | |
# Generate embeddings | |
embeddings = embedding_model.embed_documents(texts) | |
# Create points for Qdrant | |
points = [ | |
PointStruct( | |
id=str(uuid.uuid4()), | |
vector=embeddings[i], | |
payload={ | |
"text": texts[i], | |
"source": "dataset_info", | |
"type": "dataset_summary" | |
}, | |
) | |
for i in range(len(texts)) | |
] | |
# Upsert to Qdrant | |
qdrant_client.upsert(collection_name=collection_name, points=points) | |
print(f"Successfully ingested dataset information to collection '{collection_name}'") | |
return True | |
finally: | |
# Clean up temporary file | |
os.unlink(tmp_file_path) | |
except Exception as e: | |
print(f"Error ingesting datasets to collection {collection_name}: {str(e)}") | |
return False | |
def fetch_and_ingest_user_datasets(userLoginId: int, orgId: int, auth_token: str) -> Dict[str, Any]: | |
""" | |
Fetch all user projects and their datasets, then ingest to user's collection. | |
""" | |
try: | |
# Step 1: Ensure user collection exists | |
collection_created = check_and_create_user_collection(userLoginId) | |
if not collection_created: | |
return { | |
"success": False, | |
"message": "Failed to create/verify user collection", | |
"datasets": {} | |
} | |
# Step 2: Fetch user projects | |
encoded_token = get_encoded_auth_token(userLoginId, auth_token) | |
projects_data = fetch_user_projects(userLoginId, orgId, encoded_token) | |
# Step 3: Extract project IDs | |
project_ids = [] | |
for proj in projects_data.get("data", {}).get("Myprojects", []): | |
project_ids.append(proj["projectId"]) | |
for proj in projects_data.get("data", {}).get("Otherprojects", []): | |
project_ids.append(proj["projectId"]) | |
# Step 4: Fetch dataset info for each project | |
all_datasets = {} | |
for project_id in project_ids: | |
dataset_info = get_dataset_info(userLoginId, orgId, project_id, userLoginId, auth_token) | |
all_datasets[str(project_id)] = dataset_info | |
# Step 5: Ingest datasets to user's collection | |
ingestion_success = ingest_datasets_to_collection(str(userLoginId), all_datasets) | |
return { | |
"success": True, | |
"collection_name": str(userLoginId), | |
"projects_found": len(project_ids), | |
"datasets": all_datasets, | |
"ingestion_success": ingestion_success, | |
"message": f"Successfully processed {len(project_ids)} projects and {'ingested' if ingestion_success else 'failed to ingest'} dataset information" | |
} | |
except Exception as e: | |
return { | |
"success": False, | |
"message": f"Error processing user datasets: {str(e)}", | |
"datasets": {} | |
} | |
def get_user_datasets(userLoginId_str: str) -> str: | |
""" | |
Tool to fetch user datasets and ingest them into user's collection. | |
This tool automatically manages collections and dataset ingestion. | |
""" | |
try: | |
# Use global auth context | |
if not _current_auth_token or not _current_user_id or not _current_org_id: | |
return "Authentication context required. Please provide auth_token in your request." | |
userLoginId = int(userLoginId_str) if userLoginId_str.isdigit() else _current_user_id | |
orgId = _current_org_id | |
auth_token = _current_auth_token | |
# Fetch and process datasets | |
result = fetch_and_ingest_user_datasets(userLoginId, orgId, auth_token) | |
if result["success"]: | |
datasets_count = len(result["datasets"]) | |
return f"""β Dataset Management Complete: | |
π Found {result['projects_found']} projects with dataset information | |
π Collection '{result['collection_name']}' ready | |
πΎ Ingestion Status: {'Success' if result['ingestion_success'] else 'Failed'} | |
Dataset Summary: | |
{json.dumps(result['datasets'], indent=2) if datasets_count > 0 else 'No datasets found'} | |
You can now search through your datasets using document search queries!""" | |
else: | |
return f"β Error: {result['message']}" | |
except ValueError: | |
return "β Invalid userLoginId format. Please provide a valid number." | |
except Exception as e: | |
return f"β Error managing user datasets: {str(e)}" | |
import os | |
import re | |
import urllib.parse | |
import psycopg2 | |
import pandas as pd | |
from langchain_openai import ChatOpenAI | |
from langchain_experimental.agents import create_pandas_dataframe_agent | |
from retrive_secrects import * # PostgreSQL and other secrets | |
# Constants | |
S3_Bucket_Name = 'ingenspark-user-files' | |
def read_parquet_file_from_s3(file_location): | |
""" | |
Reads a Parquet file from S3 using pandas and returns it as a DataFrame. | |
Args: | |
file_location (str): S3-relative path to the Parquet file. | |
Returns: | |
pd.DataFrame | |
""" | |
# Normalize and clean path | |
file_location = re.sub(r'\.parquet(?!$)', '', file_location) | |
s3_file_path = file_location if file_location.endswith('.parquet') else file_location + '.parquet' | |
# Extract relative S3 path | |
s3_file_path = s3_file_path.split(f"{S3_Bucket_Name}/")[-1] | |
s3_file_path = urllib.parse.unquote(s3_file_path) | |
if not s3_file_path.endswith('.parquet'): | |
s3_file_path += '.parquet' | |
s3_url = f"s3://{S3_Bucket_Name}/{s3_file_path}" | |
print(f"\nπΉ Reading from S3: {s3_url}\n") | |
# Read Parquet file using pandas | |
df = pd.read_parquet(s3_url, engine='pyarrow') | |
return df | |
def pandas_ai(input_text: str, api_key: str = None, model: str = "gpt-4") -> str: | |
""" | |
Parses the input string to extract the S3 path and user query, | |
reads the data, and answers the query using LLM. | |
Args: | |
input_text (str): Input in the format "S3_path , natural language question" | |
api_key (str): OpenAI API key (or read from env) | |
model (str): OpenAI model to use (default: gpt-4) | |
Returns: | |
str: Answer from the LLM | |
""" | |
try: | |
# Split input into S3 path and question | |
parts = input_text.split(",", 1) | |
if len(parts) != 2: | |
raise ValueError("Input must be in the format: <S3_path>, <question>") | |
file_path = parts[0].strip() | |
user_query = parts[1].strip() | |
# Get OpenAI key | |
openai_key = api_key or os.getenv("OPENAI_API_KEY") | |
if not openai_key: | |
raise ValueError("OpenAI API key must be provided or set in environment variable 'OPENAI_API_KEY'.") | |
# Read DataFrame from S3 | |
df = read_parquet_file_from_s3(file_location=file_path) | |
# Initialize OpenAI LLM | |
llm = ChatOpenAI( | |
temperature=0, | |
model=model, | |
openai_api_key=openai_key | |
) | |
# Create LangChain agent | |
agent_executor = create_pandas_dataframe_agent( | |
llm=llm, | |
df=df, | |
agent_type="tool-calling", | |
verbose=False, | |
handle_parsing_errors=True, | |
include_df_in_prompt=True, | |
number_of_head_rows=5, | |
allow_dangerous_code=True | |
) | |
# Ask the question | |
result = agent_executor.invoke({"input": user_query}) | |
return result["output"] | |
except Exception as e: | |
return f"β Error: {str(e)}" | |
# =============== Example Usage =============== | |
# if __name__ == "__main__": | |
# user_input = input("Enter your input (format: <S3_Path>, <Question>):\n") | |
# answer = pandas_ai(user_input) | |
# print("\nπ Answer:\n", answer) | |
# === CREATE TOOLS === | |
dataset_management_tool = Tool( | |
name="manage_user_datasets", | |
description="""Use this tool to automatically fetch user datasets and set up their personal collection. | |
This tool will: | |
1. Create a user-specific collection if it doesn't exist | |
2. Fetch all user projects and their dataset information | |
3. Ingest the dataset information into the user's collection for searching | |
Perfect for when users want to: | |
- Set up their dataset collection | |
- Refresh their dataset information | |
- Prepare their datasets for searching and analysis | |
Input should be the userLoginId (e.g., '25') or leave empty to use current user. | |
Note: This tool requires authentication context to be set.""", | |
func=get_user_datasets | |
) | |
document_search_tool = Tool( | |
name="document_search", | |
description="""Use this tool to search through ingested documents and get relevant information from the knowledge base. | |
Perfect for answering messages about uploaded documents, manuals, or any content that was previously stored. | |
Input should be a search query or message about the documents.""", | |
func=search_documents | |
) | |
project_list_tool = Tool( | |
name="get_user_projects", | |
description="""Use this tool to get the list of projects for a user. | |
Perfect for when users ask about their projects, want to see available projects, or need project information. | |
Input should be the userLoginId (e.g., '25'). | |
Note: This tool requires authentication context to be set.""", | |
func=get_user_projects | |
) | |
pandas_analysis_tool = Tool( | |
name="pandas_data_analysis", | |
description="""Use this tool for data analysis on CSV/Parquet files using PandasAI. | |
Perfect for when users ask questions about data analysis, statistics, insights, or want to query their datasets. | |
Input format: 'filepath, query' where: | |
- filepath: S3 file path (e.g., 'User-Uploaded-Raw-Files/Data2004csv1754926601269756') or ufuid (e.g., '123') | |
- query: Natural language question about the data (e.g., 'What are the top 5 values?', 'Show me summary statistics') | |
Examples: | |
- 'User-Uploaded-Raw-Files/mydata.csv, What is this file about?' | |
- '123, Show me the first 5 rows' | |
- 'Modified-Files/processed_data, What are the most common values in column X?' | |
""", | |
func=pandas_ai | |
) | |
# === AGENT SETUP === | |
# def create_agent_with_session_memory(session_id: str): | |
# """Create agent with session memory from Redis""" | |
# # Get memory from Redis | |
# memory_messages = get_session_memory(session_id) | |
# agent_prompt = ChatPromptTemplate.from_messages([ | |
# ("system", """You are a helpful AI assistant with access to multiple tools and conversation memory: | |
# 1. **Document Search**: Search through uploaded documents and knowledge base | |
# 2. **Project Management**: Get list of user projects and project information | |
# 3. **Data Analysis**: Analyze CSV/Parquet files using PandasAI for insights, statistics, and queries | |
# Your capabilities: | |
# - Answer messages about documents using the document search tool | |
# - Help users find their projects and project information | |
# - Perform data analysis on uploaded datasets using natural language queries | |
# - Remember previous conversations in this session | |
# - Provide general assistance and information | |
# - Use appropriate tools based on user queries | |
# Guidelines: | |
# - Use the document search tool when users ask about specific content, documentation, or information that might be in uploaded files | |
# - Use the project tool when users ask about projects, want to see their projects, or need project-related information | |
# - Use the pandas analysis tool when users ask about data analysis, statistics, insights, or want to query datasets | |
# - For pandas analysis, you need both a filepath (S3 path or ufuid) and a query - ask for missing information if needed | |
# - Reference previous conversation context when relevant | |
# - Be clear about which tool you're using and what information you're providing | |
# - If you're unsure which tool to use, you can ask for clarification | |
# - Provide helpful, accurate, and well-formatted responses | |
# Remember: Always use the most appropriate tool based on the user's message and conversation context to provide the best possible answer."""), | |
# MessagesPlaceholder(variable_name="chat_history"), | |
# ("user", "{input}"), | |
# MessagesPlaceholder(variable_name="agent_scratchpad"), | |
# ]) | |
# # Create memory object | |
# memory = ConversationBufferMemory( | |
# memory_key="chat_history", | |
# return_messages=True | |
# ) | |
# # Load existing messages into memory | |
# for msg in memory_messages: | |
# if msg["role"] == "user": | |
# memory.chat_memory.add_user_message(msg["message"]) | |
# else: | |
# memory.chat_memory.add_ai_message(msg["message"]) | |
# # Create tools list | |
# tools = [document_search_tool, project_list_tool, pandas_analysis_tool] | |
# # Create the agent | |
# agent = create_openai_tools_agent(llm, tools, agent_prompt) | |
# # Create the agent executor with memory | |
# agent_executor = AgentExecutor( | |
# agent=agent, | |
# tools=tools, | |
# verbose=True, | |
# memory=memory | |
# ) | |
# return agent_executor, memory | |
#Update the create_agent_with_session_memory function to include the new tool | |
def create_agent_with_session_memory(session_id: str): | |
"""Create agent with session memory from Redis - Updated with dataset management""" | |
# Get memory from Redis | |
memory_messages = get_session_memory(session_id) | |
agent_prompt = ChatPromptTemplate.from_messages([ | |
("system", """You are a helpful AI assistant with access to multiple tools and conversation memory: | |
1. **Document Search**: Search through uploaded documents and user's dataset knowledge base | |
2. **Project Management**: Get list of user projects and project information | |
3. **Data Analysis**: Analyze CSV/Parquet files using PandasAI for insights, statistics, and queries | |
4. **Dataset Management**: Automatically fetch and organize user datasets into searchable collections | |
Your capabilities: | |
- Answer questions about documents using the document search tool | |
- Help users find their projects and project information | |
- Perform data analysis on uploaded datasets using natural language queries | |
- Automatically manage user datasets and make them searchable | |
- Remember previous conversations in this session | |
- Provide general assistance and information | |
Guidelines: | |
- Use the document search tool when users ask about specific content, documentation, or dataset information | |
- Use the project tool when users ask about projects, want to see their projects, or need project-related information | |
- Use the pandas analysis tool when users ask about data analysis, statistics, insights, or want to query specific datasets | |
- Use the dataset management tool when users want to set up their datasets for searching, or refresh their dataset collection | |
- For pandas analysis, you need both a filepath (S3 path or ufuid) and a query - ask for missing information if needed | |
- The dataset management tool automatically creates user collections and ingests their dataset information | |
- Reference previous conversation context when relevant | |
- Be clear about which tool you're using and what information you're providing | |
- If you're unsure which tool to use, you can ask for clarification | |
- Provide helpful, accurate, and well-formatted responses | |
Dataset Management Flow: | |
1. When users first interact or ask about their datasets, suggest using dataset management to set up their collection | |
2. After dataset management completes, users can search their datasets using document search | |
3. For specific data analysis, direct them to use pandas analysis with specific file paths | |
Remember: Always use the most appropriate tool based on the user's query and conversation context to provide the best possible answer."""), | |
MessagesPlaceholder(variable_name="chat_history"), | |
("user", "{input}"), | |
MessagesPlaceholder(variable_name="agent_scratchpad"), | |
]) | |
# Create memory object | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True | |
) | |
# Load existing messages into memory | |
for msg in memory_messages: | |
if msg["role"] == "user": | |
memory.chat_memory.add_user_message(msg["message"]) | |
else: | |
memory.chat_memory.add_ai_message(msg["message"]) | |
# Create tools list - Updated with dataset management tool | |
tools = [document_search_tool, project_list_tool, pandas_analysis_tool, dataset_management_tool] | |
# Create the agent | |
agent = create_openai_tools_agent(llm, tools, agent_prompt) | |
# Create the agent executor with memory | |
agent_executor = AgentExecutor( | |
agent=agent, | |
tools=tools, | |
verbose=True, | |
memory=memory | |
) | |
return agent_executor, memory | |
# ------------------- COLLECTION CRUD ENDPOINTS ------------------- | |
def create_collection(req: CollectionRequest): | |
"""Create a new Qdrant collection""" | |
distance_map = { | |
"Cosine": Distance.COSINE, | |
"Euclid": Distance.EUCLID, | |
"Dot": Distance.DOT, | |
} | |
if req.distance not in distance_map: | |
raise HTTPException(status_code=400, detail="Invalid distance metric") | |
try: | |
qdrant_client.recreate_collection( | |
collection_name=req.name, | |
vectors_config=VectorParams(size=req.vector_size, distance=distance_map[req.distance]), | |
) | |
return {"message": f"β Collection '{req.name}' created successfully"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def list_collections(): | |
"""List all Qdrant collections""" | |
try: | |
collections = qdrant_client.get_collections() | |
return collections.dict() | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def get_collection(name: str): | |
"""Get information about a specific collection""" | |
try: | |
collection_info = qdrant_client.get_collection(collection_name=name) | |
return collection_info.dict() | |
except Exception as e: | |
raise HTTPException(status_code=404, detail=f"Collection '{name}' not found: {str(e)}") | |
def update_collection(name: str, req: UpdateCollectionRequest): | |
"""Update a collection's configuration""" | |
distance_map = { | |
"Cosine": Distance.COSINE, | |
"Euclid": Distance.EUCLID, | |
"Dot": Distance.DOT, | |
} | |
try: | |
current = qdrant_client.get_collection(name) | |
vector_size = req.vector_size if req.vector_size else current.config.params.vectors.size | |
distance = distance_map[req.distance] if req.distance else current.config.params.vectors.distance | |
qdrant_client.recreate_collection( | |
collection_name=name, | |
vectors_config=VectorParams(size=vector_size, distance=distance), | |
) | |
return {"message": f"β»οΈ Collection '{name}' updated successfully"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def delete_collection(name: str): | |
"""Delete a collection""" | |
try: | |
qdrant_client.delete_collection(collection_name=name) | |
return {"message": f"ποΈ Collection '{name}' deleted successfully"} | |
except Exception as e: | |
raise HTTPException(status_code=404, detail=f"Collection '{name}' not found: {str(e)}") | |
# ------------------- INGESTION ENDPOINTS ------------------- | |
async def ingest_file(collection_name: str, file: UploadFile = File(...)): | |
"""Ingest a file into a Qdrant collection""" | |
suffix = os.path.splitext(file.filename)[-1].lower() | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
tmp.write(await file.read()) | |
tmp_path = tmp.name | |
try: | |
# Select loader based on file suffix | |
if suffix == ".pdf": | |
loader = PyPDFLoader(tmp_path) | |
elif suffix in [".txt", ".md"]: | |
loader = TextLoader(tmp_path) | |
elif suffix == ".csv": | |
loader = CSVLoader(file_path=tmp_path) | |
elif suffix == ".docx": | |
loader = Docx2txtLoader(tmp_path) | |
elif suffix == ".html": | |
loader = BSHTMLLoader(file_path=tmp_path) | |
else: | |
raise HTTPException(status_code=400, detail=f"β Unsupported file type: {suffix}") | |
docs = loader.load() | |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
chunks = splitter.split_documents(docs) | |
texts = [chunk.page_content for chunk in chunks] | |
# Embed documents synchronously (OpenAIEmbeddings is sync) | |
embeddings = embedding_model.embed_documents(texts) | |
# Verify embedding dimension matches collection config | |
collection_info = qdrant_client.get_collection(collection_name=collection_name) | |
expected_dim = collection_info.config.params.vectors.size | |
if len(embeddings[0]) != expected_dim: | |
raise HTTPException( | |
status_code=400, | |
detail=f"Embedding dimension mismatch: expected {expected_dim}, got {len(embeddings[0])}", | |
) | |
points = [ | |
PointStruct( | |
id=str(uuid.uuid4()), | |
vector=embeddings[i], | |
payload={"text": texts[i], "source": file.filename}, | |
) | |
for i in range(len(texts)) | |
] | |
qdrant_client.upsert(collection_name=collection_name, points=points) | |
except HTTPException as he: | |
raise he # re-raise HTTP exceptions directly | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Ingestion failed: {str(e)}") | |
finally: | |
os.remove(tmp_path) | |
return {"message": f"π '{file.filename}' ingested into '{collection_name}' successfully"} | |
def search_collection( | |
collection_name: str, | |
query: str = QueryParam(..., description="Your question or search query"), | |
top_k: int = 5 | |
): | |
"""Search within a specific collection""" | |
try: | |
# Generate embedding for the query | |
query_vector = embedding_model.embed_query(query) | |
# Perform similarity search in Qdrant | |
search_result = qdrant_client.search( | |
collection_name=collection_name, | |
query_vector=query_vector, | |
limit=top_k, | |
) | |
# Format results | |
results = [ | |
{ | |
"score": hit.score, | |
"payload": hit.payload, | |
} | |
for hit in search_result | |
] | |
return { | |
"query": query, | |
"collection": collection_name, | |
"results": results, | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Search failed: {str(e)}") | |
# === SESSION MANAGEMENT ENDPOINTS === | |
def create_new_session(userLoginId: int, orgId: int, auth_token: str): | |
"""Create a new chat session""" | |
try: | |
session_data = create_session(userLoginId, orgId, auth_token) | |
return SessionResponse(**session_data) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error creating session: {str(e)}") | |
def list_user_sessions(userLoginId: int): | |
"""List all sessions for a user""" | |
try: | |
sessions = get_user_sessions(userLoginId) | |
return { | |
"userLoginId": userLoginId, | |
"total_sessions": len(sessions), | |
"sessions": sessions | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error fetching sessions: {str(e)}") | |
def delete_user_session(session_id: str): | |
"""Delete/close a session""" | |
try: | |
# Verify session exists | |
get_session(session_id) | |
# Delete session | |
delete_session(session_id) | |
return { | |
"message": f"Session {session_id} deleted successfully", | |
"session_id": session_id | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error deleting session: {str(e)}") | |
def get_session_history(session_id: str, n: int = QueryParam(50, description="Number of recent messages to return")): | |
"""Get chat history for a session""" | |
try: | |
# Verify session exists | |
get_session(session_id) | |
# Get chat history | |
chat_data = redis_client.get(f"chat:{session_id}") | |
if not chat_data: | |
return ChatHistoryResponse( | |
session_id=session_id, | |
messages=[], | |
total_messages=0 | |
) | |
messages = json.loads(chat_data) | |
# Get the last n messages (or all if less than n) | |
recent_messages = messages[-n:] if len(messages) > n else messages | |
# Convert to MessageResponse objects | |
message_responses = [MessageResponse(**msg) for msg in recent_messages] | |
return ChatHistoryResponse( | |
session_id=session_id, | |
messages=message_responses, | |
total_messages=len(messages) | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error fetching chat history: {str(e)}") | |
def update_session_title_endpoint(session_id: str, request: UpdateSessionTitleRequest): | |
"""Update the user-defined title of an existing session""" | |
try: | |
session_data = redis_client.get(f"session:{session_id}") | |
if not session_data: | |
raise HTTPException(status_code=404, detail="Session not found or expired") | |
session = json.loads(session_data) | |
new_title = request.new_title.strip() | |
if not new_title: | |
raise HTTPException(status_code=400, detail="New title cannot be empty") | |
if len(new_title) > 100: | |
raise HTTPException(status_code=400, detail="Title cannot exceed 100 characters") | |
old_title = session.get("title", "New Chat") | |
session["user_title"] = new_title | |
session["title"] = new_title # Effective title = user-defined | |
session["last_updated"] = datetime.now().isoformat() | |
redis_client.setex(f"session:{session_id}", 86400, json.dumps(session)) | |
return { | |
"message": "Session title updated successfully", | |
"session_id": session_id, | |
"old_title": old_title, | |
"new_title": new_title | |
} | |
except HTTPException: | |
raise | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error updating session title: {str(e)}") | |
def refresh_session_title(session_id: str): | |
"""Manually refresh/regenerate session title""" | |
try: | |
# Verify session exists | |
session_data = get_session(session_id) | |
# Generate new title | |
new_title = generate_session_title(session_id) | |
# Update session | |
session_data["title"] = new_title | |
redis_client.setex( | |
f"session:{session_id}", | |
86400, # 24 hours | |
json.dumps(session_data) | |
) | |
return { | |
"session_id": session_id, | |
"new_title": new_title, | |
"message": "Session title updated successfully" | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error updating session title: {str(e)}") | |
#_____________data injestion ___________________________ | |
import base64 | |
import json | |
import requests | |
from fastapi import HTTPException | |
def get_encoded_auth_token(user: int, token: str) -> str: | |
auth_string = f"{user}:{token}" | |
return base64.b64encode(auth_string.encode("utf-8")).decode("utf-8") | |
def get_dataset_info(userLoginId: int, orgId: int, project_id: int, user: str, token: str): | |
auth_token = get_encoded_auth_token(user, token) | |
url = f"https://papidemo.dev.ingenspark.com/get_dataset_info?user_login_id={userLoginId}&project_id={project_id}" | |
headers = { | |
'accept': 'application/json, text/plain, */*', | |
'authorization': f'Basic {auth_token}', | |
'content-type': 'application/json; charset=utf-8', | |
'origin': 'https://demo-app.dev.ingenspark.com', | |
'referer': 'https://demo-app.dev.ingenspark.com/', | |
'user-agent': 'Mozilla/5.0' | |
} | |
try: | |
response = requests.get(url, headers=headers) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
return {"error": str(e)} | |
except ValueError: | |
return {"error": "Invalid JSON response", "text": response.text} | |
def fetch_user_projects(userLoginId: int, orgId: int, auth_token: str): | |
url = "https://japidemo.dev.ingenspark.com/fetchUserProjects" | |
payload = {"userLoginId": userLoginId, "orgId": orgId} | |
headers = { | |
'accept': 'application/json, text/plain, */*', | |
'authorization': f'Basic {auth_token}', | |
'content-type': 'application/json; charset=UTF-8' | |
} | |
try: | |
response = requests.post(url, headers=headers, json=payload) | |
response.raise_for_status() | |
return response.json() | |
except requests.exceptions.RequestException as e: | |
raise HTTPException(status_code=response.status_code if 'response' in locals() else 500, detail=str(e)) | |
def format_project_response(data: dict) -> str: | |
my_projects = data.get("data", {}).get("Myprojects", []) | |
other_projects = data.get("data", {}).get("Otherprojects", []) | |
all_projects = [] | |
for project in my_projects + other_projects: | |
all_projects.append({ | |
"type": "Your Project" if project in my_projects else "Other Project", | |
"projectNm": project["projectNm"], | |
"projectId": project["projectId"], | |
"created_dttm": project["created_dttm"].split('.')[0], | |
"description": project["description"], | |
"categoryName": project["categoryName"] | |
}) | |
if not all_projects: | |
return "β No projects found." | |
result = [f"β You have access to {len(all_projects)} project(s):\n"] | |
for i, project in enumerate(all_projects, 1): | |
result.append(f"{i}. Project Name: {project['projectNm']} ({project['type']})") | |
result.append(f" Project ID: {project['projectId']}") | |
result.append(f" Created On: {project['created_dttm']}") | |
result.append(f" Description: {project['description']}") | |
result.append(f" Category: {project['categoryName']}\n") | |
return "\n".join(result) | |
def save_to_txt(data: dict, filename: str = "datasets_summary.txt"): | |
with open(filename, "w", encoding="utf-8") as f: | |
json.dump(data, f, indent=4, ensure_ascii=False) | |
print(f"β Dataset info saved to {filename}") | |
def check_and_create_collection(userLoginId: str, base_url="https://srivatsavdamaraju-accusaga-bot.hf.space") -> bool: | |
get_url = f"{base_url}/collections/" | |
headers = {'accept': 'application/json'} | |
try: | |
response = requests.get(get_url, headers=headers) | |
response.raise_for_status() | |
data = response.json() | |
collections = data if isinstance(data, list) else data.get("collections", []) | |
collection_names = [coll.get("name") for coll in collections if isinstance(coll, dict)] | |
if str(userLoginId) in collection_names: | |
print(f"Collection named '{userLoginId}' found.") | |
return True | |
else: | |
print("Collection not found. Creating a new one...") | |
post_data = { | |
"name": str(userLoginId), | |
"vector_size": 3072, | |
"distance": "Cosine" | |
} | |
post_response = requests.post(get_url, headers={ | |
'accept': 'application/json', | |
'Content-Type': 'application/json' | |
}, json=post_data) | |
post_response.raise_for_status() | |
print(f"β Collection created: {post_response.json()}") | |
return True | |
except requests.exceptions.RequestException as e: | |
print(f"Error calling collection API: {e}") | |
return False | |
def ingest_file_to_collection(collection_name: str, file_path: str, base_url="https://srivatsavdamaraju-accusaga-bot.hf.space") -> bool: | |
url = f"{base_url}/ingest/{collection_name}" | |
headers = {'accept': 'application/json'} | |
try: | |
with open(file_path, 'rb') as f: | |
files = {'file': (file_path, f, 'text/plain')} | |
response = requests.post(url, headers=headers, files=files) | |
response.raise_for_status() | |
print(f"β File '{file_path}' ingested into '{collection_name}'.") | |
print("Response:", response.json()) | |
return True | |
except FileNotFoundError: | |
print(f"β File not found: {file_path}") | |
except requests.exceptions.HTTPError as http_err: | |
print(f"β HTTP error: {http_err}") | |
print("Response content:", response.text) | |
except requests.exceptions.RequestException as e: | |
print(f"β Request exception: {e}") | |
return False | |
# === MAIN CHAT AND AGENT ENDPOINTS === | |
def chat_with_bot(query: BotQuery): | |
"""Main bot endpoint with session management and agent tools""" | |
try: | |
# Set global auth context for tools | |
global _current_user_id, _current_org_id, _current_auth_token | |
_current_user_id = query.userLoginId | |
_current_org_id = query.orgId | |
_current_auth_token = query.auth_token | |
session_id = query.session_id | |
# Create new session if not provided | |
if not session_id: | |
session_data = create_session(query.userLoginId, query.orgId, query.auth_token) | |
session_id = session_data["session_id"] | |
else: | |
# Verify existing session | |
get_session(session_id) | |
file_path = "datasets_summary.txt" # The file created earlier with dataset info | |
# Step 1: Check/create collection | |
success = check_and_create_collection(_current_user_id) | |
# Step 2: If collection ready, ingest the file | |
# Only ingest if user has <= 1 session | |
if success: | |
if should_ingest_data(_current_user_id): | |
print("User has 1 or fewer sessions. Ingesting data...") | |
ingest_file_to_collection(_current_user_id, file_path) | |
else: | |
print("User has more than 1 session. Skipping ingestion.") | |
else: | |
print("Could not create or find the collection. Aborting ingestion.") | |
# Add user message to session | |
user_message_id = add_message_to_session(session_id, "user", query.message) | |
# Create agent with session memory | |
agent_executor, memory = create_agent_with_session_memory(session_id) | |
# Use the agent to process the query | |
result = agent_executor.invoke({"input": query.message}) | |
# Add AI response to session | |
ai_message_id = add_message_to_session(session_id, "assistant", result["output"]) | |
# Update session memory in Redis | |
updated_messages = [] | |
for message in memory.chat_memory.messages: | |
if hasattr(message, 'content'): | |
role = "user" if message.__class__.__name__ == "HumanMessage" else "assistant" | |
updated_messages.append({ | |
"role": role, | |
"message": message.content, | |
"timestamp": datetime.now().isoformat() | |
}) | |
update_session_memory(session_id, updated_messages) | |
# Update session title after first user message | |
update_session_title(session_id) | |
# Clear auth context after use | |
_current_user_id = None | |
_current_org_id = None | |
_current_auth_token = None | |
return { | |
"session_id": session_id, | |
"user_message_id": user_message_id, | |
"ai_message_id": ai_message_id, | |
"message": query.message, | |
"answer": result["output"], | |
"userLoginId": query.userLoginId, | |
"agent_used": True | |
} | |
except Exception as e: | |
# Clear auth context on error | |
_current_user_id = None | |
_current_org_id = None | |
_current_auth_token = None | |
raise HTTPException(status_code=500, detail=f"Error processing chat: {str(e)}") | |
# === DIRECT TOOL ENDPOINTS === | |
def chat_documents_only(query: Query): | |
"""Direct document search without agent""" | |
try: | |
result = search_documents(query.message) | |
return { | |
"message": query.message, | |
"answer": result, | |
"tool_used": "document_search" | |
} | |
except Exception as e: | |
return { | |
"message": query.message, | |
"answer": f"An error occurred: {str(e)}", | |
"tool_used": "document_search" | |
} | |
def list_projects(request: ProjectRequest): | |
"""Direct project listing without agent""" | |
try: | |
# Use the provided auth token and userLoginId | |
encoded_token = get_encoded_auth_token(request.userLoginId, request.auth_token) | |
# Fetch projects | |
data = fetch_user_projects(request.userLoginId, request.orgId, encoded_token) | |
# Format and return the project list | |
formatted = format_project_response(data) | |
return { | |
"projects": formatted, | |
"tool_used": "project_list" | |
} | |
except Exception as e: | |
return { | |
"error": f"An error occurred: {str(e)}", | |
"tool_used": "project_list" | |
} | |
def chat_with_pandas_agent(request: PandasAgentQuery): | |
"""Direct pandas AI agent endpoint for data analysis""" | |
try: | |
result = pandas_agent(request.filepath, request.query) | |
return { | |
"filepath": request.filepath, | |
"query": request.query, | |
"answer": result, | |
"tool_used": "pandas_agent", | |
"timestamp": datetime.now().isoformat() | |
} | |
except Exception as e: | |
error_msg = f"An error occurred: {str(e)}" | |
return { | |
"filepath": request.filepath, | |
"query": request.query, | |
"answer": error_msg, | |
"tool_used": "pandas_agent", | |
"error": True, | |
"timestamp": datetime.now().isoformat() | |
} | |
def delete_user_completely(user_login_id: int): | |
BASE_URL = "https://srivatsavdamaraju-accusaga-bot.hf.space" | |
headers = { | |
"accept": "application/json" | |
} | |
# Step 1: Delete Collection | |
collection_url = f"{BASE_URL}/collections/{user_login_id}" | |
collection_response = requests.delete(collection_url, headers=headers) | |
if collection_response.status_code != 200: | |
raise HTTPException( | |
status_code=collection_response.status_code, | |
detail=f"Failed to delete collection. Response: {collection_response.text}" | |
) | |
# Step 2: Get Sessions | |
sessions_url = f"{BASE_URL}/sessions?userLoginId={user_login_id}" | |
sessions_response = requests.get(sessions_url, headers=headers) | |
if sessions_response.status_code != 200: | |
raise HTTPException( | |
status_code=sessions_response.status_code, | |
detail=f"Failed to fetch sessions. Response: {sessions_response.text}" | |
) | |
sessions_data = sessions_response.json() | |
sessions = sessions_data.get("sessions", []) | |
deleted_sessions = [] | |
failed_sessions = [] | |
# Step 3: Delete Each Session | |
for session in sessions: | |
session_id = session["session_id"] | |
delete_session_url = f"{BASE_URL}/sessions/{session_id}" | |
delete_session_response = requests.delete(delete_session_url, headers=headers) | |
if delete_session_response.status_code == 200: | |
deleted_sessions.append(session_id) | |
else: | |
failed_sessions.append({ | |
"session_id": session_id, | |
"status_code": delete_session_response.status_code, | |
"error": delete_session_response.text | |
}) | |
return { | |
"user_login_id": user_login_id, | |
"collection_deleted": True, | |
"deleted_sessions": deleted_sessions, | |
"failed_sessions": failed_sessions | |
} | |
# === SYSTEM INFORMATION ENDPOINTS === | |
def redis_info(): | |
"""Get Redis connection information""" | |
try: | |
info = redis_client.info() | |
return { | |
"redis_connected": True, | |
"redis_version": info.get("redis_version"), | |
"used_memory": info.get("used_memory_human"), | |
"connected_clients": info.get("connected_clients"), | |
"total_keys": redis_client.dbsize() | |
} | |
except Exception as e: | |
return { | |
"redis_connected": False, | |
"error": str(e) | |
} | |
def qdrant_info(): | |
"""Get Qdrant connection information""" | |
try: | |
collections = qdrant_client.get_collections() | |
return { | |
"qdrant_connected": True, | |
"total_collections": len(collections.collections), | |
"collections": [col.name for col in collections.collections] | |
} | |
except Exception as e: | |
return { | |
"qdrant_connected": False, | |
"error": str(e) | |
} | |
def fetch_dataset_info_endpoint(request: DatasetInfoRequest): | |
"""Direct endpoint to fetch dataset info for a specific project""" | |
try: | |
dataset_info = get_dataset_info( | |
request.userLoginId, | |
request.orgId, | |
request.project_id, | |
request.userLoginId, | |
request.auth_token | |
) | |
return DatasetInfoResponse( | |
project_id=request.project_id, | |
dataset_info=dataset_info | |
) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error fetching dataset info: {str(e)}") | |
def setup_user_datasets_endpoint(request: ProjectRequest): | |
"""Direct endpoint to set up user datasets and collection""" | |
try: | |
result = fetch_and_ingest_user_datasets( | |
request.userLoginId, | |
request.orgId, | |
request.auth_token | |
) | |
return { | |
"userLoginId": request.userLoginId, | |
"collection_name": str(request.userLoginId), | |
**result | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error setting up user datasets: {str(e)}") | |
def health(): | |
"""System health check - Updated""" | |
try: | |
redis_client.ping() | |
redis_status = "connected" | |
except: | |
redis_status = "disconnected" | |
try: | |
qdrant_client.get_collections() | |
qdrant_status = "connected" | |
except: | |
qdrant_status = "disconnected" | |
return { | |
"status": "ok", | |
"tools": ["document_search", "project_list", "pandas_data_analysis", "dataset_management"], | |
"agent": "active", | |
"session_management": "enabled", | |
"dataset_management": "enabled", | |
"redis_status": redis_status, | |
"qdrant_status": qdrant_status, | |
"pandas_ai": "enabled", | |
"total_sessions": len(list(redis_client.scan_iter(match="session:*"))) if redis_status == "connected" else 0, | |
"collections_available": qdrant_status == "connected" | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
try: | |
uvicorn.run(app) | |
except KeyboardInterrupt: | |
print("\nπ Server stopped gracefully") | |
except Exception as e: | |
print(f"β Server error: {e}") | |
#bot10.py |