Spaces:
Running
Running
import gradio as gr | |
from huggingface_hub import InferenceClient | |
from sentence_transformers import CrossEncoder, util | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import numpy as np | |
from qa_vector_store import build_qa_vector_store, retrieve_and_rerank, generate_response_from_local_llm | |
# 初始化模型和資料庫 | |
model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" | |
collection_name = model_name.split("/")[-1] | |
cross_encoder_model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1") | |
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True) | |
llm_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True) | |
# tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) | |
# llm_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True) | |
# 構建向量資料庫 | |
build_qa_vector_store(model_name, collection_name) | |
def respond(message, history, system_message): | |
try: | |
# llm.temperature = temperature | |
# llm.max_output_tokens = max_tokens | |
# answer = search_and_generate(message) | |
# return answer.content | |
reranked = retrieve_and_rerank(message, model_name, collection_name, cross_encoder_model, score_threshold=0.5, search_top_k=20, rerank_top_k=5) | |
final_passages = [r[0] for r in reranked] | |
# 使用 LLM 生成回答 | |
answer = generate_response_from_local_llm(message, final_passages, tokenizer, llm_model, max_new_tokens=256) | |
return answer | |
except Exception as e: | |
return f"[錯誤] {str(e)}" | |
chat_interface = gr.ChatInterface( | |
fn=respond, | |
title="Chatbot問答系統", | |
additional_inputs=[ | |
gr.Textbox(value="你是個幫忙立法委員許智傑回答問題的助手。", label="System Message"), | |
# gr.Slider(1, 2048, value=512, step=1, label="Max tokens"), | |
# gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"), | |
# gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"), | |
], | |
) | |
if __name__ == "__main__": | |
chat_interface.launch() | |