kanninian commited on
Commit
bb1f9ea
·
verified ·
1 Parent(s): e56d196

Upload 2 files

Browse files
Files changed (2) hide show
  1. qa_vector_store.py +161 -0
  2. semantic_reranker.py +8 -0
qa_vector_store.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ from text2vec import SentenceModel
3
+ from qdrant_client import QdrantClient
4
+ from qdrant_client.models import VectorParams, Distance, PointStruct
5
+
6
+
7
+ def deterministic_id(text):
8
+ import hashlib
9
+ return int(hashlib.sha256(text.encode('utf-8')).hexdigest(), 16) >> 128
10
+
11
+ def build_qa_vector_store(model_name, collection_name):
12
+ import pandas as pd
13
+ # 讀取資料
14
+ df = pd.read_excel("data/一百問三百答.xlsx", sheet_name=0)
15
+ df.columns = ['Question', 'Answer']
16
+ original_len = len(df)
17
+
18
+ # 去除重複 QA 組合
19
+ df = df.drop_duplicates(subset=["Question", "Answer"]).reset_index(drop=True)
20
+ print(f"📊 原始資料筆數:{original_len},去除重複後筆數:{len(df)}")
21
+
22
+ questions = df['Question'].tolist()
23
+ answers = df['Answer'].tolist()
24
+ # 初始化模型
25
+ model = SentenceModel(model_name)
26
+ question_vectors = model.encode(questions, normalize_embeddings=True)
27
+ embedding_dim = len(question_vectors[0])
28
+
29
+ # 初始化 Qdrant
30
+ client = QdrantClient(path="./qadrant_data")
31
+
32
+ # 建立新的 collection(重新指定向量維度)
33
+ client.recreate_collection(
34
+ collection_name=collection_name,
35
+ vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE)
36
+ )
37
+
38
+ points = [
39
+ PointStruct(
40
+ id=deterministic_id(q + a),
41
+ vector=vector.tolist(),
42
+ payload={"question": q, "answer": a}
43
+ )
44
+ for q, a, vector in zip(questions, answers, question_vectors)
45
+ ]
46
+
47
+ client.upsert(collection_name=collection_name, points=points)
48
+ print(f"✅ 向量資料庫建立完成,共嵌入 {len(points)} 筆 QA。")
49
+ client.scroll(collection_name=collection_name, limit=100)
50
+
51
+
52
+ # build_qa_vector_store(model_name, collection_name)
53
+
54
+ # %%
55
+ # model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
56
+ # collection_name = model_name.split("/")[-1]
57
+ # client = QdrantClient(path="./qadrant_data")
58
+ # count = client.count(collection_name=collection_name, exact=True).count
59
+ # print(f"📦 Collection {collection_name} 中有 {count} 筆資料")
60
+ # # %%
61
+ # from collections import Counter
62
+
63
+ # records = client.scroll(collection_name=collection_name, limit=1000)[0]
64
+ # answers = [rec.payload["answer"] for rec in records]
65
+ # duplicates = [item for item, count in Counter(answers).items() if count > 1]
66
+ # print("重複答案數量:", len(duplicates))
67
+ # print("部分重複答案:", duplicates[:5])
68
+
69
+ def retrieve_and_rerank(query, model_name, collection_name, cross_encoder_model, score_threshold=0.6, search_top_k=25, rerank_top_k=5):
70
+ from semantic_reranker import rerank_results
71
+ model = SentenceModel(model_name)
72
+ client = QdrantClient(path="./qadrant_data")
73
+
74
+ # 確認 collection 是否存在
75
+ if collection_name not in [c.name for c in client.get_collections().collections]:
76
+ print(f"⚠️ Collection {collection_name} 不存在,請先建立向量資料庫。")
77
+ return
78
+
79
+ query_vector = model.encode(query, normalize_embeddings=True)
80
+ results = client.search(
81
+ collection_name=collection_name,
82
+ query_vector=query_vector.tolist(),
83
+ limit=search_top_k,
84
+ score_threshold=score_threshold
85
+ )
86
+
87
+ retrieved_answers = [r.payload["answer"] for r in results]
88
+
89
+ # 用 cross-encoder rerank
90
+ reranked = rerank_results(
91
+ cross_encoder_model=cross_encoder_model,
92
+ query=query,
93
+ candidate_passages=retrieved_answers,
94
+ top_k=rerank_top_k
95
+ )
96
+
97
+ if not results:
98
+ print("❌ 找不到相關答案。")
99
+ else:
100
+ for i, hit in enumerate(results):
101
+ print(f"✅ Top {i+1}:{hit.payload['answer']} (score={hit.score:.3f})")
102
+
103
+ for i, (answer, score) in enumerate(reranked):
104
+ print(f"🔥 Rerank Top {i+1}:{answer} (score={score:.3f})")
105
+ return reranked
106
+
107
+
108
+ def generate_response_from_local_llm(query, passages, tokenizer, model, max_new_tokens=256):
109
+ context = "\n".join([f"- {p}" for p in passages])
110
+ if not passages:
111
+ return "抱歉,沒有找到相關資料。"
112
+ prompt = f"""你是一個客服助理。請根據以下資料回答使用者問題。
113
+ 問題:{query}
114
+ 相關資料:
115
+ {context}
116
+ 回答時注意:。
117
+ - 一定要針對問題回答
118
+ - 若是問題太過負面一定要先反駁,再巧妙的轉移焦點或拒答。
119
+ - 只回答問題中提到的內容,避免無關的話題。
120
+ - 只使用相關資料的內容來回答問題,避免添加個人意見或無關的資訊,若有必要可拒答。
121
+ - 只回答正面、積極的內容,避免使用負面或消極的語言。
122
+ - 請以溫暖又充滿人性的方式回答問題。
123
+ - 回答時平易近人,像和朋友交談一樣。
124
+ - 精簡回答,避免冗長的解釋。
125
+ 回答:"""
126
+ print(prompt)
127
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
128
+ outputs = model.generate(
129
+ **inputs,
130
+ max_new_tokens=max_new_tokens,
131
+ do_sample=True,
132
+ top_p=0.95,
133
+ temperature=0.7
134
+ )
135
+ decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
136
+
137
+ # 提取回答部分
138
+ answer = decoded_output.split("回答:", 1)[-1].strip() if "回答:" in decoded_output else decoded_output
139
+ return answer
140
+ #%%
141
+ # from sentence_transformers import CrossEncoder
142
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
143
+ # model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
144
+ # collection_name = model_name.split("/")[-1]
145
+ # cross_encoder_model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")
146
+ # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B", trust_remote_code=True)
147
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B", trust_remote_code=True)
148
+ # # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-7B-Chat", trust_remote_code=True)
149
+ # # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-7B-Chat", trust_remote_code=True)
150
+
151
+ # #%%
152
+ # user_query = "許智傑做過什麼壞事"
153
+ # reranked = retrieve_and_rerank(user_query, model_name, collection_name, cross_encoder_model, score_threshold=0.6, search_top_k=20, rerank_top_k=5)
154
+ # #%%
155
+ # passages = [answer for answer, score in reranked]
156
+ # answer = generate_response_from_local_llm(user_query, passages, tokenizer, model, max_new_tokens=256)
157
+ # print("回答:", answer)
158
+
159
+
160
+
161
+ # %%
semantic_reranker.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ def rerank_results(cross_encoder_model, query, candidate_passages, top_k=10):
2
+ if not candidate_passages:
3
+ return []
4
+ pairs = [(query, passage) for passage in candidate_passages]
5
+ scores = cross_encoder_model.predict(pairs)
6
+ reranked = sorted(zip(candidate_passages, scores), key=lambda x: x[1], reverse=True)
7
+
8
+ return reranked[:top_k]