import os from langchain.document_loaders import TextLoader from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_google_genai import GoogleGenerativeAIEmbeddings from langchain_google_genai import ChatGoogleGenerativeAI api_key = os.getenv("GOOGLE_API_KEY") model_name = "gemini-1.5-flash" embedding_model = GoogleGenerativeAIEmbeddings(model="models/embedding-001", google_api_key=api_key) llm = ChatGoogleGenerativeAI(model=model_name, temperature=0.4, google_api_key=os.getenv("GOOGLE_API_KEY")) file_paths = [ './心靈迴路:技術的黎明與潛力.txt', './心靈迴路:社會的重塑與爭議之聲.txt', './永恆之夜:元網絡的覺醒與迴響者的行動.txt', './迴響者的哲學:超越數據的意識.txt', './迴響者的崛起:元網絡的低語.txt', ] def load_and_split_documents(file_paths, chunk_size=500, chunk_overlap=50): all_docs = [] text_splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap) for path in file_paths: loader = TextLoader(path, encoding="utf-8") documents = loader.load() for doc in documents: chunks = text_splitter.split_text(doc.page_content) for i, chunk in enumerate(chunks): all_docs.append({ 'doc_id': os.path.basename(path), 'chunk_id': i, 'text': chunk }) return all_docs def embed_chunks(docs): for doc in docs: emb = embedding_model.embed_query(doc['text']) doc['embedding'] = emb return docs if __name__ == "__main__": docs = load_and_split_documents(file_paths) docs_with_embeddings = embed_chunks(docs) from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams, PointStruct import uuid client = QdrantClient(":memory:") VECTOR_SIZE = 768 COLLECTION_NAME = "gemini_demo_collection" client.recreate_collection( collection_name=COLLECTION_NAME, vectors_config=VectorParams(size=VECTOR_SIZE, distance=Distance.COSINE) ) def insert_to_qdrant(docs_with_embeddings): points = [] for doc in docs_with_embeddings: point = PointStruct( id=str(uuid.uuid4()), vector=doc["embedding"], payload={ "text": doc["text"], "doc_id": doc["doc_id"], "chunk_id": doc["chunk_id"] } ) points.append(point) client.upsert(collection_name=COLLECTION_NAME, points=points) insert_to_qdrant(docs_with_embeddings) def hybrid_search(query, top_k=5): query_vector = embedding_model.embed_query(query) results = client.search( collection_name=COLLECTION_NAME, query_vector=query_vector, limit=top_k, with_payload=True ) return results from langchain.prompts import PromptTemplate from langchain_core.documents import Document from typing import List def rerank_chunks_with_llm(query: str, results: List[Document], llm, top_n=3): passages = "\n".join([f"{i+1}. {doc.payload['text']}" for i, doc in enumerate(results)]) template = """ 請根據下列使用者問題,從提供的段落中找出最相關的內容。 問題:{query} 段落列表: {passages} 請依照與問題的相關性進行排序,回傳最相關的前 {top_n} 段落的編號。 只需回傳一行,以逗號分隔的數字,例如:2,3,1 """ prompt_template = PromptTemplate( input_variables=["query", "passages", "top_n"], template=template, ) prompt = prompt_template.format(query=query, passages=passages, top_n=top_n) response = llm.invoke(prompt) return response def generate_answer_with_rag(query: str, reranked_docs: List, llm): context = "\n\n".join([doc.payload["text"] for doc in reranked_docs]) prompt = f""" 你是一個知識助手,請根據以下提供的資訊,回答使用者的問題。 問題: {query} 相關資訊: {context} 請以自然語言完整回答,若資料不足請誠實說明。 """ response = llm.invoke(prompt) return response import gradio as gr def respond(message, history, system_message, max_tokens, temperature, top_p): try: llm.temperature = temperature llm.max_output_tokens = max_tokens search_results = hybrid_search(message) rerank_response = rerank_chunks_with_llm(message, search_results, llm, top_n=3) reranked_indices = [int(i.strip()) - 1 for i in rerank_response.content.split(",") if i.strip().isdigit()] reranked_docs = [search_results[i] for i in reranked_indices] answer = generate_answer_with_rag(message, reranked_docs, llm) return answer.content except Exception as e: return f"[錯誤] {str(e)}" chat_interface = gr.ChatInterface( fn=respond, title="Chatbot問答系統 - RAG Demo", additional_inputs=[ gr.Textbox(value="你是個樂於助人的AI助手。", label="System Message"), gr.Slider(1, 2048, value=512, step=1, label="Max tokens"), gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"), gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), ], ) if __name__ == "__main__": chat_interface.launch()