|
import os |
|
import json |
|
import time |
|
import logging |
|
import threading |
|
import gradio as gr |
|
from datetime import datetime |
|
from typing import Any, List, Dict |
|
import warnings |
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
from langchain_neo4j import Neo4jVector, Neo4jChatMessageHistory, Neo4jGraph |
|
from langchain_neo4j import GraphCypherQAChain |
|
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnableBranch |
|
from langchain.retrievers import ContextualCompressionRetriever |
|
from langchain_community.document_transformers import EmbeddingsRedundantFilter |
|
from langchain.retrievers.document_compressors import EmbeddingsFilter, DocumentCompressorPipeline |
|
from langchain_text_splitters import TokenTextSplitter |
|
from langchain_core.messages import HumanMessage, AIMessage |
|
from langchain_community.chat_message_histories import ChatMessageHistory |
|
|
|
|
|
from langchain_openai import ChatOpenAI, OpenAIEmbeddings |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
CHAT_VECTOR_MODE = "vector" |
|
CHAT_GRAPH_MODE = "graph" |
|
CHAT_FULLTEXT_MODE = "fulltext" |
|
CHAT_HYBRID_MODE = "hybrid" |
|
CHAT_COMPREHENSIVE_MODE = "comprehensive" |
|
|
|
class Neo4jRAGSystem: |
|
def __init__(self, openai_api_key: str = None): |
|
|
|
self.api_key = openai_api_key or os.getenv("OPENAI_API_KEY") |
|
|
|
if not self.api_key: |
|
raise ValueError("OpenAI API key is required!") |
|
|
|
|
|
try: |
|
self.graph = Neo4jGraph( |
|
url="neo4j+s://418790fe.databases.neo4j.io", |
|
username="neo4j", |
|
password="nL2O2TcOySTueabSbk4scu64DJxLYgybyWTQjBHtYnY" |
|
) |
|
logger.info("β
Neo4j connection established") |
|
except Exception as e: |
|
logger.error(f"β Neo4j connection failed: {e}") |
|
raise |
|
|
|
|
|
try: |
|
self.llm = ChatOpenAI( |
|
model="gpt-3.5-turbo", |
|
temperature=0.1, |
|
openai_api_key=self.api_key |
|
) |
|
|
|
self.embeddings = OpenAIEmbeddings( |
|
openai_api_key=self.api_key |
|
) |
|
logger.info("β
OpenAI models initialized") |
|
except Exception as e: |
|
logger.error(f"β OpenAI initialization failed: {e}") |
|
raise |
|
|
|
|
|
self.vector_store = None |
|
self.setup_vector_store() |
|
|
|
|
|
try: |
|
self.graph_qa = GraphCypherQAChain.from_llm( |
|
llm=self.llm, |
|
graph=self.graph, |
|
verbose=True |
|
) |
|
logger.info("β
Graph QA chain initialized") |
|
except Exception as e: |
|
logger.error(f"β Graph QA initialization failed: {e}") |
|
self.graph_qa = None |
|
|
|
def setup_vector_store(self): |
|
"""Setup vector store handling dimension mismatch""" |
|
try: |
|
|
|
result = self.graph.query(""" |
|
SHOW INDEXES YIELD name, labelsOrTypes, properties, type |
|
WHERE type = "VECTOR" |
|
RETURN name, labelsOrTypes, properties |
|
""") |
|
|
|
vector_index_name = "vector" |
|
|
|
if result: |
|
logger.info(f"Found existing vector indexes: {result}") |
|
|
|
try: |
|
self.graph.query(f"DROP INDEX {vector_index_name} IF EXISTS") |
|
logger.info(f"Dropped existing vector index: {vector_index_name}") |
|
except Exception as e: |
|
logger.warning(f"Could not drop index {vector_index_name}: {e}") |
|
|
|
|
|
self.vector_store = Neo4jVector.from_existing_graph( |
|
embedding=self.embeddings, |
|
graph=self.graph, |
|
node_label="Document", |
|
text_node_properties=["content", "text", "title"], |
|
embedding_node_property="embedding_openai", |
|
index_name="openai_vector_index", |
|
) |
|
logger.info("β
Vector store initialized successfully") |
|
|
|
except Exception as e: |
|
logger.error(f"β Vector store initialization failed: {e}") |
|
|
|
self.vector_store = None |
|
logger.info("Continuing without vector search capabilities") |
|
|
|
def get_chat_mode_settings(self, mode: str) -> Dict[str, Any]: |
|
"""Get settings for different chat modes""" |
|
modes = { |
|
CHAT_VECTOR_MODE: { |
|
"mode": "vector", |
|
"document_filter": True, |
|
"use_vector_search": True, |
|
"use_graph_search": False, |
|
"use_fulltext_search": False, |
|
"description": "Vector similarity search using embeddings" |
|
}, |
|
CHAT_GRAPH_MODE: { |
|
"mode": "graph", |
|
"document_filter": False, |
|
"use_vector_search": False, |
|
"use_graph_search": True, |
|
"use_fulltext_search": False, |
|
"description": "Graph-based knowledge retrieval using relationships" |
|
}, |
|
CHAT_FULLTEXT_MODE: { |
|
"mode": "fulltext", |
|
"document_filter": True, |
|
"use_vector_search": False, |
|
"use_graph_search": False, |
|
"use_fulltext_search": True, |
|
"description": "Full-text keyword search" |
|
}, |
|
CHAT_HYBRID_MODE: { |
|
"mode": "hybrid", |
|
"document_filter": True, |
|
"use_vector_search": True, |
|
"use_graph_search": True, |
|
"use_fulltext_search": False, |
|
"description": "Combined vector and graph search" |
|
}, |
|
CHAT_COMPREHENSIVE_MODE: { |
|
"mode": "comprehensive", |
|
"document_filter": True, |
|
"use_vector_search": True, |
|
"use_graph_search": True, |
|
"use_fulltext_search": True, |
|
"description": "Graph+Vector+Fulltext - Complete search using all methods" |
|
} |
|
} |
|
return modes.get(mode, modes[CHAT_GRAPH_MODE]) |
|
|
|
def create_neo4j_chat_message_history(self, session_id: str, write_access: bool = True): |
|
"""Create Neo4j chat message history""" |
|
try: |
|
return Neo4jChatMessageHistory( |
|
graph=self.graph, |
|
session_id=session_id, |
|
window=10 |
|
) |
|
except Exception as e: |
|
logger.error(f"Error creating chat history: {e}") |
|
return ChatMessageHistory() |
|
|
|
def process_graph_response(self, question: str, messages: List, history, session_id: str): |
|
"""Process graph-based response""" |
|
start_time = time.time() |
|
|
|
try: |
|
if not self.graph_qa: |
|
return { |
|
"session_id": session_id, |
|
"message": "Graph QA is not available. Please check your Neo4j setup.", |
|
"info": { |
|
"sources": [], |
|
"model": "gpt-3.5-turbo", |
|
"response_time": 0, |
|
"mode": "graph", |
|
}, |
|
"user": "chatbot" |
|
} |
|
|
|
|
|
response = self.graph_qa.invoke({"query": question}) |
|
|
|
|
|
ai_message = AIMessage(content=response["result"]) |
|
history.add_message(ai_message) |
|
|
|
response_time = time.time() - start_time |
|
|
|
return { |
|
"session_id": session_id, |
|
"message": response["result"], |
|
"info": { |
|
"sources": [], |
|
"model": "gpt-3.5-turbo", |
|
"nodedetails": [], |
|
"total_tokens": 0, |
|
"response_time": response_time, |
|
"mode": "graph", |
|
"entities": [], |
|
"metric_details": [], |
|
"cypher_query": response.get("intermediate_steps", []) |
|
}, |
|
"user": "chatbot" |
|
} |
|
except Exception as e: |
|
logger.error(f"Error in graph response: {e}") |
|
return { |
|
"session_id": session_id, |
|
"message": f"I encountered an issue processing your question. Let me try a direct approach: {question}", |
|
"info": { |
|
"sources": [], |
|
"model": "gpt-3.5-turbo", |
|
"response_time": 0, |
|
"mode": "graph", |
|
}, |
|
"user": "chatbot" |
|
} |
|
|
|
def fulltext_search(self, question: str, document_names: List[str] = None, limit: int = 5): |
|
"""Perform fulltext search on Neo4j""" |
|
try: |
|
|
|
self.ensure_fulltext_index() |
|
|
|
|
|
if document_names: |
|
search_query = """ |
|
CALL db.index.fulltext.queryNodes('fulltext_content', $question) |
|
YIELD node, score |
|
WHERE any(doc IN $doc_names WHERE node.source CONTAINS doc) |
|
RETURN node.content as content, node.text as text, node.title as title, |
|
node.source as source, score, labels(node) as labels |
|
ORDER BY score DESC |
|
LIMIT $limit |
|
""" |
|
params = {"question": question, "doc_names": document_names, "limit": limit} |
|
else: |
|
search_query = """ |
|
CALL db.index.fulltext.queryNodes('fulltext_content', $question) |
|
YIELD node, score |
|
RETURN node.content as content, node.text as text, node.title as title, |
|
node.source as source, score, labels(node) as labels |
|
ORDER BY score DESC |
|
LIMIT $limit |
|
""" |
|
params = {"question": question, "limit": limit} |
|
|
|
results = self.graph.query(search_query, params) |
|
|
|
context_parts = [] |
|
sources = [] |
|
|
|
for result in results: |
|
content = result.get("content") or result.get("text") or "" |
|
if content: |
|
context_parts.append(content) |
|
sources.append({ |
|
"source": result.get("source", "Fulltext Search"), |
|
"title": result.get("title", "Unknown"), |
|
"score": result.get("score", 0), |
|
"search_type": "fulltext" |
|
}) |
|
|
|
return "\n\n".join(context_parts), sources |
|
|
|
except Exception as e: |
|
logger.warning(f"Fulltext search failed: {e}") |
|
return "", [] |
|
|
|
def ensure_fulltext_index(self): |
|
"""Ensure fulltext index exists""" |
|
try: |
|
|
|
check_query = """ |
|
SHOW INDEXES YIELD name, type |
|
WHERE name = 'fulltext_content' AND type = 'FULLTEXT' |
|
RETURN count(*) as count |
|
""" |
|
result = self.graph.query(check_query) |
|
|
|
if not result or result[0]["count"] == 0: |
|
|
|
create_index_query = """ |
|
CREATE FULLTEXT INDEX fulltext_content IF NOT EXISTS |
|
FOR (n:Document|Chunk|Text) |
|
ON EACH [n.content, n.text, n.title] |
|
""" |
|
self.graph.query(create_index_query) |
|
logger.info("β
Created fulltext index") |
|
else: |
|
logger.info("β
Fulltext index already exists") |
|
|
|
except Exception as e: |
|
logger.warning(f"Could not create fulltext index: {e}") |
|
|
|
def vector_search(self, question: str, document_names: List[str] = None, k: int = 5): |
|
"""Perform vector similarity search""" |
|
try: |
|
if not self.vector_store: |
|
return "", [] |
|
|
|
search_kwargs = {"k": k} |
|
if document_names: |
|
search_kwargs["filter"] = {"document_name": {"$in": document_names}} |
|
|
|
retriever = self.vector_store.as_retriever( |
|
search_type="similarity", |
|
search_kwargs=search_kwargs |
|
) |
|
|
|
relevant_docs = retriever.invoke(question) |
|
|
|
context = "\n\n".join([doc.page_content for doc in relevant_docs]) |
|
sources = [ |
|
{ |
|
"source": doc.metadata.get("source", "Vector Search"), |
|
"content": doc.page_content[:200] + "..." if len(doc.page_content) > 200 else doc.page_content, |
|
"search_type": "vector" |
|
} |
|
for doc in relevant_docs |
|
] |
|
|
|
return context, sources |
|
|
|
except Exception as e: |
|
logger.warning(f"Vector search failed: {e}") |
|
return "", [] |
|
|
|
def graph_search(self, question: str, limit: int = 5): |
|
"""Perform graph-based search using relationships""" |
|
try: |
|
|
|
search_query = """ |
|
// Find nodes matching the question |
|
MATCH (n) |
|
WHERE n.content IS NOT NULL |
|
AND (toLower(n.content) CONTAINS toLower($question) |
|
OR toLower(n.text) CONTAINS toLower($question) |
|
OR toLower(n.title) CONTAINS toLower($question)) |
|
|
|
// Get connected nodes for additional context |
|
OPTIONAL MATCH (n)-[r]-(connected) |
|
WHERE connected.content IS NOT NULL OR connected.text IS NOT NULL |
|
|
|
WITH n, collect(DISTINCT connected) as connected_nodes |
|
|
|
RETURN n.content as content, n.text as text, n.title as title, |
|
n.source as source, labels(n) as labels, |
|
[node IN connected_nodes | { |
|
content: coalesce(node.content, node.text), |
|
title: node.title, |
|
relationship: 'connected' |
|
}][0..3] as connected_info |
|
LIMIT $limit |
|
""" |
|
|
|
results = self.graph.query(search_query, {"question": question, "limit": limit}) |
|
|
|
context_parts = [] |
|
sources = [] |
|
|
|
for result in results: |
|
content = result.get("content") or result.get("text") or "" |
|
if content: |
|
context_parts.append(content) |
|
|
|
|
|
connected_info = result.get("connected_info", []) |
|
for conn in connected_info: |
|
if conn and conn.get("content"): |
|
context_parts.append(f"Related: {conn['content']}") |
|
|
|
sources.append({ |
|
"source": result.get("source", "Graph Search"), |
|
"title": result.get("title", "Unknown"), |
|
"labels": result.get("labels", []), |
|
"search_type": "graph", |
|
"connected_nodes": len(connected_info) |
|
}) |
|
|
|
return "\n\n".join(context_parts), sources |
|
|
|
except Exception as e: |
|
logger.warning(f"Graph search failed: {e}") |
|
return "", [] |
|
|
|
def comprehensive_search(self, question: str, document_names: List[str] = None): |
|
"""Perform comprehensive search combining vector, graph, and fulltext""" |
|
all_context = [] |
|
all_sources = [] |
|
search_results = {"vector": 0, "graph": 0, "fulltext": 0} |
|
|
|
|
|
try: |
|
vector_context, vector_sources = self.vector_search(question, document_names, k=3) |
|
if vector_context: |
|
all_context.append(f"=== SEMANTIC SIMILARITY RESULTS ===\n{vector_context}") |
|
all_sources.extend(vector_sources) |
|
search_results["vector"] = len(vector_sources) |
|
except Exception as e: |
|
logger.warning(f"Vector search in comprehensive mode failed: {e}") |
|
|
|
|
|
try: |
|
graph_context, graph_sources = self.graph_search(question, limit=3) |
|
if graph_context: |
|
all_context.append(f"=== GRAPH RELATIONSHIP RESULTS ===\n{graph_context}") |
|
all_sources.extend(graph_sources) |
|
search_results["graph"] = len(graph_sources) |
|
except Exception as e: |
|
logger.warning(f"Graph search in comprehensive mode failed: {e}") |
|
|
|
|
|
try: |
|
fulltext_context, fulltext_sources = self.fulltext_search(question, document_names, limit=3) |
|
if fulltext_context: |
|
all_context.append(f"=== KEYWORD SEARCH RESULTS ===\n{fulltext_context}") |
|
all_sources.extend(fulltext_sources) |
|
search_results["fulltext"] = len(fulltext_sources) |
|
except Exception as e: |
|
logger.warning(f"Fulltext search in comprehensive mode failed: {e}") |
|
|
|
|
|
unique_sources = [] |
|
seen_sources = set() |
|
for source in all_sources: |
|
source_key = f"{source.get('source', '')}-{source.get('title', '')}" |
|
if source_key not in seen_sources: |
|
seen_sources.add(source_key) |
|
unique_sources.append(source) |
|
|
|
final_context = "\n\n".join(all_context) |
|
|
|
|
|
search_summary = { |
|
"source": "Search Summary", |
|
"title": f"Combined Search Results", |
|
"search_type": "comprehensive", |
|
"results_breakdown": search_results, |
|
"total_sources": len(unique_sources) |
|
} |
|
unique_sources.insert(0, search_summary) |
|
|
|
return final_context, unique_sources |
|
|
|
def create_retriever(self, document_names: List[str], mode_settings: Dict): |
|
"""Create retriever based on mode and documents""" |
|
if mode_settings["use_vector_search"] and self.vector_store: |
|
try: |
|
base_retriever = self.vector_store.as_retriever( |
|
search_type="similarity", |
|
search_kwargs={"k": 5} |
|
) |
|
|
|
if document_names and mode_settings["document_filter"]: |
|
base_retriever.search_kwargs["filter"] = { |
|
"document_name": {"$in": document_names} |
|
} |
|
|
|
return base_retriever |
|
except Exception as e: |
|
logger.error(f"Vector retriever failed: {e}") |
|
return None |
|
return None |
|
|
|
def process_chat_response(self, messages: List, history, question: str, |
|
document_names: List[str], chat_mode_settings: Dict, session_id: str): |
|
"""Process chat response for vector/hybrid/comprehensive modes""" |
|
start_time = time.time() |
|
|
|
try: |
|
context = "" |
|
sources = [] |
|
search_method = "standard" |
|
|
|
|
|
if chat_mode_settings["mode"] == "comprehensive": |
|
context, sources = self.comprehensive_search(question, document_names) |
|
search_method = "comprehensive" |
|
elif chat_mode_settings["use_vector_search"] and chat_mode_settings["use_graph_search"]: |
|
|
|
vector_context, vector_sources = self.vector_search(question, document_names, k=3) |
|
graph_context, graph_sources = self.graph_search(question, limit=3) |
|
|
|
context_parts = [] |
|
if vector_context: |
|
context_parts.append(f"=== VECTOR SEARCH RESULTS ===\n{vector_context}") |
|
if graph_context: |
|
context_parts.append(f"=== GRAPH SEARCH RESULTS ===\n{graph_context}") |
|
|
|
context = "\n\n".join(context_parts) |
|
sources = vector_sources + graph_sources |
|
search_method = "hybrid" |
|
|
|
elif chat_mode_settings["use_vector_search"]: |
|
context, sources = self.vector_search(question, document_names) |
|
search_method = "vector" |
|
elif chat_mode_settings["use_fulltext_search"]: |
|
context, sources = self.fulltext_search(question, document_names) |
|
search_method = "fulltext" |
|
elif chat_mode_settings["use_graph_search"]: |
|
context, sources = self.graph_search(question) |
|
search_method = "graph" |
|
|
|
|
|
if not context: |
|
context, fallback_sources = self.fallback_search(question, document_names) |
|
sources.extend(fallback_sources) |
|
search_method += "_with_fallback" |
|
|
|
|
|
if chat_mode_settings["mode"] == "comprehensive": |
|
system_message = """You are a highly capable AI assistant with access to comprehensive search results from multiple sources: |
|
|
|
π **SEARCH METHODS USED:** |
|
- π **Vector Search**: Semantic similarity using embeddings |
|
- πΈοΈ **Graph Search**: Relationship-based knowledge traversal |
|
- π **Fulltext Search**: Keyword and phrase matching |
|
|
|
The context below contains results from all these search methods. Use this comprehensive information to provide the most accurate and complete answer possible. |
|
|
|
Context: |
|
{context} |
|
|
|
**Instructions:** |
|
- Synthesize information from all search methods |
|
- Prioritize accuracy and completeness |
|
- Mention when information comes from relationships vs. direct content |
|
- If conflicting information exists, note the discrepancies |
|
- Cite sources when possible""" |
|
else: |
|
system_message = """You are a helpful AI assistant. Use the following context to answer the user's question. |
|
|
|
Context: |
|
{context} |
|
|
|
If you cannot find the answer in the context, say so clearly. Always be accurate and helpful.""" |
|
|
|
prompt_template = ChatPromptTemplate.from_messages([ |
|
("system", system_message), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("human", "{question}") |
|
]) |
|
|
|
|
|
chat_history = messages[:-1] |
|
|
|
|
|
chain = prompt_template | self.llm | StrOutputParser() |
|
|
|
if context: |
|
response = chain.invoke({ |
|
"context": context, |
|
"chat_history": chat_history, |
|
"question": question |
|
}) |
|
else: |
|
|
|
no_context_prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are a helpful AI assistant. The user is asking about information that may not be in the knowledge base. Answer based on your general knowledge while noting that you don't have specific context from their documents."), |
|
MessagesPlaceholder(variable_name="chat_history"), |
|
("human", "{question}") |
|
]) |
|
no_context_chain = no_context_prompt | self.llm | StrOutputParser() |
|
response = no_context_chain.invoke({ |
|
"chat_history": chat_history, |
|
"question": question |
|
}) |
|
|
|
|
|
ai_message = AIMessage(content=response) |
|
history.add_message(ai_message) |
|
|
|
response_time = time.time() - start_time |
|
|
|
return { |
|
"session_id": session_id, |
|
"message": response, |
|
"info": { |
|
"sources": sources, |
|
"model": "gpt-3.5-turbo", |
|
"nodedetails": [], |
|
"total_tokens": 0, |
|
"response_time": response_time, |
|
"mode": chat_mode_settings["mode"], |
|
"search_method": search_method, |
|
"entities": [], |
|
"metric_details": [], |
|
}, |
|
"user": "chatbot" |
|
} |
|
|
|
except Exception as e: |
|
logger.error(f"Error in chat response: {e}") |
|
return { |
|
"session_id": session_id, |
|
"message": f"I apologize, but I encountered an error processing your question. Please try again or rephrase your question.", |
|
"info": { |
|
"sources": [], |
|
"model": "gpt-3.5-turbo", |
|
"nodedetails": [], |
|
"total_tokens": 0, |
|
"response_time": 0, |
|
"mode": chat_mode_settings["mode"], |
|
"search_method": "error", |
|
"entities": [], |
|
"metric_details": [], |
|
}, |
|
"user": "chatbot" |
|
} |
|
|
|
def fallback_search(self, question: str, document_names: List[str] = None): |
|
"""Fallback search using direct Neo4j queries""" |
|
try: |
|
|
|
search_query = """ |
|
MATCH (n) |
|
WHERE n.content IS NOT NULL |
|
AND (toLower(n.content) CONTAINS toLower($question) |
|
OR toLower(n.text) CONTAINS toLower($question) |
|
OR toLower(n.title) CONTAINS toLower($question)) |
|
RETURN n.content as content, n.text as text, n.title as title, |
|
n.source as source, labels(n) as labels |
|
LIMIT 5 |
|
""" |
|
|
|
results = self.graph.query(search_query, {"question": question}) |
|
|
|
context_parts = [] |
|
sources = [] |
|
|
|
for result in results: |
|
content = result.get("content") or result.get("text") or "" |
|
if content: |
|
context_parts.append(content) |
|
sources.append({ |
|
"source": result.get("source", "Neo4j Graph"), |
|
"title": result.get("title", "Unknown"), |
|
"labels": result.get("labels", []), |
|
"search_type": "fallback" |
|
}) |
|
|
|
return "\n\n".join(context_parts), sources |
|
|
|
except Exception as e: |
|
logger.error(f"Fallback search failed: {e}") |
|
return "", [] |
|
|
|
def QA_RAG(self, question: str, document_names: str = "[]", |
|
session_id: str = None, mode: str = CHAT_GRAPH_MODE, write_access: bool = True): |
|
"""Main QA RAG function""" |
|
|
|
if session_id is None: |
|
session_id = f"session_{int(time.time())}" |
|
|
|
logging.info(f"Chat Mode: {mode}") |
|
|
|
|
|
history = self.create_neo4j_chat_message_history(session_id, write_access) |
|
messages = history.messages |
|
|
|
|
|
user_question = HumanMessage(content=question) |
|
messages.append(user_question) |
|
|
|
if mode == CHAT_GRAPH_MODE: |
|
result = self.process_graph_response(question, messages, history, session_id) |
|
else: |
|
chat_mode_settings = self.get_chat_mode_settings(mode=mode) |
|
|
|
|
|
try: |
|
document_names = list(map(str.strip, json.loads(document_names))) |
|
except: |
|
document_names = [] |
|
|
|
|
|
if document_names and not chat_mode_settings["document_filter"]: |
|
result = { |
|
"session_id": session_id, |
|
"message": "Please deselect all documents in the table before using this chat mode", |
|
"info": { |
|
"sources": [], |
|
"model": "gpt-3.5-turbo", |
|
"nodedetails": [], |
|
"total_tokens": 0, |
|
"response_time": 0, |
|
"mode": chat_mode_settings["mode"], |
|
"entities": [], |
|
"metric_details": [], |
|
}, |
|
"user": "chatbot" |
|
} |
|
else: |
|
result = self.process_chat_response( |
|
messages, history, question, document_names, chat_mode_settings, session_id |
|
) |
|
|
|
result["session_id"] = session_id |
|
return result |
|
|
|
def check_and_fix_vector_index(self): |
|
"""Check and fix vector index dimension issues""" |
|
try: |
|
|
|
result = self.graph.query(""" |
|
SHOW INDEXES YIELD name, labelsOrTypes, properties, type, options |
|
WHERE type = "VECTOR" |
|
RETURN name, labelsOrTypes, properties, options |
|
""") |
|
|
|
if result: |
|
logger.info("Existing vector indexes found:") |
|
for idx in result: |
|
logger.info(f" - {idx}") |
|
|
|
|
|
for idx in result: |
|
try: |
|
self.graph.query(f"DROP INDEX `{idx['name']}` IF EXISTS") |
|
logger.info(f"Dropped index: {idx['name']}") |
|
except Exception as e: |
|
logger.warning(f"Could not drop index {idx['name']}: {e}") |
|
|
|
return True |
|
except Exception as e: |
|
logger.error(f"Error checking vector indexes: {e}") |
|
return False |
|
|
|
def create_gradio_interface(): |
|
"""Create Gradio interface for the RAG system""" |
|
|
|
def initialize_system(api_key): |
|
"""Initialize the RAG system with API key""" |
|
try: |
|
if not api_key.strip(): |
|
return None, "β Please provide your OpenAI API Key" |
|
|
|
rag_system = Neo4jRAGSystem(openai_api_key=api_key.strip()) |
|
return rag_system, "β
System initialized successfully!" |
|
except Exception as e: |
|
logger.error(f"Initialization error: {e}") |
|
return None, f"β Initialization failed: {str(e)}" |
|
|
|
def query_rag(api_key, question, document_names, mode, session_id): |
|
"""Wrapper function for Gradio""" |
|
if not question.strip(): |
|
return "Please enter a question.", "", "Please provide a question" |
|
|
|
if not api_key.strip(): |
|
return "", "", "β Please provide your OpenAI API Key first" |
|
|
|
try: |
|
|
|
rag_system, init_message = initialize_system(api_key) |
|
if not rag_system: |
|
return "", "", init_message |
|
|
|
if not session_id.strip(): |
|
session_id = f"session_{int(time.time())}" |
|
|
|
result = rag_system.QA_RAG( |
|
question=question, |
|
document_names=document_names, |
|
session_id=session_id, |
|
mode=mode |
|
) |
|
|
|
|
|
response_text = result["message"] |
|
info_text = f"""**Response Info:** |
|
- Mode: {result['info']['mode']} |
|
- Response Time: {result['info']['response_time']:.2f}s |
|
- Sources Found: {len(result['info']['sources'])} |
|
- Session ID: {result['session_id']}""" |
|
|
|
if result['info']['sources']: |
|
info_text += "\n\n**Sources & Search Methods:**" |
|
for i, source in enumerate(result['info']['sources'], 1): |
|
source_name = source.get('source', 'Unknown') |
|
source_title = source.get('title', '') |
|
search_type = source.get('search_type', 'unknown') |
|
|
|
info_text += f"\n{i}. **{source_name}**" |
|
if source_title and source_title != 'Unknown': |
|
info_text += f" - {source_title}" |
|
if search_type != 'unknown': |
|
info_text += f" `({search_type})`" |
|
|
|
|
|
if source.get('results_breakdown'): |
|
breakdown = source['results_breakdown'] |
|
info_text += f"\n π Results: Vector({breakdown['vector']}), Graph({breakdown['graph']}), Fulltext({breakdown['fulltext']})" |
|
|
|
status = "β
Query completed successfully" |
|
|
|
return response_text, info_text, status |
|
|
|
except Exception as e: |
|
logger.error(f"Error in query_rag: {e}") |
|
return f"Error: {str(e)}", "", f"β Error: {str(e)}" |
|
|
|
|
|
with gr.Blocks(title="Neo4j RAG Query System", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown("# π€ Neo4j RAG Query System") |
|
gr.Markdown("Ask questions and get answers from your Neo4j knowledge graph using OpenAI.") |
|
|
|
|
|
with gr.Row(): |
|
api_key_input = gr.Textbox( |
|
label="OpenAI API Key", |
|
type="password", |
|
placeholder="sk-...", |
|
info="Your OpenAI API key (required for embeddings and LLM)" |
|
) |
|
|
|
with gr.Row(): |
|
status_output = gr.Textbox( |
|
label="System Status", |
|
value="β³ Please enter your OpenAI API key", |
|
interactive=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
question_input = gr.Textbox( |
|
label="Question", |
|
placeholder="Enter your question here...", |
|
lines=3 |
|
) |
|
|
|
with gr.Row(): |
|
mode_dropdown = gr.Dropdown( |
|
choices=[ |
|
CHAT_COMPREHENSIVE_MODE, |
|
CHAT_GRAPH_MODE, |
|
CHAT_VECTOR_MODE, |
|
CHAT_HYBRID_MODE, |
|
CHAT_FULLTEXT_MODE |
|
], |
|
value=CHAT_COMPREHENSIVE_MODE, |
|
label="Search Mode", |
|
info="Comprehensive mode uses Graph+Vector+Fulltext for best results" |
|
) |
|
|
|
session_input = gr.Textbox( |
|
label="Session ID", |
|
placeholder="Leave empty for auto-generated", |
|
value="" |
|
) |
|
|
|
document_input = gr.Textbox( |
|
label="Document Names (JSON array)", |
|
placeholder='["document1.pdf", "document2.txt"]', |
|
value="[]", |
|
info="Leave as [] to search all documents" |
|
) |
|
|
|
submit_btn = gr.Button("Submit Query", variant="primary") |
|
clear_btn = gr.Button("Clear", variant="secondary") |
|
|
|
with gr.Column(scale=3): |
|
response_output = gr.Textbox( |
|
label="Response", |
|
lines=12, |
|
interactive=False |
|
) |
|
|
|
info_output = gr.Markdown( |
|
label="Query Information", |
|
value="" |
|
) |
|
|
|
|
|
submit_btn.click( |
|
fn=query_rag, |
|
inputs=[api_key_input, question_input, document_input, mode_dropdown, session_input], |
|
outputs=[response_output, info_output, status_output] |
|
) |
|
|
|
clear_btn.click( |
|
fn=lambda: ("", "", "[]", "", "β³ Ready for next query"), |
|
outputs=[question_input, response_output, document_input, session_input, status_output] |
|
) |
|
|
|
|
|
question_input.submit( |
|
fn=query_rag, |
|
inputs=[api_key_input, question_input, document_input, mode_dropdown, session_input], |
|
outputs=[response_output, info_output, status_output] |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["What information do you have about machine learning?", "[]", "comprehensive"], |
|
["Tell me about the documents in the database", "[]", "graph"], |
|
["Search for specific keywords in the content", "[]", "fulltext"], |
|
["What are the main topics covered?", "[]", "vector"], |
|
["Find connections between different concepts", "[]", "hybrid"], |
|
], |
|
inputs=[question_input, document_input, mode_dropdown], |
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
## π **Search Mode Descriptions:** |
|
|
|
- **π― Comprehensive**: Uses Graph+Vector+Fulltext search for the most complete results |
|
- **πΈοΈ Graph**: Explores relationships and connections in your knowledge graph |
|
- **π§ Vector**: Semantic similarity search using AI embeddings |
|
- **π Hybrid**: Combines vector similarity with graph relationships |
|
- **π Fulltext**: Traditional keyword and phrase search |
|
""") |
|
|
|
return demo |
|
|
|
def main(): |
|
"""Main function to run the application""" |
|
try: |
|
print("π Starting Neo4j RAG System...") |
|
|
|
|
|
demo = create_gradio_interface() |
|
demo.launch( |
|
share=True, |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
debug=False, |
|
show_error=True |
|
) |
|
|
|
except Exception as e: |
|
print(f"β Error starting application: {e}") |
|
logger.error(f"Application startup error: {e}") |
|
|
|
|
|
def QA_RAG_standalone(question: str, document_names: str = "[]", |
|
session_id: str = None, mode: str = CHAT_GRAPH_MODE, |
|
write_access: bool = True, openai_api_key: str = None): |
|
"""Standalone QA RAG function for direct usage""" |
|
try: |
|
rag_system = Neo4jRAGSystem(openai_api_key=openai_api_key) |
|
return rag_system.QA_RAG(question, document_names, session_id, mode, write_access) |
|
except Exception as e: |
|
logger.error(f"Standalone query error: {e}") |
|
return { |
|
"session_id": session_id or f"session_{int(time.time())}", |
|
"message": f"Error: {str(e)}", |
|
"info": {"error": str(e)}, |
|
"user": "chatbot" |
|
} |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|
|
""" |
|
# requirements.txt content: |
|
langchain==0.1.0 |
|
langchain-neo4j==0.0.5 |
|
langchain-openai==0.0.8 |
|
langchain-community==0.0.13 |
|
gradio==4.15.0 |
|
python-dotenv==1.0.0 |
|
neo4j==5.16.0 |
|
openai==1.10.0 |
|
tiktoken==0.5.2 |
|
""" |