Spaces:
Running
Running
File size: 2,175 Bytes
e24c2c0 e56d196 e24c2c0 e56d196 becf13f 92ac835 e24c2c0 e56d196 e24c2c0 5e68a8e d6cfea3 5e68a8e 1998909 b1bdc38 1f44489 d6cfea3 1f44489 5e68a8e 1f44489 e24c2c0 274edca 5e68a8e e24c2c0 1f44489 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import gradio as gr
from huggingface_hub import InferenceClient
from sentence_transformers import CrossEncoder, util
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
from qa_vector_store import build_qa_vector_store, retrieve_and_rerank, generate_response_from_local_llm
# 初始化模型和資料庫
model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
collection_name = model_name.split("/")[-1]
cross_encoder_model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True)
llm_model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct", trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
# llm_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)
# 構建向量資料庫
build_qa_vector_store(model_name, collection_name)
def respond(message, history, system_message):
try:
# llm.temperature = temperature
# llm.max_output_tokens = max_tokens
# answer = search_and_generate(message)
# return answer.content
reranked = retrieve_and_rerank(message, model_name, collection_name, cross_encoder_model, score_threshold=0.5, search_top_k=20, rerank_top_k=5)
final_passages = [r[0] for r in reranked]
# 使用 LLM 生成回答
answer = generate_response_from_local_llm(message, final_passages, tokenizer, llm_model, max_new_tokens=256)
return answer
except Exception as e:
return f"[錯誤] {str(e)}"
chat_interface = gr.ChatInterface(
fn=respond,
title="Chatbot問答系統",
additional_inputs=[
gr.Textbox(value="你是個幫忙立法委員許智傑回答問題的助手。", label="System Message"),
# gr.Slider(1, 2048, value=512, step=1, label="Max tokens"),
# gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
# gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-p"),
],
)
if __name__ == "__main__":
chat_interface.launch()
|