kanninian commited on
Commit
452f50d
·
verified ·
1 Parent(s): b1bdc38

Update qa_vector_store.py

Browse files
Files changed (1) hide show
  1. qa_vector_store.py +3 -61
qa_vector_store.py CHANGED
@@ -15,10 +15,6 @@ def build_qa_vector_store(model_name, collection_name):
15
  df.columns = ['Question', 'Answer']
16
  original_len = len(df)
17
 
18
- # 去除重複 QA 組合
19
- df = df.drop_duplicates(subset=["Question", "Answer"]).reset_index(drop=True)
20
- print(f"📊 原始資料筆數:{original_len},去除重複後筆數:{len(df)}")
21
-
22
  questions = df['Question'].tolist()
23
  answers = df['Answer'].tolist()
24
  # 初始化模型
@@ -48,33 +44,10 @@ def build_qa_vector_store(model_name, collection_name):
48
  print(f"✅ 向量資料庫建立完成,共嵌入 {len(points)} 筆 QA。")
49
  client.scroll(collection_name=collection_name, limit=100)
50
 
51
-
52
- # build_qa_vector_store(model_name, collection_name)
53
-
54
- # %%
55
- # model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
56
- # collection_name = model_name.split("/")[-1]
57
- # client = QdrantClient(path="./qadrant_data")
58
- # count = client.count(collection_name=collection_name, exact=True).count
59
- # print(f"📦 Collection {collection_name} 中有 {count} 筆資料")
60
- # # %%
61
- # from collections import Counter
62
-
63
- # records = client.scroll(collection_name=collection_name, limit=1000)[0]
64
- # answers = [rec.payload["answer"] for rec in records]
65
- # duplicates = [item for item, count in Counter(answers).items() if count > 1]
66
- # print("重複答案數量:", len(duplicates))
67
- # print("部分重複答案:", duplicates[:5])
68
-
69
  def retrieve_and_rerank(query, model_name, collection_name, cross_encoder_model, score_threshold=0.6, search_top_k=25, rerank_top_k=5):
70
  from semantic_reranker import rerank_results
71
  model = SentenceModel(model_name)
72
  client = QdrantClient(path="./qadrant_data")
73
-
74
- # 確認 collection 是否存在
75
- if collection_name not in [c.name for c in client.get_collections().collections]:
76
- print(f"⚠️ Collection {collection_name} 不存在,請先建立向量資料庫。")
77
- return
78
 
79
  query_vector = model.encode(query, normalize_embeddings=True)
80
  results = client.search(
@@ -93,15 +66,6 @@ def retrieve_and_rerank(query, model_name, collection_name, cross_encoder_model,
93
  candidate_passages=retrieved_answers,
94
  top_k=rerank_top_k
95
  )
96
-
97
- if not results:
98
- print("❌ 找不到相關答案。")
99
- else:
100
- for i, hit in enumerate(results):
101
- print(f"✅ Top {i+1}:{hit.payload['answer']} (score={hit.score:.3f})")
102
-
103
- for i, (answer, score) in enumerate(reranked):
104
- print(f"🔥 Rerank Top {i+1}:{answer} (score={score:.3f})")
105
  return reranked
106
 
107
 
@@ -123,7 +87,6 @@ def generate_response_from_local_llm(query, passages, tokenizer, model, max_new_
123
  - 回答時平易近人,像和朋友交談一樣。
124
  - 精簡回答,避免冗長的解釋。
125
  回答:"""
126
- print(prompt)
127
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
128
  outputs = model.generate(
129
  **inputs,
@@ -135,27 +98,6 @@ def generate_response_from_local_llm(query, passages, tokenizer, model, max_new_
135
  decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
136
 
137
  # 提取回答部分
138
- answer = decoded_output.split("回答:", 1)[-1].strip() if "回答:" in decoded_output else decoded_output
139
- return answer
140
- #%%
141
- # from sentence_transformers import CrossEncoder
142
- # from transformers import AutoTokenizer, AutoModelForCausalLM
143
- # model_name = "sentence-transformers/paraphrase-multilingual-mpnet-base-v2"
144
- # collection_name = model_name.split("/")[-1]
145
- # cross_encoder_model = CrossEncoder("cross-encoder/mmarco-mMiniLMv2-L12-H384-v1")
146
- # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B", trust_remote_code=True)
147
- # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B", trust_remote_code=True)
148
- # # tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen1.5-7B-Chat", trust_remote_code=True)
149
- # # model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen1.5-7B-Chat", trust_remote_code=True)
150
-
151
- # #%%
152
- # user_query = "許智傑做過什麼壞事"
153
- # reranked = retrieve_and_rerank(user_query, model_name, collection_name, cross_encoder_model, score_threshold=0.6, search_top_k=20, rerank_top_k=5)
154
- # #%%
155
- # passages = [answer for answer, score in reranked]
156
- # answer = generate_response_from_local_llm(user_query, passages, tokenizer, model, max_new_tokens=256)
157
- # print("回答:", answer)
158
-
159
-
160
-
161
- # %%
 
15
  df.columns = ['Question', 'Answer']
16
  original_len = len(df)
17
 
 
 
 
 
18
  questions = df['Question'].tolist()
19
  answers = df['Answer'].tolist()
20
  # 初始化模型
 
44
  print(f"✅ 向量資料庫建立完成,共嵌入 {len(points)} 筆 QA。")
45
  client.scroll(collection_name=collection_name, limit=100)
46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  def retrieve_and_rerank(query, model_name, collection_name, cross_encoder_model, score_threshold=0.6, search_top_k=25, rerank_top_k=5):
48
  from semantic_reranker import rerank_results
49
  model = SentenceModel(model_name)
50
  client = QdrantClient(path="./qadrant_data")
 
 
 
 
 
51
 
52
  query_vector = model.encode(query, normalize_embeddings=True)
53
  results = client.search(
 
66
  candidate_passages=retrieved_answers,
67
  top_k=rerank_top_k
68
  )
 
 
 
 
 
 
 
 
 
69
  return reranked
70
 
71
 
 
87
  - 回答時平易近人,像和朋友交談一樣。
88
  - 精簡回答,避免冗長的解釋。
89
  回答:"""
 
90
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
91
  outputs = model.generate(
92
  **inputs,
 
98
  decoded_output = tokenizer.decode(outputs[0], skip_special_tokens=True)
99
 
100
  # 提取回答部分
101
+ answer = decoded_output.split("回答:", 1)[-1].strip()
102
+ answer = answer + "大罷免!大成功!"
103
+ return answer