Spaces:
Running
Running
import os | |
from flask import Flask, request, jsonify, render_template | |
import google.generativeai as genai | |
# LangChain Community has the updated vector stores | |
from langchain_community.vectorstores import FAISS | |
from langchain_google_genai import GoogleGenerativeAIEmbeddings, ChatGoogleGenerativeAI | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain.memory import ConversationBufferMemory | |
from dotenv import load_dotenv | |
import logging | |
import re | |
from custom_prompt import get_custom_prompt | |
import time | |
# Configure logging | |
logging.basicConfig(level=logging.INFO, | |
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Load environment variables | |
load_dotenv() | |
app = Flask(__name__) | |
# Store app start time for uptime calculation | |
app_start_time = time.time() | |
# Initialize the environment - Check multiple possible env var names | |
GOOGLE_API_KEY = (os.getenv("GOOGLE_API_KEY") or | |
os.getenv("GEMINI_API_KEY") or | |
os.getenv("GOOGLE_GEMINI_API_KEY")) | |
if not GOOGLE_API_KEY or GOOGLE_API_KEY == "your_api_key_here": | |
logger.error("No valid GOOGLE_API_KEY found in environment variables") | |
print("⚠️ Please set your Gemini API key in the environment variables") | |
print("Supported env var names: GOOGLE_API_KEY, GEMINI_API_KEY, GOOGLE_GEMINI_API_KEY") | |
else: | |
genai.configure(api_key=GOOGLE_API_KEY) | |
logger.info("API key configured successfully") | |
# Global variables for the chain and memory | |
qa_chain = None | |
memory = None | |
def initialize_chatbot(): | |
global qa_chain, memory | |
logger.info("Initializing chatbot...") | |
# Initialize embeddings | |
try: | |
embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001") | |
logger.info("Embeddings initialized") | |
except Exception as e: | |
logger.error(f"Error initializing embeddings: {str(e)}") | |
return False | |
# Load the vector store | |
try: | |
vector_store = FAISS.load_local("faiss_index", embeddings, allow_dangerous_deserialization=True) | |
logger.info("Vector store loaded successfully!") | |
except Exception as e: | |
logger.error(f"Error loading vector store: {str(e)}") | |
print(f"⚠️ Error loading vector store: {str(e)}") | |
print("Make sure your 'faiss_index' folder is in the same directory as this script.") | |
return False | |
# Create memory | |
memory = ConversationBufferMemory( | |
memory_key="chat_history", | |
return_messages=True, | |
output_key="answer" | |
) | |
logger.info("Conversation memory initialized") | |
# Initialize the language model | |
try: | |
llm = ChatGoogleGenerativeAI( | |
model="gemini-2.0-flash", # Updated to a newer recommended model | |
temperature=0.2, | |
top_p=0.85, | |
google_api_key=GOOGLE_API_KEY | |
) | |
logger.info("Language model initialized") | |
except Exception as e: | |
logger.error(f"Error initializing language model: {str(e)}") | |
return False | |
# Create the conversation chain with the custom prompt | |
try: | |
retriever = vector_store.as_retriever(search_kwargs={"k": 3}) | |
qa_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=retriever, | |
memory=memory, | |
verbose=True, | |
return_source_documents=False, | |
combine_docs_chain_kwargs={"prompt": get_custom_prompt()}, | |
) | |
logger.info("QA chain created successfully") | |
except Exception as e: | |
logger.error(f"Error creating QA chain: {str(e)}") | |
return False | |
return True | |
# Function to format links as HTML anchor tags | |
def format_links_as_html(text): | |
# Detect markdown style links [text](url) | |
markdown_pattern = r'\[(.*?)\]\((https?://[^\s\)]+)\)' | |
text = re.sub(markdown_pattern, r'<a href="\2" target="_blank">\1</a>', text) | |
# Handle URLs in square brackets [url] | |
bracket_pattern = r'\[(https?://[^\s\]]+)\]' | |
text = re.sub(bracket_pattern, r'<a href="\1" target="_blank">\1</a>', text) | |
# Regular URL pattern - THIS IS THE FIX | |
# The previous pattern r'(https?://[^\s\])+)' was invalid. | |
url_pattern = r'(?<!href=")(https?://[^\s<]+)' | |
# Replace URLs with HTML anchor tags | |
text = re.sub(url_pattern, r'<a href="\1" target="_blank">\1</a>', text) | |
return text | |
# Function to properly escape asterisks for markdown rendering | |
def escape_markdown(text): | |
return re.sub(r'(?<!\*)\*(?!\*)', r'\\*', text) | |
# Function to format markdown and handle asterisks with proper line breaks | |
def format_markdown_with_breaks(text): | |
text = text.replace('\\*', '*') | |
text = re.sub(r'\*\*(.*?)\*\*', r'<strong>\1</strong>', text) | |
lines = text.split('\n') | |
formatted_lines = [] | |
for i, line in enumerate(lines): | |
stripped_line = line.strip() | |
if stripped_line.startswith('* '): | |
content = stripped_line[2:].strip() | |
# Use a bullet point character for lists | |
formatted_lines.append(f"<br>• {content}") | |
elif stripped_line.startswith('*'): | |
content = stripped_line[1:].strip() | |
formatted_lines.append(f"<br>• {content}") | |
else: | |
formatted_lines.append(line) | |
# Join the lines, but remove the initial <br> if it exists | |
result = '\n'.join(formatted_lines) | |
if result.startswith('<br>'): | |
result = result[4:] | |
return result | |
def home(): | |
return render_template('index.html') | |
def health(): | |
"""Standard health check endpoint for uptime monitoring""" | |
try: | |
current_time = time.time() | |
uptime_seconds = current_time - app_start_time | |
health_status = { | |
"status": "healthy", | |
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime()), | |
"uptime_seconds": round(uptime_seconds, 2), | |
"chatbot_initialized": qa_chain is not None | |
} | |
return jsonify(health_status), 200 | |
except Exception as e: | |
logger.error(f"Health check failed: {str(e)}") | |
return jsonify({ | |
"status": "unhealthy", "error": str(e), | |
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S UTC", time.gmtime()) | |
}), 500 | |
def ping(): | |
"""Simple ping endpoint for basic uptime checks""" | |
return "pong", 200 | |
def chat(): | |
global qa_chain | |
if qa_chain is None: | |
if not initialize_chatbot(): | |
return jsonify({"error": "Failed to initialize chatbot. Check server logs."}), 500 | |
data = request.json | |
user_message = data.get('message', '') | |
if not user_message: | |
return jsonify({"error": "No message provided"}), 400 | |
try: | |
logger.info(f"Processing user query: {user_message}") | |
# Use .invoke() instead of the deprecated __call__ method | |
result = qa_chain.invoke({"question": user_message}) | |
answer = result.get("answer", "I'm sorry, I couldn't generate a response.") | |
# Format the answer | |
answer = escape_markdown(answer) | |
answer = format_links_as_html(answer) | |
answer = format_markdown_with_breaks(answer) | |
logger.info("Query processed successfully") | |
return jsonify({"answer": answer}) | |
except Exception as e: | |
# Log the full traceback for better debugging | |
logger.exception(f"Error processing request: {str(e)}") | |
return jsonify({"error": f"An internal error occurred: {str(e)}"}), 500 | |
if __name__ == '__main__': | |
port = int(os.environ.get('PORT', 7860)) | |
app.run(host='0.0.0.0', port=port, debug=False) |