|
|
|
import gradio as gr |
|
import torch |
|
import re |
|
from sentence_transformers import SentenceTransformer, util |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from peft import PeftModel |
|
|
|
|
|
embedder = SentenceTransformer("keepitreal/vietnamese-sbert") |
|
|
|
|
|
device = torch.device("cpu") |
|
print("✅ Using device:", device) |
|
|
|
|
|
model_name = "vanhai123/vietnamese-ecom-chatbot" |
|
try: |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
base_model = AutoModelForCausalLM.from_pretrained( |
|
"NlpHUST/gpt2-vietnamese", torch_dtype=torch.float32 |
|
).to(device) |
|
base_model.resize_token_embeddings(len(tokenizer)) |
|
model = PeftModel.from_pretrained(base_model, model_name).to(device) |
|
print("Model and tokenizer loaded successfully!") |
|
except Exception as e: |
|
print(f"Error loading model or tokenizer: {str(e)}") |
|
raise |
|
def load_qa_from_file(path="examples.txt"): |
|
qa_pairs = [] |
|
try: |
|
with open(path, "r", encoding="utf-8") as file: |
|
content = file.read() |
|
blocks = content.split("<|human|>") |
|
for block in blocks: |
|
if "Hỏi:" in block and "<|assistant|>" in block: |
|
q = re.search(r"Hỏi:(.*)", block) |
|
a = re.search(r"<\|assistant\|>(.*)", block, re.DOTALL) |
|
if q and a: |
|
question = q.group(1).strip() |
|
answer_block = a.group(1).strip() |
|
for line in answer_block.splitlines(): |
|
if "**Phản hồi**" in line: |
|
answer = line.split("**Phản hồi**:")[-1].strip() |
|
qa_pairs.append({"q": question, "a": answer}) |
|
break |
|
except Exception as e: |
|
print(f"Lỗi đọc file: {e}") |
|
return qa_pairs |
|
|
|
qa_data = load_qa_from_file("examples.txt") |
|
questions = [qa["q"] for qa in qa_data] |
|
embeddings = embedder.encode(questions, convert_to_tensor=True) |
|
|
|
|
|
def build_prompt(question): |
|
try: |
|
with open("examples.txt", "r", encoding="utf-8") as file: |
|
example_block = "".join(file.readlines()[:30]) |
|
except: |
|
example_block = "<|system|>Bạn là một trợ lý thương mại điện tử chuyên nghiệp tại Việt Nam." |
|
return example_block + f"\n<|human|>Hỏi: {question}\n<|assistant|>" |
|
|
|
|
|
|
|
def generate_with_model(question): |
|
prompt = build_prompt(question) |
|
inputs = tokenizer(prompt, return_tensors="pt").to(device) |
|
input_len = inputs["input_ids"].shape[-1] |
|
|
|
with torch.no_grad(): |
|
output = model.generate( |
|
**inputs, |
|
max_new_tokens=120, |
|
temperature=0.6, |
|
top_p=0.9, |
|
do_sample=True, |
|
repetition_penalty=1.15, |
|
no_repeat_ngram_size=3, |
|
pad_token_id=tokenizer.pad_token_id, |
|
eos_token_id=tokenizer.eos_token_id, |
|
) |
|
|
|
output_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip() |
|
lines = [line.strip() for line in output_text.splitlines() if line.strip()] |
|
for line in lines: |
|
if "**Phản hồi**" in line: |
|
return line.split("**Phản hồi**:")[-1].strip() |
|
return None |
|
|
|
def semantic_fallback(question): |
|
query_embedding = embedder.encode(question, convert_to_tensor=True) |
|
cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0] |
|
top_idx = torch.argmax(cos_scores).item() |
|
top_score = cos_scores[top_idx].item() |
|
if top_score >= 0.75: |
|
return qa_data[top_idx]["a"] |
|
return "Vui lòng liên hệ CSKH để được hỗ trợ!" |
|
|
|
def answer_question(user_question): |
|
response = generate_with_model(user_question) |
|
if response and len(response) > 30: |
|
return response |
|
return semantic_fallback(user_question) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=answer_question, |
|
inputs=gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn..."), |
|
outputs="text", |
|
title="Vietnamese E-commerce Chatbot", |
|
description="Trợ lý AI thương mại điện tử: Trả lời từ mô hình ngôn ngữ hoặc tra cứu dữ liệu câu hỏi.", |
|
examples=[ |
|
["Tôi muốn kiểm tra đơn hàng"], |
|
["Có giảm giá khi mua số lượng lớn không?"], |
|
["Tôi muốn trả hàng vì sản phẩm lỗi"], |
|
["Tư vấn laptop cho dân văn phòng"] |
|
] |
|
) |
|
|
|
interface.launch() |
|
|