Xalt8 commited on
Commit
1dcc70b
·
1 Parent(s): 466b7d1

reranking working

Browse files
rag_app/loading_data/load_chroma_db_cross_platform.py CHANGED
@@ -8,9 +8,6 @@ import sys
8
  import zipfile
9
 
10
 
11
- S3_LOCATION = os.getenv("S3_LOCATION")
12
-
13
-
14
  def download_chroma_from_s3(s3_location:str,
15
  chroma_vs_name:str,
16
  vectorstore_folder:str,
@@ -32,20 +29,27 @@ def download_chroma_from_s3(s3_location:str,
32
  # Initialize an S3 client with unsigned configuration for public access
33
  s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
34
  s3.download_file(s3_location, chroma_vs_name, vs_save_path)
 
35
 
36
  # Extract the zip file
37
  with zipfile.ZipFile(file=str(vs_save_path), mode='r') as zip_ref:
38
  zip_ref.extractall(path=vectorstore_folder)
39
-
 
40
  except Exception as e:
41
  print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
42
 
43
  # Delete the zip file
44
  vs_save_path.unlink()
 
45
 
46
  if __name__ == "__main__":
 
 
 
47
  chroma_vs_name = "vectorstores/chroma-zurich-mpnet-1500.zip"
48
- project_dir = Path().cwd().parent
 
49
  vs_destination = str(project_dir / 'vectorstore')
50
  assert Path(vs_destination).is_dir(), "Cannot find vectorstore folder"
51
 
 
8
  import zipfile
9
 
10
 
 
 
 
11
  def download_chroma_from_s3(s3_location:str,
12
  chroma_vs_name:str,
13
  vectorstore_folder:str,
 
29
  # Initialize an S3 client with unsigned configuration for public access
30
  s3 = boto3.client('s3', config=Config(signature_version=UNSIGNED))
31
  s3.download_file(s3_location, chroma_vs_name, vs_save_path)
32
+ print('Downloaded file from S3')
33
 
34
  # Extract the zip file
35
  with zipfile.ZipFile(file=str(vs_save_path), mode='r') as zip_ref:
36
  zip_ref.extractall(path=vectorstore_folder)
37
+ print("Extracted zip file")
38
+
39
  except Exception as e:
40
  print(f"Error during downloading or extracting from S3: {e}", file=sys.stderr)
41
 
42
  # Delete the zip file
43
  vs_save_path.unlink()
44
+ print("Deleting zip file")
45
 
46
  if __name__ == "__main__":
47
+
48
+ S3_LOCATION = os.getenv("S3_LOCATION")
49
+
50
  chroma_vs_name = "vectorstores/chroma-zurich-mpnet-1500.zip"
51
+
52
+ project_dir = Path().cwd().parent.parent
53
  vs_destination = str(project_dir / 'vectorstore')
54
  assert Path(vs_destination).is_dir(), "Cannot find vectorstore folder"
55
 
rag_app/reranking.py CHANGED
@@ -80,31 +80,29 @@ def get_reranked_docs_chroma(query:str,
80
 
81
  Returns: A list of documents with the highest rank
82
  """
83
- assert num_docs <= 10, "num_docs should be less than similarity search results"
84
-
85
  embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
86
  model_name=embedding_model)
87
  # Load the vectorstore database
88
  db = Chroma(persist_directory=path_to_db, embedding_function=embeddings)
89
 
90
- # Get 10 documents based on similarity search
91
  sim_docs = db.similarity_search(query=query, k=10)
92
 
93
- # Add the page_content, description and title together
94
  passages = [doc.page_content for doc in sim_docs]
95
 
96
  # Prepare the payload
97
  payload = {"inputs":
98
  {"source_sentence": query,
99
  "sentences": passages}}
100
-
101
 
102
  headers = {"Authorization": f"Bearer {hf_api_key}"}
103
 
104
  response = requests.post(url=reranking_hf_url, headers=headers, json=payload)
 
105
  if response.status_code != 200:
106
  print('Something went wrong with the response')
107
  return
 
108
  similarity_scores = response.json()
109
  ranked_results = sorted(zip(sim_docs, passages, similarity_scores), key=lambda x: x[2], reverse=True)
110
  top_k_results = ranked_results[:num_docs]
@@ -113,16 +111,17 @@ def get_reranked_docs_chroma(query:str,
113
 
114
 
115
  if __name__ == "__main__":
116
-
 
117
  HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
118
  EMBEDDING_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
119
 
120
  project_dir = Path().cwd().parent
121
  path_to_vector_db = str(project_dir/'vectorstore/chroma-zurich-mpnet-1500')
 
122
 
123
  query = "I'm looking for student insurance"
124
 
125
-
126
  re_ranked_docs = get_reranked_docs_chroma(query=query,
127
  path_to_db= path_to_vector_db,
128
  embedding_model=EMBEDDING_MODEL,
 
80
 
81
  Returns: A list of documents with the highest rank
82
  """
 
 
83
  embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
84
  model_name=embedding_model)
85
  # Load the vectorstore database
86
  db = Chroma(persist_directory=path_to_db, embedding_function=embeddings)
87
 
88
+ # Get k documents based on similarity search
89
  sim_docs = db.similarity_search(query=query, k=10)
90
 
 
91
  passages = [doc.page_content for doc in sim_docs]
92
 
93
  # Prepare the payload
94
  payload = {"inputs":
95
  {"source_sentence": query,
96
  "sentences": passages}}
 
97
 
98
  headers = {"Authorization": f"Bearer {hf_api_key}"}
99
 
100
  response = requests.post(url=reranking_hf_url, headers=headers, json=payload)
101
+ print(f'{response = }')
102
  if response.status_code != 200:
103
  print('Something went wrong with the response')
104
  return
105
+
106
  similarity_scores = response.json()
107
  ranked_results = sorted(zip(sim_docs, passages, similarity_scores), key=lambda x: x[2], reverse=True)
108
  top_k_results = ranked_results[:num_docs]
 
111
 
112
 
113
  if __name__ == "__main__":
114
+
115
+
116
  HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
117
  EMBEDDING_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
118
 
119
  project_dir = Path().cwd().parent
120
  path_to_vector_db = str(project_dir/'vectorstore/chroma-zurich-mpnet-1500')
121
+ assert Path(path_to_vector_db).exists(), "Cannot access path_to_vector_db "
122
 
123
  query = "I'm looking for student insurance"
124
 
 
125
  re_ranked_docs = get_reranked_docs_chroma(query=query,
126
  path_to_db= path_to_vector_db,
127
  embedding_model=EMBEDDING_MODEL,