Update app.py
Browse files
app.py
CHANGED
@@ -11,7 +11,7 @@ embedder = SentenceTransformer("keepitreal/vietnamese-sbert")
|
|
11 |
|
12 |
# === Thiết bị
|
13 |
device = torch.device("cpu")
|
14 |
-
print("Using device:", device)
|
15 |
|
16 |
# === Load mô hình sinh phản hồi
|
17 |
model_name = "vanhai123/vietnamese-ecom-chatbot"
|
@@ -22,9 +22,9 @@ try:
|
|
22 |
).to(device)
|
23 |
base_model.resize_token_embeddings(len(tokenizer))
|
24 |
model = PeftModel.from_pretrained(base_model, model_name).to(device)
|
25 |
-
print("
|
26 |
except Exception as e:
|
27 |
-
print(f"
|
28 |
raise
|
29 |
def load_qa_from_file(path="examples.txt"):
|
30 |
qa_pairs = []
|
@@ -45,14 +45,14 @@ def load_qa_from_file(path="examples.txt"):
|
|
45 |
qa_pairs.append({"q": question, "a": answer})
|
46 |
break
|
47 |
except Exception as e:
|
48 |
-
print(f"
|
49 |
return qa_pairs
|
|
|
50 |
qa_data = load_qa_from_file("examples.txt")
|
51 |
questions = [qa["q"] for qa in qa_data]
|
52 |
embeddings = embedder.encode(questions, convert_to_tensor=True)
|
53 |
|
54 |
-
# ===
|
55 |
-
|
56 |
def build_prompt(question):
|
57 |
try:
|
58 |
with open("examples.txt", "r", encoding="utf-8") as file:
|
@@ -61,41 +61,47 @@ def build_prompt(question):
|
|
61 |
example_block = "<|system|>Bạn là một trợ lý thương mại điện tử chuyên nghiệp tại Việt Nam."
|
62 |
return example_block + f"\n<|human|>Hỏi: {question}\n<|assistant|>"
|
63 |
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
66 |
cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
|
67 |
top_idx = torch.argmax(cos_scores).item()
|
68 |
top_score = cos_scores[top_idx].item()
|
69 |
-
|
70 |
if top_score >= 0.75:
|
71 |
return qa_data[top_idx]["a"]
|
|
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
with torch.no_grad():
|
80 |
-
output = model.generate(
|
81 |
-
**inputs,
|
82 |
-
max_new_tokens=120,
|
83 |
-
temperature=0.6,
|
84 |
-
top_p=0.9,
|
85 |
-
do_sample=True,
|
86 |
-
repetition_penalty=1.15,
|
87 |
-
no_repeat_ngram_size=3,
|
88 |
-
pad_token_id=tokenizer.pad_token_id,
|
89 |
-
eos_token_id=tokenizer.eos_token_id,
|
90 |
-
)
|
91 |
-
|
92 |
-
output_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip()
|
93 |
-
lines = [line.strip() for line in output_text.splitlines() if line.strip()]
|
94 |
-
response_line = next((line for line in lines if "phản hồi" in line.lower()), None)
|
95 |
-
|
96 |
-
return response_line.split("**Phản hồi**:")[-1].strip() if response_line else "Vui lòng cung cấp thêm thông tin để được hỗ trợ!"
|
97 |
-
except:
|
98 |
-
return "Xin lỗi, hệ thống không thể xử lý câu hỏi hiện tại."
|
99 |
|
100 |
# === Giao diện Gradio
|
101 |
interface = gr.Interface(
|
@@ -103,7 +109,7 @@ interface = gr.Interface(
|
|
103 |
inputs=gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn..."),
|
104 |
outputs="text",
|
105 |
title="Vietnamese E-commerce Chatbot",
|
106 |
-
description="Trợ lý AI
|
107 |
examples=[
|
108 |
["Tôi muốn kiểm tra đơn hàng"],
|
109 |
["Có giảm giá khi mua số lượng lớn không?"],
|
|
|
11 |
|
12 |
# === Thiết bị
|
13 |
device = torch.device("cpu")
|
14 |
+
print("✅ Using device:", device)
|
15 |
|
16 |
# === Load mô hình sinh phản hồi
|
17 |
model_name = "vanhai123/vietnamese-ecom-chatbot"
|
|
|
22 |
).to(device)
|
23 |
base_model.resize_token_embeddings(len(tokenizer))
|
24 |
model = PeftModel.from_pretrained(base_model, model_name).to(device)
|
25 |
+
print("Model and tokenizer loaded successfully!")
|
26 |
except Exception as e:
|
27 |
+
print(f"Error loading model or tokenizer: {str(e)}")
|
28 |
raise
|
29 |
def load_qa_from_file(path="examples.txt"):
|
30 |
qa_pairs = []
|
|
|
45 |
qa_pairs.append({"q": question, "a": answer})
|
46 |
break
|
47 |
except Exception as e:
|
48 |
+
print(f"Lỗi đọc file: {e}")
|
49 |
return qa_pairs
|
50 |
+
|
51 |
qa_data = load_qa_from_file("examples.txt")
|
52 |
questions = [qa["q"] for qa in qa_data]
|
53 |
embeddings = embedder.encode(questions, convert_to_tensor=True)
|
54 |
|
55 |
+
# === Prompt builder
|
|
|
56 |
def build_prompt(question):
|
57 |
try:
|
58 |
with open("examples.txt", "r", encoding="utf-8") as file:
|
|
|
61 |
example_block = "<|system|>Bạn là một trợ lý thương mại điện tử chuyên nghiệp tại Việt Nam."
|
62 |
return example_block + f"\n<|human|>Hỏi: {question}\n<|assistant|>"
|
63 |
|
64 |
+
# === Sinh phản hồi từ mô hình
|
65 |
+
|
66 |
+
def generate_with_model(question):
|
67 |
+
prompt = build_prompt(question)
|
68 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
69 |
+
input_len = inputs["input_ids"].shape[-1]
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
output = model.generate(
|
73 |
+
**inputs,
|
74 |
+
max_new_tokens=120,
|
75 |
+
temperature=0.6,
|
76 |
+
top_p=0.9,
|
77 |
+
do_sample=True,
|
78 |
+
repetition_penalty=1.15,
|
79 |
+
no_repeat_ngram_size=3,
|
80 |
+
pad_token_id=tokenizer.pad_token_id,
|
81 |
+
eos_token_id=tokenizer.eos_token_id,
|
82 |
+
)
|
83 |
+
|
84 |
+
output_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip()
|
85 |
+
lines = [line.strip() for line in output_text.splitlines() if line.strip()]
|
86 |
+
for line in lines:
|
87 |
+
if "**Phản hồi**" in line:
|
88 |
+
return line.split("**Phản hồi**:")[-1].strip()
|
89 |
+
return None
|
90 |
+
|
91 |
+
def semantic_fallback(question):
|
92 |
+
query_embedding = embedder.encode(question, convert_to_tensor=True)
|
93 |
cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
|
94 |
top_idx = torch.argmax(cos_scores).item()
|
95 |
top_score = cos_scores[top_idx].item()
|
|
|
96 |
if top_score >= 0.75:
|
97 |
return qa_data[top_idx]["a"]
|
98 |
+
return "Vui lòng liên hệ CSKH để được hỗ trợ!"
|
99 |
|
100 |
+
def answer_question(user_question):
|
101 |
+
response = generate_with_model(user_question)
|
102 |
+
if response and len(response) > 30:
|
103 |
+
return response
|
104 |
+
return semantic_fallback(user_question)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
# === Giao diện Gradio
|
107 |
interface = gr.Interface(
|
|
|
109 |
inputs=gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn..."),
|
110 |
outputs="text",
|
111 |
title="Vietnamese E-commerce Chatbot",
|
112 |
+
description="Trợ lý AI thương mại điện tử: Trả lời từ mô hình ngôn ngữ hoặc tra cứu dữ liệu câu hỏi.",
|
113 |
examples=[
|
114 |
["Tôi muốn kiểm tra đơn hàng"],
|
115 |
["Có giảm giá khi mua số lượng lớn không?"],
|