import gradio as gr from huggingface_hub import InferenceClient from fastapi import FastAPI, HTTPException, Request from pydantic import BaseModel from typing import List from sentence_transformers import CrossEncoder, util from transformers import AutoTokenizer, AutoModelForCausalLM import torch import numpy as np from qa_vector_store import build_qa_vector_store, retrieve_and_rerank, generate_response_from_local_llm # 建立 FastAPI 應用 app = FastAPI() # 初始化模型和資料庫 model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" collection_name = model_name.split("/")[-1] cross_encoder_model = "cross-encoder/mmarco-mMiniLMv2-L12-H384-v1" tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True) llm_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True) # 構建向量資料庫 build_qa_vector_store(model_name, collection_name) # 輸入格式 class QueryInput(BaseModel): query: str top_k: int = 5 # 輸出格式 class SearchResult(BaseModel): text: str score: float # 搜尋+rerank API @app.post("/search", response_model=List[SearchResult]) def search(input: QueryInput): reranked = retrieve_and_rerank(input.query, model_name, collection_name, cross_encoder_model, score_threshold=0.5, search_top_k=20, rerank_top_k=input.top_k) # 如果沒有找到相關答案,則返回 404 錯誤 if not reranked: raise HTTPException(status_code=404, detail="找不到相關答案,請嘗試換個問題或降低門檻。") final_passages = [r[0] for r in reranked] # 使用 LLM 生成回答 answer = generate_response_from_local_llm(input.query, final_passages, tokenizer, llm_model, max_new_tokens=256) if not answer: raise HTTPException(status_code=404, detail="無法生成回答,請檢查輸入或模型設定。") return answer demo = gr.ChatInterface( respond, additional_inputs=[ gr.Textbox(value="You are a friendly Chatbot.", label="System message"), gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)", ), ], ) if __name__ == "__main__": demo.launch()