import os import json import gradio as gr import chromadb from llama_index.core import ( VectorStoreIndex, StorageContext, ServiceContext, download_loader, ) from llama_index.llms.mistralai import MistralAI from llama_index.embeddings.mistralai import MistralAIEmbedding from llama_index.vector_stores.chroma import ChromaVectorStore title = "Gaia Mistral 8x7b Chat RAG PDF Demo" description = "Example of an assistant with Gradio, RAG from PDF documents and Mistral AI via its API" placeholder = ( "Vous pouvez me posez une question sur ce contexte, appuyer sur Entrée pour valider" ) placeholder_url = "Extract text from this url" llm_model = "open-mixtral-8x7b" env_api_key = os.environ.get("MISTRAL_API_KEY") query_engine = None # Define LLMs llm = MistralAI(api_key=env_api_key, model=llm_model) embed_model = MistralAIEmbedding(model_name="mistral-embed", api_key=env_api_key) # create client and a new collection db = chromadb.PersistentClient(path="./chroma_db") chroma_collection = db.get_or_create_collection("quickstart") # set up ChromaVectorStore and load in data vector_store = ChromaVectorStore(chroma_collection=chroma_collection) storage_context = StorageContext.from_defaults(vector_store=vector_store) service_context = ServiceContext.from_defaults( chunk_size=1024, llm=llm, embed_model=embed_model ) PDFReader = download_loader("PDFReader") loader = PDFReader() index = VectorStoreIndex( [], service_context=service_context, storage_context=storage_context ) query_engine = index.as_query_engine(similarity_top_k=5) def get_documents_in_db(): print("Fetching documents in DB") docs = [] for item in chroma_collection.get(include=["metadatas"])["metadatas"]: docs.append(json.loads(item["_node_content"])["metadata"]["file_name"]) docs = list(set(docs)) print(f"Found {len(docs)} documents") out = "**List of files in db:**\n" for d in docs: out += " - " + d + "\n" return out def empty_db(): ids = chroma_collection.get()["ids"] chroma_collection.delete(ids) return get_documents_in_db() def load_file(file): documents = loader.load_data(file=file) for doc in documents: index.insert(doc) return ( gr.Textbox(visible=False), gr.Textbox(value=f"Document encoded ! You can ask questions", visible=True), get_documents_in_db(), ) def load_document(input_file): file_name = input_file.name.split("/")[-1] return gr.Textbox(value=f"Document loaded: {file_name}", visible=True) with gr.Blocks() as demo: gr.Markdown( """ # Welcome to Gaia Level 3 Demo Add a file before interacting with the Chat. This demo allows you to interact with a pdf file and then ask questions to Mistral APIs. Mistral will answer with the context extracted from your uploaded file. *The files will stay in the database unless there is 48h of inactivty or you re-build the space.* """ ) gr.Markdown(""" ### 1 / Extract data from PDF """) with gr.Row(): with gr.Column(): input_file = gr.File( label="Load a pdf", file_types=[".pdf"], file_count="single", type="filepath", interactive=True, ) file_msg = gr.Textbox( label="Loaded documents:", container=False, visible=False ) input_file.upload( fn=load_document, inputs=[ input_file, ], outputs=[file_msg], concurrency_limit=20, ) help_msg = gr.Markdown( value="Once the document is loaded, press the Encode button below to add it to the db." ) file_btn = gr.Button(value="Encode file ✅", interactive=True) btn_msg = gr.Textbox(container=False, visible=False) with gr.Row(): db_list = gr.Markdown(value=get_documents_in_db) delete_btn = gr.Button(value="Empty db 🗑️", interactive=True, scale=0) file_btn.click( load_file, inputs=[input_file], outputs=[file_msg, btn_msg, db_list], show_progress="full", ) delete_btn.click(empty_db, outputs=[db_list], show_progress="minimal") gr.Markdown(""" ### 2 / Ask a question about this context """) chatbot = gr.Chatbot() msg = gr.Textbox(placeholder=placeholder) clear = gr.ClearButton([msg, chatbot]) def respond(message, chat_history): response = query_engine.query(message) chat_history.append((message, str(response))) return chat_history msg.submit(respond, [msg, chatbot], [chatbot]) demo.title = title demo.launch()