kanninian commited on
Commit
d43ed98
·
verified ·
1 Parent(s): 4aa65bc

Update qa_vector_store.py

Browse files
Files changed (1) hide show
  1. qa_vector_store.py +161 -161
qa_vector_store.py CHANGED
@@ -1,161 +1,161 @@
1
- #%%
2
- from text2vec import SentenceModel
3
- from qdrant_client import QdrantClient
4
- from qdrant_client.models import VectorParams, Distance, PointStruct
5
-
6
-
7
- def deterministic_id(text):
8
- import hashlib
9
- return int(hashlib.sha256(text.encode('utf-8')).hexdigest(), 16) >> 128
10
-
11
- def build_qa_vector_store(model_name, collection_name):
12
- import pandas as pd
13
- # 讀取資料
14
- df = pd.read_excel("data/一百問三百答.xlsx", sheet_name=0)
15
- df.columns = ['Question', 'Answer']
16
- original_len = len(df)
17
-
18
- # 去除重複 QA 組合
19
- df = df.drop_duplicates(subset=["Question", "Answer"]).reset_index(drop=True)
20
- print(f"📊 原始資料筆數:{original_len},去除重複後筆數:{len(df)}")
21
-
22
- questions = df['Question'].tolist()
23
- answers = df['Answer'].tolist()
24
- # 初始化模型
25
- model = SentenceModel(model_name)
26
- question_vectors = model.encode(questions, normalize_embeddings=True)
27
- embedding_dim = len(question_vectors[0])
28
-
29
- # 初始化 Qdrant
30
- client = QdrantClient(path="./qadrant_data")
31
-
32
- # 建立新的 collection(重新指定向量維度)
33
- client.recreate_collection(
34
- collection_name=collection_name,
35
- vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE)
36
- )
37
-
38
- points = [
39
- PointStruct(
40
- id=deterministic_id(q + a),
41
- vector=vector.tolist(),
42
- payload={"question": q, "answer": a}
43
- )
44
- for q, a, vector in zip(questions, answers, question_vectors)
45
- ]
46
-
47
- client.upsert(collection_name=collection_name, points=points)
48
- print(f"✅ 向量資料庫建立完成,共嵌入 {len(points)} 筆 QA。")
49
- client.scroll(collection_name=collection_name, limit=100)
50
-
51
-
52
- # build_qa_vector_store(model_name, collection_name)
53
-
54
- # %%
55
- # model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
56
- # collection_name = model_name.split("/")[-1]
57
- # client = QdrantClient(path="./qadrant_data")
58
- # count = client.count(collection_name=collection_name, exact=True).count
59
- # print(f"📦 Collection {collection_name} 中有 {count} 筆資料")
60
- # # %%
61
- # from collections import Counter
62
-
63
- # records = client.scroll(collection_name=collection_name, limit=1000)[0]
64
- # answers = [rec.payload["answer"] for rec in records]
65
- # duplicates = [item for item, count in Counter(answers).items() if count > 1]
66
- # print("重複答案數量:", len(duplicates))
67
- # print("部分重複答案:", duplicates[:5])
68
-
69
- def retrieve_and_rerank(query, model_name, collection_name, cross_encoder_model, score_threshold=0.6, search_top_k=25, rerank_top_k=5):
70
- from semantic_reranker import rerank_results
71
- model = SentenceModel(model_name)
72
- client = QdrantClient(path="./qadrant_data")
73
-
74
- # 確認 collection 是否存在
75
- if collection_name not in [c.name for c in client.get_collections().collections]:
76
- print(f"⚠️ Collection {collection_name} 不存在,請先建立向量資料庫。")
77
- return
78
-
79
- query_vector = model.encode(query, normalize_embeddings=True)
80
- results = client.search(
81
- collection_name=collection_name,
82
- query_vector=query_vector.tolist(),
83
- limit=search_top_k,
84
- score_threshold=score_threshold
85
- )
86
-
87
- retrieved_answers = [r.payload["answer"] for r in results]
88
-
89
- # 用 cross-encoder rerank
90
- reranked = rerank_results(
91
- cross_encoder_model=cross_encoder_model,
92
- query=query,
93
- candidate_passages=retrieved_answers,
94
- top_k=rerank_top_k
95
- )
96
-
97
- if not results:
98
- print("❌ 找不到相關答案。")
99
- else:
100
- for i, hit in enumerate(results):
101
- print(f"✅ Top {i+1}:{hit.payload['answer']} (score={hit.score:.3f})")
102
-
103
- for i, (answer, score) in enumerate(reranked):
104
- print(f"🔥 Rerank Top {i+1}:{answer} (score={score:.3f})")
105
- return reranked
106
-
107
-
108
- def generate_response_from_local_llm(query, passages, tokenizer, model, max_new_tokens=256):
109
- context = "\n".join([f"- {p}" for p in passages])
110
- if not passages:
111
- return "抱歉,沒有找到相關資料。"
112
- prompt = f"""你是一個客服助理。請根據以下資料回答使用者問題。
113
- 問題:{query}
114
- 相關資料:
115
- {context}
116
- 回答時注意:。
117
- - 一定要針對問題回答
118
- - 若是問題太過負面一定要先反駁,再巧妙的轉移焦點或拒答。
119
- - 只回答問題中提到的內容,避免無關的話題。
120
- - 只使用相關資料的內容來回答問題,避免添加個人意見或無關的資訊,若有必要可拒答。
121
- - 只回答正面、積極的內容,避免使用負面或消極的語言。
122
- - 請以溫暖又充滿人性的方式回答問題。
123
- - 回答時平易近人,像和朋友交談一樣。
124
- - 精簡回答,避免冗長的解釋。
125
- 回答:"""
126
- print(prompt)
127
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
128
- outputs = model.generate(
129
- **inputs,
130
- max_new_tokens=max_new_tokens,
131
- do_sample=True,
132
- top_p=0.95,
133
- temperature=0.7
134
- )
135
- decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
136
-
137
- # 提取回答部分
138
- answer = decoded_output.split("回答:", 1)[-1].strip() if "回答:" in decoded_output else decoded_output
139
- return answer
140
- #%%
141
- # from sentence_transformers import CrossEncoder
142
- # from transformers import AutoTokenizer, AutoModelForCausalLM
143
- # model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
144
- # collection_name = model_name.split("/")[-1]
145
- # cross_encoder_model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")
146
- # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B", trust_remote_code=True)
147
- # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B", trust_remote_code=True)
148
- # # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-7B-Chat", trust_remote_code=True)
149
- # # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-7B-Chat", trust_remote_code=True)
150
-
151
- # #%%
152
- # user_query = "許智傑做過什麼壞事"
153
- # reranked = retrieve_and_rerank(user_query, model_name, collection_name, cross_encoder_model, score_threshold=0.6, search_top_k=20, rerank_top_k=5)
154
- # #%%
155
- # passages = [answer for answer, score in reranked]
156
- # answer = generate_response_from_local_llm(user_query, passages, tokenizer, model, max_new_tokens=256)
157
- # print("回答:", answer)
158
-
159
-
160
-
161
- # %%
 
