vanhai123's picture
Update app.py
e19aff3 verified
raw
history blame
4.6 kB
# app.py
import gradio as gr
import torch
import re
from sentence_transformers import SentenceTransformer, util
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
# === Load model embedding
embedder = SentenceTransformer("keepitreal/vietnamese-sbert")
# === Thiết bị
device = torch.device("cpu")
print("✅ Using device:", device)
# === Load mô hình sinh phản hồi
model_name = "vanhai123/vietnamese-ecom-chatbot"
try:
tokenizer = AutoTokenizer.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(
"NlpHUST/gpt2-vietnamese", torch_dtype=torch.float32
).to(device)
base_model.resize_token_embeddings(len(tokenizer))
model = PeftModel.from_pretrained(base_model, model_name).to(device)
print("Model and tokenizer loaded successfully!")
except Exception as e:
print(f"Error loading model or tokenizer: {str(e)}")
raise
def load_qa_from_file(path="examples.txt"):
qa_pairs = []
try:
with open(path, "r", encoding="utf-8") as file:
content = file.read()
blocks = content.split("<|human|>")
for block in blocks:
if "Hỏi:" in block and "<|assistant|>" in block:
q = re.search(r"Hỏi:(.*)", block)
a = re.search(r"<\|assistant\|>(.*)", block, re.DOTALL)
if q and a:
question = q.group(1).strip()
answer_block = a.group(1).strip()
for line in answer_block.splitlines():
if "**Phản hồi**" in line:
answer = line.split("**Phản hồi**:")[-1].strip()
qa_pairs.append({"q": question, "a": answer})
break
except Exception as e:
print(f"Lỗi đọc file: {e}")
return qa_pairs
qa_data = load_qa_from_file("examples.txt")
questions = [qa["q"] for qa in qa_data]
embeddings = embedder.encode(questions, convert_to_tensor=True)
# === Prompt builder
def build_prompt(question):
try:
with open("examples.txt", "r", encoding="utf-8") as file:
example_block = "".join(file.readlines()[:30])
except:
example_block = "<|system|>Bạn là một trợ lý thương mại điện tử chuyên nghiệp tại Việt Nam."
return example_block + f"\n<|human|>Hỏi: {question}\n<|assistant|>"
# === Sinh phản hồi từ mô hình
def generate_with_model(question):
prompt = build_prompt(question)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=120,
temperature=0.6,
top_p=0.9,
do_sample=True,
repetition_penalty=1.15,
no_repeat_ngram_size=3,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
output_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip()
lines = [line.strip() for line in output_text.splitlines() if line.strip()]
for line in lines:
if "**Phản hồi**" in line:
return line.split("**Phản hồi**:")[-1].strip()
return None
def semantic_fallback(question):
query_embedding = embedder.encode(question, convert_to_tensor=True)
cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
top_idx = torch.argmax(cos_scores).item()
top_score = cos_scores[top_idx].item()
if top_score >= 0.75:
return qa_data[top_idx]["a"]
return "Vui lòng liên hệ CSKH để được hỗ trợ!"
def answer_question(user_question):
response = generate_with_model(user_question)
if response and len(response) > 30:
return response
return semantic_fallback(user_question)
# === Giao diện Gradio
interface = gr.Interface(
fn=answer_question,
inputs=gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn..."),
outputs="text",
title="Vietnamese E-commerce Chatbot",
description="Trợ lý AI thương mại điện tử: Trả lời từ mô hình ngôn ngữ hoặc tra cứu dữ liệu câu hỏi.",
examples=[
["Tôi muốn kiểm tra đơn hàng"],
["Có giảm giá khi mua số lượng lớn không?"],
["Tôi muốn trả hàng vì sản phẩm lỗi"],
["Tư vấn laptop cho dân văn phòng"]
]
)
interface.launch()