import gradio as gr import os import numpy as np import faiss from mistralai import Mistral api_key = os.getenv("MISTRAL_API_KEY") client = Mistral(api_key=api_key) # ============================================================================= # BASIC CHAT UI (Gradio Version) # ============================================================================= def run_mistral_basic(message, history): """Basic chat function for Gradio ChatInterface""" messages = [{"role": "user", "content": message}] chat_response = client.chat.complete( model="mistral-large-latest", messages=messages ) return chat_response.choices[0].message.content # Create basic chat interface basic_chat = gr.ChatInterface( fn=run_mistral_basic, title="Basic Mistral Chat", description="Chat with Mistral AI" ) # ============================================================================= # RAG UI (Gradio Version) # ============================================================================= # Global variable to store processed document processed_chunks = None faiss_index = None def get_text_embedding(input_text): """Get embeddings from Mistral""" embeddings_batch_response = client.embeddings.create( model="mistral-embed", inputs=[input_text] ) return embeddings_batch_response.data[0].embedding def process_document(file): """Process uploaded document and create FAISS index""" global processed_chunks, faiss_index if file is None: return "Please upload a text file first." try: # Read the file with open(file.name, 'r', encoding='utf-8') as f: text = f.read() # Split document into chunks chunk_size = 2048 processed_chunks = [text[i:i + chunk_size] for i in range(0, len(text), chunk_size)] # Create embeddings and FAISS index text_embeddings = np.array([get_text_embedding(chunk) for chunk in processed_chunks]) d = text_embeddings.shape[1] faiss_index = faiss.IndexFlatL2(d) faiss_index.add(text_embeddings.astype(np.float32)) return f"Document processed successfully! Split into {len(processed_chunks)} chunks." except Exception as e: return f"Error processing document: {str(e)}" def rag_chat(message, history): """RAG chat function for Gradio""" global processed_chunks, faiss_index if processed_chunks is None or faiss_index is None: return "Please upload and process a document first." try: # Create prompt template prompt_template = """ Context information is below. --------------------- {retrieved_chunk} --------------------- Given the context information and not prior knowledge, answer the query. Query: {question} Answer: """ # Get question embedding question_embedding = np.array([get_text_embedding(message)]) # Search for similar chunks D, I = faiss_index.search(question_embedding.astype(np.float32), k=2) retrieved_chunks = [processed_chunks[i] for i in I.tolist()[0]] # Generate response prompt = prompt_template.format( retrieved_chunk=retrieved_chunks, question=message ) messages = [{"role": "user", "content": prompt}] chat_response = client.chat.complete( model="mistral-large-latest", messages=messages ) return chat_response.choices[0].message.content except Exception as e: return f"Error generating response: {str(e)}" # ============================================================================= # GRADIO INTERFACES # ============================================================================= # Create RAG interface with file upload with gr.Blocks(title="RAG Chat with Mistral") as rag_interface: gr.Markdown("# RAG Chat Interface") gr.Markdown("Upload a text file and chat with its content!") with gr.Row(): file_upload = gr.File( label="Upload Text File", file_types=[".txt"], type="filepath" ) process_btn = gr.Button("Process Document", variant="primary") process_status = gr.Textbox( label="Processing Status", interactive=False, placeholder="Upload a file and click 'Process Document'" ) # Chat interface chatbot = gr.Chatbot(label="RAG Chat") msg = gr.Textbox( label="Your Message", placeholder="Ask questions about the uploaded document...", lines=2 ) with gr.Row(): submit_btn = gr.Button("Send", variant="primary") clear_btn = gr.Button("Clear Chat") # Event handlers process_btn.click( process_document, inputs=[file_upload], outputs=[process_status] ) def respond(message, chat_history): if not message.strip(): return "", chat_history # Add user message to history chat_history.append([message, None]) # Get bot response bot_response = rag_chat(message, chat_history) # Add bot response to history chat_history[-1][1] = bot_response return "", chat_history submit_btn.click( respond, inputs=[msg, chatbot], outputs=[msg, chatbot] ) msg.submit( respond, inputs=[msg, chatbot], outputs=[msg, chatbot] ) clear_btn.click(lambda: ([], ""), outputs=[chatbot, msg]) if __name__ == "__main__": rag_interface.launch(share=True)