1
+ #%%
2
+ from text2vec import SentenceModel
3
+ from qdrant_client import QdrantClient
4
+ from qdrant_client.models import VectorParams, Distance, PointStruct
5
+
6
+
7
+ def deterministic_id(text):
8
+ import hashlib
9
+ return int(hashlib.sha256(text.encode('utf-8')).hexdigest(), 16) >> 128
10
+
11
+ def build_qa_vector_store(model_name, collection_name):
12
+ import pandas as pd
13
+ # 讀取資料
14
+ df = pd.read_excel("一百問三百答.xlsx", sheet_name=0)
15
+ df.columns = ['Question', 'Answer']
16
+ original_len = len(df)
17
+
18
+ # 去除重複 QA 組合
19
+ df = df.drop_duplicates(subset=["Question", "Answer"]).reset_index(drop=True)
20
+ print(f"📊 原始資料筆數:{original_len},去除重複後筆數:{len(df)}")
21
+
22
+ questions = df['Question'].tolist()
23
+ answers = df['Answer'].tolist()
24
+ # 初始化模型
25
+ model = SentenceModel(model_name)
26
+ question_vectors = model.encode(questions, normalize_embeddings=True)
27
+ embedding_dim = len(question_vectors[0])
28
+
29
+ # 初始化 Qdrant
30
+ client = QdrantClient(path="./qadrant_data")
31
+
32
+ # 建立新的 collection(重新指定向量維度)
33
+ client.recreate_collection(
34
+ collection_name=collection_name,
35
+ vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE)
36
+ )
37
+
38
+ points = [
39
+ PointStruct(
40
+ id=deterministic_id(q + a),
41
+ vector=vector.tolist(),
42
+ payload={"question": q, "answer": a}
43
+ )
44
+ for q, a, vector in zip(questions, answers, question_vectors)
45
+ ]
46
+
47
+ client.upsert(collection_name=collection_name, points=points)
48
+ print(f"✅ 向量資料庫建立完成,共嵌入 {len(points)} 筆 QA。")
49
+ client.scroll(collection_name=collection_name, limit=100)
50
+
51
+
52
+ # build_qa_vector_store(model_name, collection_name)
53
+
54
+ # %%
55
+ # model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
56
+ # collection_name = model_name.split("/")[-1]
57
+ # client = QdrantClient(path="./qadrant_data")
58
+ # count = client.count(collection_name=collection_name, exact=True).count
59
+ # print(f"📦 Collection {collection_name} 中有 {count} 筆資料")
60
+ # # %%
61
+ # from collections import Counter
62
+
63
+ # records = client.scroll(collection_name=collection_name, limit=1000)[0]
64
+ # answers = [rec.payload["answer"] for rec in records]
65
+ # duplicates = [item for item, count in Counter(answers).items() if count > 1]
66
+ # print("重複答案數量:", len(duplicates))
67
+ # print("部分重複答案:", duplicates[:5])
68
+
69
+ def retrieve_and_rerank(query, model_name, collection_name, cross_encoder_model, score_threshold=0.6, search_top_k=25, rerank_top_k=5):
70
+ from semantic_reranker import rerank_results
71
+ model = SentenceModel(model_name)
72
+ client = QdrantClient(path="./qadrant_data")
73
+
74
+ # 確認 collection 是否存在
75
+ if collection_name not in [c.name for c in client.get_collections().collections]:
76
+ print(f"⚠️ Collection {collection_name} 不存在,請先建立向量資料庫。")
77
+ return
78
+
79
+ query_vector = model.encode(query, normalize_embeddings=True)
80
+ results = client.search(
81
+ collection_name=collection_name,
82
+ query_vector=query_vector.tolist(),
83
+ limit=search_top_k,
84
+ score_threshold=score_threshold
85
+ )
86
+
87
+ retrieved_answers = [r.payload["answer"] for r in results]
88
+
89
+ # 用 cross-encoder rerank
90
+ reranked = rerank_results(
91
+ cross_encoder_model=cross_encoder_model,
92
+ query=query,
93
+ candidate_passages=retrieved_answers,
94
+ top_k=rerank_top_k
95
+ )
96
+
97
+ if not results:
98
+ print("❌ 找不到相關答案。")
99
+ else:
100
+ for i, hit in enumerate(results):
101
+ print(f"✅ Top {i+1}:{hit.payload['answer']} (score={hit.score:.3f})")
102
+
103
+ for i, (answer, score) in enumerate(reranked):
104
+ print(f"🔥 Rerank Top {i+1}:{answer} (score={score:.3f})")
105
+ return reranked
106
+
107
+
108
+ def generate_response_from_local_llm(query, passages, tokenizer, model, max_new_tokens=256):
109
+ context = "\n".join([f"- {p}" for p in passages])
110
+ if not passages:
111
+ return "抱歉,沒有找到相關資料。"
112
+ prompt = f"""你是一個客服助理。請根據以下資料回答使用者問題。
113
+ 問題:{query}
114
+ 相關資料:
115
+ {context}
116
+ 回答時注意:。
117
+ - 一定要針對問題回答
118
+ - 若是問題太過負面一定要先反駁,再巧妙的轉移焦點或拒答。
119
+ - 只回答問題中提到的內容,避免無關的話題。
120
+ - 只使用相關資料的內容來回答問題,避免添加個人意見或無關的資訊,若有必要可拒答。
121
+ - 只回答正面、積極的內容,避免使用負面或消極的語言。
122
+ - 請以溫暖又充滿人性的方式回答問題。
123
+ - 回答時平易近人,像和朋友交談一樣。
124
+ - 精簡回答,避免冗長的解釋。
125
+ 回答:"""
126
+ print(prompt)
127
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
128
+ outputs = model.generate(
129
+ **inputs,
130
+ max_new_tokens=max_new_tokens,
131
+ do_sample=True,
132
+ top_p=0.95,
133
+ temperature=0.7
134
+ )
135
+ decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
136
+
137
+ # 提取回答部分
138
+ answer = decoded_output.split("回答:", 1)[-1].strip() if "回答:" in decoded_output else decoded_output
139
+ return answer
140
+ #%%
141
+ # from sentence_transformers import CrossEncoder
142
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
143
+ # model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
144
+ # collection_name = model_name.split("/")[-1]
145
+ # cross_encoder_model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")
146
+ # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B", trust_remote_code=True)
147
+ # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B", trust_remote_code=True)
148
+ # # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-7B-Chat", trust_remote_code=True)
149
+ # # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-7B-Chat", trust_remote_code=True)
150
+
151
+ # #%%
152
+ # user_query = "許智傑做過什麼壞事"
153
+ # reranked = retrieve_and_rerank(user_query, model_name, collection_name, cross_encoder_model, score_threshold=0.6, search_top_k=20, rerank_top_k=5)
154
+ # #%%
155
+ # passages = [answer for answer, score in reranked]
156
+ # answer = generate_response_from_local_llm(user_query, passages, tokenizer, model, max_new_tokens=256)
157
+ # print("回答:", answer)
158
+
159
+
160
+
161
+ # %%