vanhai123 commited on
Commit
e19aff3
·
verified ·
1 Parent(s): 3ae87dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -36
app.py CHANGED
@@ -11,7 +11,7 @@ embedder = SentenceTransformer("keepitreal/vietnamese-sbert")
11
 
12
  # === Thiết bị
13
  device = torch.device("cpu")
14
- print("Using device:", device)
15
 
16
  # === Load mô hình sinh phản hồi
17
  model_name = "vanhai123/vietnamese-ecom-chatbot"
@@ -22,9 +22,9 @@ try:
22
  ).to(device)
23
  base_model.resize_token_embeddings(len(tokenizer))
24
  model = PeftModel.from_pretrained(base_model, model_name).to(device)
25
- print("Model and tokenizer loaded successfully!")
26
  except Exception as e:
27
- print(f"Error loading model or tokenizer: {str(e)}")
28
  raise
29
  def load_qa_from_file(path="examples.txt"):
30
  qa_pairs = []
@@ -45,14 +45,14 @@ def load_qa_from_file(path="examples.txt"):
45
  qa_pairs.append({"q": question, "a": answer})
46
  break
47
  except Exception as e:
48
- print(f"Lỗi đọc file: {e}")
49
  return qa_pairs
 
50
  qa_data = load_qa_from_file("examples.txt")
51
  questions = [qa["q"] for qa in qa_data]
52
  embeddings = embedder.encode(questions, convert_to_tensor=True)
53
 
54
- # === Xây dựng prompt sinh
55
-
56
  def build_prompt(question):
57
  try:
58
  with open("examples.txt", "r", encoding="utf-8") as file:
@@ -61,41 +61,47 @@ def build_prompt(question):
61
  example_block = "<|system|>Bạn là một trợ lý thương mại điện tử chuyên nghiệp tại Việt Nam."
62
  return example_block + f"\n<|human|>Hỏi: {question}\n<|assistant|>"
63
 
64
- def answer_question(user_question):
65
- query_embedding = embedder.encode(user_question, convert_to_tensor=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
67
  top_idx = torch.argmax(cos_scores).item()
68
  top_score = cos_scores[top_idx].item()
69
-
70
  if top_score >= 0.75:
71
  return qa_data[top_idx]["a"]
 
72
 
73
- # fallback bằng model nếu không khớp
74
- try:
75
- prompt = build_prompt(user_question)
76
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
77
- input_len = inputs["input_ids"].shape[-1]
78
-
79
- with torch.no_grad():
80
- output = model.generate(
81
- **inputs,
82
- max_new_tokens=120,
83
- temperature=0.6,
84
- top_p=0.9,
85
- do_sample=True,
86
- repetition_penalty=1.15,
87
- no_repeat_ngram_size=3,
88
- pad_token_id=tokenizer.pad_token_id,
89
- eos_token_id=tokenizer.eos_token_id,
90
- )
91
-
92
- output_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip()
93
- lines = [line.strip() for line in output_text.splitlines() if line.strip()]
94
- response_line = next((line for line in lines if "phản hồi" in line.lower()), None)
95
-
96
- return response_line.split("**Phản hồi**:")[-1].strip() if response_line else "Vui lòng cung cấp thêm thông tin để được hỗ trợ!"
97
- except:
98
- return "Xin lỗi, hệ thống không thể xử lý câu hỏi hiện tại."
99
 
100
  # === Giao diện Gradio
101
  interface = gr.Interface(
@@ -103,7 +109,7 @@ interface = gr.Interface(
103
  inputs=gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn..."),
104
  outputs="text",
105
  title="Vietnamese E-commerce Chatbot",
106
- description="Trợ lý AI trả lời câu hỏi thương mại điện tử từ sở dữ liệu hoặc sinh mới.",
107
  examples=[
108
  ["Tôi muốn kiểm tra đơn hàng"],
109
  ["Có giảm giá khi mua số lượng lớn không?"],
 
11
 
12
  # === Thiết bị
13
  device = torch.device("cpu")
14
+ print("Using device:", device)
15
 
16
  # === Load mô hình sinh phản hồi
17
  model_name = "vanhai123/vietnamese-ecom-chatbot"
 
22
  ).to(device)
23
  base_model.resize_token_embeddings(len(tokenizer))
24
  model = PeftModel.from_pretrained(base_model, model_name).to(device)
25
+ print("Model and tokenizer loaded successfully!")
26
  except Exception as e:
27
+ print(f"Error loading model or tokenizer: {str(e)}")
28
  raise
29
  def load_qa_from_file(path="examples.txt"):
30
  qa_pairs = []
 
45
  qa_pairs.append({"q": question, "a": answer})
46
  break
47
  except Exception as e:
48
+ print(f"Lỗi đọc file: {e}")
49
  return qa_pairs
50
+
51
  qa_data = load_qa_from_file("examples.txt")
52
  questions = [qa["q"] for qa in qa_data]
53
  embeddings = embedder.encode(questions, convert_to_tensor=True)
54
 
55
+ # === Prompt builder
 
56
  def build_prompt(question):
57
  try:
58
  with open("examples.txt", "r", encoding="utf-8") as file:
 
61
  example_block = "<|system|>Bạn là một trợ lý thương mại điện tử chuyên nghiệp tại Việt Nam."
62
  return example_block + f"\n<|human|>Hỏi: {question}\n<|assistant|>"
63
 
64
+ # === Sinh phản hồi từ mô hình
65
+
66
+ def generate_with_model(question):
67
+ prompt = build_prompt(question)
68
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
69
+ input_len = inputs["input_ids"].shape[-1]
70
+
71
+ with torch.no_grad():
72
+ output = model.generate(
73
+ **inputs,
74
+ max_new_tokens=120,
75
+ temperature=0.6,
76
+ top_p=0.9,
77
+ do_sample=True,
78
+ repetition_penalty=1.15,
79
+ no_repeat_ngram_size=3,
80
+ pad_token_id=tokenizer.pad_token_id,
81
+ eos_token_id=tokenizer.eos_token_id,
82
+ )
83
+
84
+ output_text = tokenizer.decode(output[0][input_len:], skip_special_tokens=True).strip()
85
+ lines = [line.strip() for line in output_text.splitlines() if line.strip()]
86
+ for line in lines:
87
+ if "**Phản hồi**" in line:
88
+ return line.split("**Phản hồi**:")[-1].strip()
89
+ return None
90
+
91
+ def semantic_fallback(question):
92
+ query_embedding = embedder.encode(question, convert_to_tensor=True)
93
  cos_scores = util.pytorch_cos_sim(query_embedding, embeddings)[0]
94
  top_idx = torch.argmax(cos_scores).item()
95
  top_score = cos_scores[top_idx].item()
 
96
  if top_score >= 0.75:
97
  return qa_data[top_idx]["a"]
98
+ return "Vui lòng liên hệ CSKH để được hỗ trợ!"
99
 
100
+ def answer_question(user_question):
101
+ response = generate_with_model(user_question)
102
+ if response and len(response) > 30:
103
+ return response
104
+ return semantic_fallback(user_question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
  # === Giao diện Gradio
107
  interface = gr.Interface(
 
109
  inputs=gr.Textbox(lines=2, placeholder="Nhập câu hỏi của bạn..."),
110
  outputs="text",
111
  title="Vietnamese E-commerce Chatbot",
112
+ description="Trợ lý AI thương mại điện tử: Trả lời từ hình ngôn ngữ hoặc tra cứu dữ liệu câu hỏi.",
113
  examples=[
114
  ["Tôi muốn kiểm tra đơn hàng"],
115
  ["Có giảm giá khi mua số lượng lớn không?"],