sabazo commited on
Commit
7034c3c
·
unverified ·
2 Parent(s): 76dab82 47feab3

Merge pull request #56 from almutareb/52-move-metadata-creation-and-supplement-into-own-module

Browse files
.gitignore CHANGED
@@ -168,4 +168,8 @@ cython_debug/
168
 
169
  # Databases
170
 
171
- *.db
 
 
 
 
 
168
 
169
  # Databases
170
 
171
+ *.db
172
+
173
+
174
+ # editor realted files
175
+ .vscode/
config.py CHANGED
@@ -2,16 +2,31 @@ import os
2
  from dotenv import load_dotenv
3
  from rag_app.database.db_handler import DataBaseHandler
4
  from langchain_huggingface import HuggingFaceEndpoint
 
 
5
 
6
  load_dotenv()
7
 
8
  SQLITE_FILE_NAME = os.getenv('SOURCES_CACHE')
9
- PERSIST_DIRECTORY = os.getenv('VECTOR_DATABASE_LOCATION')
10
  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
11
  SEVEN_B_LLM_MODEL = os.getenv("SEVEN_B_LLM_MODEL")
12
  BERT_MODEL = os.getenv("BERT_MODEL")
 
 
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
15
  db = DataBaseHandler()
16
 
17
  db.create_all_tables()
 
2
  from dotenv import load_dotenv
3
  from rag_app.database.db_handler import DataBaseHandler
4
  from langchain_huggingface import HuggingFaceEndpoint
5
+ # from langchain_huggingface import HuggingFaceHubEmbeddings
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
 
8
  load_dotenv()
9
 
10
  SQLITE_FILE_NAME = os.getenv('SOURCES_CACHE')
11
+ VECTOR_DATABASE_LOCATION = os.getenv('VECTOR_DATABASE_LOCATION')
12
  EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
13
  SEVEN_B_LLM_MODEL = os.getenv("SEVEN_B_LLM_MODEL")
14
  BERT_MODEL = os.getenv("BERT_MODEL")
15
+ FAISS_INDEX_PATH = os.getenv("FAISS_INDEX_PATH")
16
+ HUGGINGFACEHUB_API_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN")
17
 
18
 
19
+
20
+ # embeddings = HuggingFaceHubEmbeddings(repo_id=EMBEDDING_MODEL)
21
+
22
+ model_kwargs = {'device': 'cpu'}
23
+ encode_kwargs = {'normalize_embeddings': False}
24
+ embeddings = HuggingFaceEmbeddings(
25
+ model_name=EMBEDDING_MODEL,
26
+ model_kwargs=model_kwargs,
27
+ encode_kwargs=encode_kwargs
28
+ )
29
+
30
  db = DataBaseHandler()
31
 
32
  db.create_all_tables()
test_this.py → cookbook/sample_get_faiss.py RENAMED
File without changes
pytest.ini ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [pytest]
2
+ pythonpath = .
rag_app/__init__.py CHANGED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+
4
+ # Add the project root to the Python path
5
+ project_root = str(Path(__file__).parent.parent)
6
+ if project_root not in sys.path:
7
+ sys.path.append(project_root)
rag_app/get_db_retriever.py DELETED
@@ -1,30 +0,0 @@
1
- # retriever and qa_chain function
2
-
3
- # HF libraries
4
- from langchain.llms import HuggingFaceHub
5
- from langchain_huggingface import HuggingFaceHubEmbeddings
6
- # vectorestore
7
- from langchain_community.vectorstores import FAISS
8
- # retrieval chain
9
- from langchain.chains import RetrievalQA
10
- # prompt template
11
- from langchain.prompts import PromptTemplate
12
- from langchain.memory import ConversationBufferMemory
13
-
14
-
15
- def get_db_retriever(vector_db:str=None):
16
- model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
17
- embeddings = HuggingFaceHubEmbeddings(repo_id=model_name)
18
-
19
- #db = Chroma(persist_directory="./vectorstore/lc-chroma-multi-mpnet-500", embedding_function=embeddings)
20
- #db.get()
21
- if not vector_db:
22
- FAISS_INDEX_PATH='./vectorstore/py-faiss-multi-mpnet-500'
23
- else:
24
- FAISS_INDEX_PATH=vector_db
25
- db = FAISS.load_local(FAISS_INDEX_PATH, embeddings)
26
-
27
- retriever = db.as_retriever()
28
-
29
- return retriever
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_app/hybrid_search.py DELETED
@@ -1,63 +0,0 @@
1
- from pathlib import Path
2
- from langchain_community.vectorstores import FAISS
3
- from dotenv import load_dotenv
4
- import os
5
- from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
6
- from langchain.retrievers import EnsembleRetriever
7
- from langchain_community.retrievers import BM25Retriever
8
-
9
-
10
- def get_hybrid_search_results(query:str,
11
- path_to_db:str,
12
- embedding_model:str,
13
- hf_api_key:str,
14
- num_docs:int=5) -> list:
15
- """ Uses an ensemble retriever of BM25 and FAISS to return k num documents
16
-
17
- Args:
18
- query (str): The search query
19
- path_to_db (str): Path to the vectorstore database
20
- embedding_model (str): Embedding model used in the vector store
21
- num_docs (int): Number of documents to return
22
-
23
- Returns
24
- List of documents
25
-
26
- """
27
-
28
- embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
29
- model_name=embedding_model)
30
- # Load the vectorstore database
31
- db = FAISS.load_local(folder_path=path_to_db,
32
- embeddings=embeddings,
33
- allow_dangerous_deserialization=True)
34
-
35
- all_docs = db.similarity_search("", k=db.index.ntotal)
36
-
37
- bm25_retriever = BM25Retriever.from_documents(all_docs)
38
- bm25_retriever.k = num_docs # How many results you want
39
-
40
- faiss_retriever = db.as_retriever(search_kwargs={'k': num_docs})
41
-
42
- ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever],
43
- weights=[0.5,0.5])
44
-
45
- results = ensemble_retriever.invoke(input=query)
46
- return results
47
-
48
-
49
- if __name__ == "__main__":
50
- query = "Haustierversicherung"
51
- HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
52
- EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL")
53
-
54
- path_to_vector_db = Path("..")/'vectorstore/faiss-insurance-agent-500'
55
-
56
- results = get_hybrid_search_results(query=query,
57
- path_to_db=path_to_vector_db,
58
- embedding_model=EMBEDDING_MODEL,
59
- hf_api_key=HUGGINGFACEHUB_API_TOKEN)
60
-
61
- for doc in results:
62
- print(doc)
63
- print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_app/knowledge_base/build_vector_store.py DELETED
@@ -1,85 +0,0 @@
1
- # vectorization functions
2
- from langchain_community.vectorstores import FAISS
3
- from langchain_community.vectorstores import Chroma
4
- #from langchain_community.document_loaders import DirectoryLoader
5
- #from langchain_text_splitters import RecursiveCharacterTextSplitter
6
- #from langchain_community.embeddings.sentence_transformer import (
7
- # SentenceTransformerEmbeddings,
8
- #)
9
- #from langchain_huggingface import HuggingFaceEmbeddings
10
- from langchain_community.retrievers import BM25Retriever
11
- from rag_app.knowledge_base.create_embedding import create_embeddings
12
- from rag_app.utils.generate_summary import generate_description, generate_keywords
13
- import time
14
- import os
15
- #from dotenv import load_dotenv
16
-
17
- def build_vector_store(
18
- docs: list,
19
- db_path: str,
20
- embedding_model: str,
21
- new_db:bool=False,
22
- chunk_size:int=500,
23
- chunk_overlap:int=50,
24
- ):
25
- """
26
-
27
- """
28
-
29
- if db_path is None:
30
- FAISS_INDEX_PATH = os.getenv("FAISS_INDEX_PATH")
31
- else:
32
- FAISS_INDEX_PATH = db_path
33
-
34
- embeddings,chunks = create_embeddings(docs, chunk_size, chunk_overlap, embedding_model)
35
- # for chunk in chunks:
36
- # keywords=generate_keywords(chunk)
37
- # description=generate_description(chunk)
38
- # chunk.metadata['keywords']=keywords
39
- # chunk.metadata['description']=description
40
-
41
- #load chunks into vector store
42
- print(f'Loading chunks into faiss vector store ...')
43
- st = time.time()
44
- if new_db:
45
- db_faiss = FAISS.from_documents(chunks, embeddings)
46
- bm25_retriever = BM25Retriever.from_documents(chunks)
47
- else:
48
- db_faiss = FAISS.add_documents(chunks, embeddings)
49
- bm25_retriever = BM25Retriever.add_documents(chunks)
50
- db_faiss.save_local(FAISS_INDEX_PATH)
51
- et = time.time() - st
52
- print(f'Time taken: {et} seconds.')
53
-
54
- print(f'Loading chunks into chroma vector store ...')
55
- st = time.time()
56
- persist_directory='./vectorstore/chroma-insurance-agent-1500'
57
- db_chroma = Chroma.from_documents(chunks, embeddings, persist_directory=persist_directory)
58
- et = time.time() - st
59
- print(f'Time taken: {et} seconds.')
60
- result = f"built vectore store at {FAISS_INDEX_PATH}"
61
- return result
62
-
63
-
64
- # # Path for saving the FAISS index
65
- # FAISS_INDEX_PATH = "./vectorstore/lc-faiss-multi-mpnet-500"
66
-
67
- # try:
68
- # # Stage two: Vectorization of the document chunks
69
- # model_name = "sentence-transformers/multi-qa-mpnet-base-dot-v1" # Model used for embedding
70
-
71
- # # Initialize HuggingFace embeddings with the specified model
72
- # embeddings = HuggingFaceEmbeddings(model_name=model_name)
73
-
74
- # print(f'Loading chunks into vector store ...')
75
- # st = time.time() # Start time for performance measurement
76
- # # Create a FAISS vector store from the document chunks and save it locally
77
- # db = FAISS.from_documents(filter_complex_metadata(chunks), embeddings)
78
- # db.save_local(FAISS_INDEX_PATH)
79
- # et = time.time() - st # Calculate time taken for vectorization
80
- # print(f'Time taken for vectorization and saving: {et} seconds.')
81
- # except Exception as e:
82
- # print(f"Error during vectorization or FAISS index saving: {e}", file=sys.stderr)
83
-
84
- # alternatively download a preparaed vectorized index from S3 and load the index into vectorstore
85
- # Import necessary libraries for AWS S3 interaction, file handling, and FAISS vector stores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_app/knowledge_base/create_embedding.py DELETED
@@ -1,54 +0,0 @@
1
- # embeddings functions
2
- #from langchain_community.vectorstores import FAISS
3
- #from langchain_community.document_loaders import ReadTheDocsLoader
4
- #from langchain_community.vectorstores.utils import filter_complex_metadata
5
- from langchain_text_splitters import RecursiveCharacterTextSplitter
6
- # from langchain_huggingface import HuggingFaceEmbeddings
7
- from langchain_community.embeddings.sentence_transformer import (
8
- SentenceTransformerEmbeddings,
9
- )
10
- import time
11
- from langchain_core.documents import Document
12
-
13
-
14
- def create_embeddings(
15
- docs: list[Document],
16
- chunk_size:int = 500,
17
- chunk_overlap:int = 50,
18
- embedding_model: str = "sentence-transformers/multi-qa-mpnet-base-dot-v1",
19
- ):
20
- """given a sequence of `Document` objects this fucntion will
21
- generate embeddings for it.
22
-
23
- ## argument
24
- :params docs (list[Document]) -> list of `list[Document]`
25
- :params chunk_size (int) -> chunk size in which documents are chunks, defaults to 500
26
- :params chunk_overlap (int) -> the amount of token that will be overlapped between chunks, defaults to 50
27
- :params embedding_model (str) -> the huggingspace model that will embed the documents
28
- ## Return
29
- Tuple of embedding and chunks
30
- """
31
-
32
-
33
- text_splitter = RecursiveCharacterTextSplitter(
34
- separators=["\n\n", "\n", "(?<=\. )", " ", ""],
35
- chunk_size = chunk_size,
36
- chunk_overlap = chunk_overlap,
37
- length_function = len,
38
- )
39
-
40
- # Stage one: read all the docs, split them into chunks.
41
- st = time.time()
42
- print('Loading documents and creating chunks ...')
43
-
44
- # Split each document into chunks using the configured text splitter
45
- chunks = text_splitter.create_documents([doc.page_content for doc in docs], metadatas=[doc.metadata for doc in docs])
46
- et = time.time() - st
47
- print(f'Time taken to chunk {len(docs)} documents: {et} seconds.')
48
-
49
- #Stage two: embed the docs.
50
- #embeddings = HuggingFaceEmbeddings(model_name=embedding_model)
51
- embeddings = SentenceTransformerEmbeddings(model_name=embedding_model)
52
- print(f"created a total of {len(chunks)} chunks")
53
-
54
- return embeddings,chunks
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_app/{multi_index_search.py → knowledge_base/multi_index_search.py} RENAMED
File without changes
rag_app/knowledge_base/utils.py CHANGED
@@ -1,10 +1,75 @@
1
  from langchain_core.documents import Document
2
  from chains import generate_document_summary_prompt
3
- from config import SEVEN_B_LLM_MODEL
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def generate_document_summaries(
7
- docs: list[Document]
 
8
  ) -> list[Document]:
9
  """
10
  Generates summaries for a list of Document objects and updates their metadata with the summaries.
@@ -27,7 +92,7 @@ def generate_document_summaries(
27
 
28
  for doc in new_docs:
29
 
30
- genrate_summary_chain = generate_document_summary_prompt | SEVEN_B_LLM_MODEL
31
  summary = genrate_summary_chain.invoke(
32
  {"document":str(doc.metadata)}
33
  )
@@ -36,4 +101,51 @@ def generate_document_summaries(
36
  {"summary":summary}
37
  )
38
 
39
- return new_docs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from langchain_core.documents import Document
2
  from chains import generate_document_summary_prompt
3
+ # embeddings functions
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from langchain_community.embeddings.sentence_transformer import (
6
+ SentenceTransformerEmbeddings,
7
+ )
8
+ import time
9
+ from langchain_core.language_models import BaseChatModel
10
+ from langchain.retrievers import VectorStoreRetriever
11
+ from langchain_core.vectorstores import VectorStoreRetriever
12
+ # vectorization functions
13
+ from langchain_community.vectorstores import FAISS
14
+ from langchain_community.vectorstores import Chroma
15
+ from langchain_community.retrievers import BM25Retriever
16
+ from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
17
+
18
+
19
+ from pathlib import Path
20
+ from langchain_community.vectorstores import FAISS
21
+ from dotenv import load_dotenv
22
+ import os
23
+ import requests
24
+
25
+ from rag_app.knowledge_base.utils import create_embeddings
26
+ from rag_app.utils.generate_summary import generate_description, generate_keywords
27
+ from config import EMBEDDING_MODEL, FAISS_INDEX_PATH, SEVEN_B_LLM_MODEL
28
+
29
+ def create_embeddings(
30
+ docs: list[Document],
31
+ chunk_size:int = 500,
32
+ chunk_overlap:int = 50,
33
+ ):
34
+ """given a sequence of `Document` objects this fucntion will
35
+ generate embeddings for it.
36
+
37
+ ## argument
38
+ :params docs (list[Document]) -> list of `list[Document]`
39
+ :params chunk_size (int) -> chunk size in which documents are chunks, defaults to 500
40
+ :params chunk_overlap (int) -> the amount of token that will be overlapped between chunks, defaults to 50
41
+ :params embedding_model (str) -> the huggingspace model that will embed the documents
42
+ ## Return
43
+ Tuple of embedding and chunks
44
+ """
45
+
46
+
47
+ text_splitter = RecursiveCharacterTextSplitter(
48
+ separators=["\n\n", "\n", "(?<=\. )", " ", ""],
49
+ chunk_size = chunk_size,
50
+ chunk_overlap = chunk_overlap,
51
+ length_function = len,
52
+ )
53
+
54
+ # Stage one: read all the docs, split them into chunks.
55
+ st = time.time()
56
+ print('Loading documents and creating chunks ...')
57
+
58
+ # Split each document into chunks using the configured text splitter
59
+ chunks = text_splitter.create_documents([doc.page_content for doc in docs], metadatas=[doc.metadata for doc in docs])
60
+ et = time.time() - st
61
+ print(f'Time taken to chunk {len(docs)} documents: {et} seconds.')
62
+
63
+ #Stage two: embed the docs.
64
+ embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)
65
+ print(f"created a total of {len(chunks)} chunks")
66
+
67
+ return embeddings,chunks
68
 
69
 
70
  def generate_document_summaries(
71
+ docs: list[Document],
72
+ llm:BaseChatModel= SEVEN_B_LLM_MODEL,
73
  ) -> list[Document]:
74
  """
75
  Generates summaries for a list of Document objects and updates their metadata with the summaries.
 
92
 
93
  for doc in new_docs:
94
 
95
+ genrate_summary_chain = generate_document_summary_prompt | llm
96
  summary = genrate_summary_chain.invoke(
97
  {"document":str(doc.metadata)}
98
  )
 
101
  {"summary":summary}
102
  )
103
 
104
+ return new_docs
105
+
106
+
107
+ def build_vector_store(
108
+ docs: list,
109
+ embedding_model: str,
110
+ new_db:bool=False,
111
+ chunk_size:int=500,
112
+ chunk_overlap:int=50,
113
+ ):
114
+ """
115
+
116
+ """
117
+
118
+ embeddings,chunks = create_embeddings(
119
+ docs,
120
+ chunk_size,
121
+ chunk_overlap,
122
+ embedding_model
123
+ )
124
+
125
+ #load chunks into vector store
126
+ print(f'Loading chunks into faiss vector store ...')
127
+
128
+ st = time.time()
129
+ if new_db:
130
+ db_faiss = FAISS.from_documents(chunks, embeddings)
131
+ bm25_retriever = BM25Retriever.from_documents(chunks)
132
+ else:
133
+ db_faiss = FAISS.add_documents(chunks, embeddings)
134
+ bm25_retriever = BM25Retriever.add_documents(chunks)
135
+
136
+ db_faiss.save_local(FAISS_INDEX_PATH)
137
+ et = time.time() - st
138
+ print(f'Time taken: {et} seconds.')
139
+
140
+ print(f'Loading chunks into chroma vector store ...')
141
+
142
+ st = time.time()
143
+ persist_directory='./vectorstore/chroma-insurance-agent-1500'
144
+ db_chroma = Chroma.from_documents(chunks, embeddings, persist_directory=persist_directory)
145
+ et = time.time() - st
146
+
147
+ print(f'Time taken: {et} seconds.')
148
+ result = f"built vectore store at {FAISS_INDEX_PATH}"
149
+ return result
150
+
151
+
rag_app/reranking.py DELETED
@@ -1,131 +0,0 @@
1
- # from get_db_retriever import get_db_retriever
2
- from pathlib import Path
3
- from langchain_community.vectorstores import FAISS
4
- from dotenv import load_dotenv
5
- import os
6
- from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings
7
- import requests
8
- from langchain_community.vectorstores import Chroma
9
-
10
-
11
- load_dotenv()
12
-
13
-
14
- def get_reranked_docs_faiss(query:str,
15
- path_to_db:str,
16
- embedding_model:str,
17
- hf_api_key:str,
18
- num_docs:int=5) -> list:
19
- """ Re-ranks the similarity search results and returns top-k highest ranked docs
20
-
21
- Args:
22
- query (str): The search query
23
- path_to_db (str): Path to the vectorstore database
24
- embedding_model (str): Embedding model used in the vector store
25
- num_docs (int): Number of documents to return
26
-
27
- Returns: A list of documents with the highest rank
28
- """
29
- assert num_docs <= 10, "num_docs should be less than similarity search results"
30
-
31
- embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
32
- model_name=embedding_model)
33
- # Load the vectorstore database
34
- db = FAISS.load_local(folder_path=path_to_db,
35
- embeddings=embeddings,
36
- allow_dangerous_deserialization=True)
37
-
38
- # Get 10 documents based on similarity search
39
- docs = db.similarity_search(query=query, k=10)
40
-
41
- # Add the page_content, description and title together
42
- passages = [doc.page_content + "\n" + doc.metadata.get('title', "") +"\n"+ doc.metadata.get('description', "")
43
- for doc in docs]
44
-
45
- # Prepare the payload
46
- inputs = [{"text": query, "text_pair": passage} for passage in passages]
47
-
48
- API_URL = "https://api-inference.huggingface.co/models/deepset/gbert-base-germandpr-reranking"
49
- headers = {"Authorization": f"Bearer {hf_api_key}"}
50
-
51
- response = requests.post(API_URL, headers=headers, json=inputs)
52
- scores = response.json()
53
-
54
- try:
55
- relevance_scores = [item[1]['score'] for item in scores]
56
- except ValueError as e:
57
- print('Could not get the relevance_scores -> something might be wrong with the json output')
58
- return
59
-
60
- if relevance_scores:
61
- ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True)
62
- top_k_results = ranked_results[:num_docs]
63
- return [doc for doc, _, _ in top_k_results]
64
-
65
-
66
-
67
- def get_reranked_docs_chroma(query:str,
68
- path_to_db:str,
69
- embedding_model:str,
70
- hf_api_key:str,
71
- reranking_hf_url:str = "https://api-inference.huggingface.co/models/sentence-transformers/all-mpnet-base-v2",
72
- num_docs:int=5) -> list:
73
- """ Re-ranks the similarity search results and returns top-k highest ranked docs
74
-
75
- Args:
76
- query (str): The search query
77
- path_to_db (str): Path to the vectorstore database
78
- embedding_model (str): Embedding model used in the vector store
79
- num_docs (int): Number of documents to return
80
-
81
- Returns: A list of documents with the highest rank
82
- """
83
- embeddings = HuggingFaceInferenceAPIEmbeddings(api_key=hf_api_key,
84
- model_name=embedding_model)
85
- # Load the vectorstore database
86
- db = Chroma(persist_directory=path_to_db, embedding_function=embeddings)
87
-
88
- # Get k documents based on similarity search
89
- sim_docs = db.similarity_search(query=query, k=10)
90
-
91
- passages = [doc.page_content for doc in sim_docs]
92
-
93
- # Prepare the payload
94
- payload = {"inputs":
95
- {"source_sentence": query,
96
- "sentences": passages}}
97
-
98
- headers = {"Authorization": f"Bearer {hf_api_key}"}
99
-
100
- response = requests.post(url=reranking_hf_url, headers=headers, json=payload)
101
- print(f'{response = }')
102
- if response.status_code != 200:
103
- print('Something went wrong with the response')
104
- return
105
-
106
- similarity_scores = response.json()
107
- ranked_results = sorted(zip(sim_docs, passages, similarity_scores), key=lambda x: x[2], reverse=True)
108
- top_k_results = ranked_results[:num_docs]
109
- return [doc for doc, _, _ in top_k_results]
110
-
111
-
112
-
113
- if __name__ == "__main__":
114
-
115
-
116
- HUGGINGFACEHUB_API_TOKEN = os.getenv('HUGGINGFACEHUB_API_TOKEN')
117
- EMBEDDING_MODEL = "sentence-transformers/multi-qa-mpnet-base-dot-v1"
118
-
119
- project_dir = Path().cwd().parent
120
- path_to_vector_db = str(project_dir/'vectorstore/chroma-zurich-mpnet-1500')
121
- assert Path(path_to_vector_db).exists(), "Cannot access path_to_vector_db "
122
-
123
- query = "I'm looking for student insurance"
124
-
125
- re_ranked_docs = get_reranked_docs_chroma(query=query,
126
- path_to_db= path_to_vector_db,
127
- embedding_model=EMBEDDING_MODEL,
128
- hf_api_key=HUGGINGFACEHUB_API_TOKEN)
129
-
130
-
131
- print(f"{re_ranked_docs=}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
rag_app/structured_tools/structured_tools.py CHANGED
@@ -13,9 +13,9 @@ from rag_app.utils.utils import (
13
  )
14
  import chromadb
15
  import os
16
- from config import db, PERSIST_DIRECTORY, EMBEDDING_MODEL
17
 
18
- if not os.path.exists(PERSIST_DIRECTORY):
19
  get_chroma_vs()
20
 
21
  @tool
@@ -24,7 +24,7 @@ def memory_search(query:str) -> str:
24
  This is your primary source to start your search with checking what you already have learned from the past, before going online."""
25
  # Since we have more than one collections we should change the name of this tool
26
  client = chromadb.PersistentClient(
27
- path=PERSIST_DIRECTORY,
28
  )
29
 
30
  collection_name = os.getenv('CONVERSATION_COLLECTION_NAME')
@@ -71,7 +71,7 @@ def knowledgeBase_search(query:str) -> str:
71
  # #collection_name=collection_name,
72
  # embedding_function=embedding_function,
73
  # )
74
- vector_db = Chroma(persist_directory=PERSIST_DIRECTORY, embedding_function=embedding_function)
75
  retriever = vector_db.as_retriever(search_type="mmr", search_kwargs={'k':5, 'fetch_k':10})
76
  # This is deprecated, changed to invoke
77
  # LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.
 
13
  )
14
  import chromadb
15
  import os
16
+ from config import db, VECTOR_DATABASE_LOCATION, EMBEDDING_MODEL
17
 
18
+ if not os.path.exists(VECTOR_DATABASE_LOCATION):
19
  get_chroma_vs()
20
 
21
  @tool
 
24
  This is your primary source to start your search with checking what you already have learned from the past, before going online."""
25
  # Since we have more than one collections we should change the name of this tool
26
  client = chromadb.PersistentClient(
27
+ path=VECTOR_DATABASE_LOCATION,
28
  )
29
 
30
  collection_name = os.getenv('CONVERSATION_COLLECTION_NAME')
 
71
  # #collection_name=collection_name,
72
  # embedding_function=embedding_function,
73
  # )
74
+ vector_db = Chroma(persist_directory=VECTOR_DATABASE_LOCATION, embedding_function=embedding_function)
75
  retriever = vector_db.as_retriever(search_type="mmr", search_kwargs={'k':5, 'fetch_k':10})
76
  # This is deprecated, changed to invoke
77
  # LangChainDeprecationWarning: The method `BaseRetriever.get_relevant_documents` was deprecated in langchain-core 0.1.46 and will be removed in 0.3.0. Use invoke instead.
rag_app/vector_store_handler/__init__.py ADDED
File without changes
rag_app/vector_store_handler/vectorstores.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from langchain.vectorstores import Chroma, FAISS
3
+ from langchain.embeddings import OpenAIEmbeddings
4
+ from langchain.text_splitter import CharacterTextSplitter
5
+ from langchain.document_loaders import TextLoader
6
+
7
+
8
+ from langchain_community.embeddings.sentence_transformer import (
9
+ SentenceTransformerEmbeddings,
10
+ )
11
+ import time
12
+ from langchain_core.documents import Document
13
+ from config import EMBEDDING_MODEL, HUGGINGFACEHUB_API_TOKEN
14
+ from langchain.retrievers import EnsembleRetriever
15
+ from langchain_community.retrievers import BM25Retriever
16
+ import requests
17
+
18
+ class BaseVectorStore(ABC):
19
+ """
20
+ Abstract base class for vector stores.
21
+
22
+ This class defines the interface for vector stores and implements
23
+ common functionality.
24
+ """
25
+
26
+ def __init__(self, embedding_model, persist_directory=None):
27
+ """
28
+ Initialize the BaseVectorStore.
29
+
30
+ Args:
31
+ embedding_model: The embedding model to use for vectorizing text.
32
+ persist_directory (str, optional): Directory to persist the vector store.
33
+ """
34
+ self.persist_directory = persist_directory
35
+ self.embeddings = embedding_model
36
+ self.vectorstore = None
37
+
38
+ def load_and_process_documents(self, file_path, chunk_size=1000, chunk_overlap=0):
39
+ """
40
+ Load and process documents from a file.
41
+
42
+ Args:
43
+ file_path (str): Path to the file to load.
44
+ chunk_size (int): Size of text chunks for processing.
45
+ chunk_overlap (int): Overlap between chunks.
46
+
47
+ Returns:
48
+ list: Processed documents.
49
+ """
50
+ loader = TextLoader(file_path)
51
+ documents = loader.load()
52
+ text_splitter = CharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=chunk_overlap)
53
+ return text_splitter.split_documents(documents)
54
+
55
+ def get_hybrid_search_result(self,query:str):
56
+ pass
57
+
58
+ @abstractmethod
59
+ def create_vectorstore(self, texts):
60
+ """
61
+ Create a new vector store from the given texts.
62
+
63
+ Args:
64
+ texts (list): List of texts to vectorize and store.
65
+ """
66
+ pass
67
+
68
+ @abstractmethod
69
+ def load_existing_vectorstore(self):
70
+ """
71
+ Load an existing vector store from the persist directory.
72
+ """
73
+ pass
74
+
75
+ def similarity_search(self, query):
76
+ """
77
+ Perform a similarity search on the vector store.
78
+
79
+ Args:
80
+ query (str): The query text to search for.
81
+
82
+ Returns:
83
+ list: Search results.
84
+
85
+ Raises:
86
+ ValueError: If the vector store is not initialized.
87
+ """
88
+ if self.vectorstore is None:
89
+ raise ValueError("Vector store not initialized. Call create_vectorstore or load_existing_vectorstore first.")
90
+ return self.vectorstore.similarity_search(query)
91
+
92
+ @abstractmethod
93
+ def save(self):
94
+ """
95
+ Save the current state of the vector store.
96
+ """
97
+ pass
98
+
99
+
100
+ class ChromaVectorStore(BaseVectorStore):
101
+ """
102
+ Implementation of BaseVectorStore using Chroma as the backend.
103
+ """
104
+
105
+ def create_vectorstore(self, texts):
106
+ """
107
+ Create a new Chroma vector store from the given texts.
108
+
109
+ Args:
110
+ texts (list): List of texts to vectorize and store.
111
+ """
112
+ self.vectorstore = Chroma.from_documents(
113
+ texts,
114
+ self.embeddings,
115
+ persist_directory=self.persist_directory
116
+ )
117
+
118
+ def load_existing_vectorstore(self):
119
+ """
120
+ Load an existing Chroma vector store from the persist directory.
121
+
122
+ Raises:
123
+ ValueError: If persist_directory is not set.
124
+ """
125
+ if self.persist_directory is not None:
126
+ self.vectorstore = Chroma(
127
+ persist_directory=self.persist_directory,
128
+ embedding_function=self.embeddings
129
+ )
130
+ else:
131
+ raise ValueError("Persist directory is required for loading Chroma.")
132
+
133
+ def save(self):
134
+ """
135
+ Save the current state of the Chroma vector store.
136
+
137
+ Raises:
138
+ ValueError: If the vector store is not initialized.
139
+ """
140
+ if not self.vectorstore:
141
+ raise ValueError("Vector store not initialized. Nothing to save.")
142
+ self.vectorstore.persist()
143
+
144
+ def get_reranked_docs(
145
+ self,
146
+ query:str,
147
+ num_docs:int=5
148
+ ):
149
+ """ Re-ranks the similarity search results and returns top-k highest ranked docs
150
+
151
+ Args:
152
+ query (str): The search query
153
+ path_to_db (str): Path to the vectorstore database
154
+ embedding_model (str): Embedding model used in the vector store
155
+ num_docs (int): Number of documents to return
156
+
157
+ Returns: A list of documents with the highest rank
158
+ """
159
+
160
+ # Get k documents based on similarity search
161
+ sim_docs = self.vectorstore.similarity_search(query=query, k=10)
162
+
163
+ # Add the page_content, description and title together
164
+ passages = [doc.page_content for doc in sim_docs]
165
+
166
+ # Prepare the payload
167
+ payload = {"inputs":
168
+ {"source_sentence": query,
169
+ "sentences": passages}}
170
+
171
+ headers = {"Authorization": f"Bearer {HUGGINGFACEHUB_API_TOKEN}"}
172
+ reranking_hf_url:str = "https://api-inference.huggingface.co/models/sentence-transformers/all-mpnet-base-v2"
173
+
174
+ response = requests.post(url=reranking_hf_url, headers=headers, json=payload)
175
+ print(f'{response = }')
176
+ if response.status_code != 200:
177
+ print('Something went wrong with the response')
178
+ return
179
+
180
+ similarity_scores = response.json()
181
+ ranked_results = sorted(zip(sim_docs, passages, similarity_scores), key=lambda x: x[2], reverse=True)
182
+ top_k_results = ranked_results[:num_docs]
183
+ return [doc for doc, _, _ in top_k_results]
184
+
185
+
186
+
187
+ class FAISSVectorStore(BaseVectorStore):
188
+ """
189
+ Implementation of BaseVectorStore using FAISS as the backend.
190
+ """
191
+
192
+ def create_vectorstore(self, texts):
193
+ """
194
+ Create a new FAISS vector store from the given texts.
195
+
196
+ Args:
197
+ texts (list): List of texts to vectorize and store.
198
+ """
199
+ self.vectorstore = FAISS.from_documents(texts, self.embeddings)
200
+
201
+ def load_existing_vectorstore(self,allow_dangerous_deserialization:bool=False):
202
+ """
203
+ Load an existing FAISS vector store from the persist directory.
204
+
205
+ Raises:
206
+ ValueError: If persist_directory is not set.
207
+ """
208
+ if self.persist_directory:
209
+ self.vectorstore = FAISS.load_local(self.persist_directory, self.embeddings, allow_dangerous_deserialization)
210
+ else:
211
+ raise ValueError("Persist directory is required for loading FAISS.")
212
+
213
+ def save(self):
214
+ """
215
+ Save the current state of the FAISS vector store.
216
+
217
+ Raises:
218
+ ValueError: If the vector store is not initialized.
219
+ """
220
+ if self.vectorstore is None:
221
+ raise ValueError("Vector store not initialized. Nothing to save.")
222
+ self.vectorstore.save_local(self.persist_directory)
223
+
224
+ def get_hybrid_search_result(
225
+ self,
226
+ query:str,
227
+ num_docs:int=5
228
+ )-> list[Document]:
229
+ """ Uses an ensemble retriever of BM25 and FAISS to return k num documents
230
+
231
+ Args:
232
+ query (str): The search query
233
+ path_to_db (str): Path to the vectorstore database
234
+ embedding_model (str): Embedding model used in the vector store
235
+ num_docs (int): Number of documents to return
236
+
237
+ Returns
238
+ List of documents
239
+
240
+ """
241
+ all_docs = self.vectorstore.similarity_search("", k=self.vectorstore.index.ntotal)
242
+ bm25_retriever = BM25Retriever.from_documents(all_docs)
243
+ bm25_retriever.k = num_docs # How many results you want
244
+
245
+ faiss_retriever = self.vectorstore.as_retriever(search_kwargs={'k': num_docs})
246
+
247
+ ensemble_retriever = EnsembleRetriever(retrievers=[bm25_retriever, faiss_retriever],
248
+ weights=[0.5,0.5])
249
+
250
+ results = ensemble_retriever.invoke(input=query)
251
+ return results
252
+
253
+ def get_reranked_docs(
254
+ self,
255
+ query:str,
256
+ num_docs:int=5
257
+ ):
258
+
259
+ # Get 10 documents based on similarity search
260
+ docs = self.vectorstore.similarity_search(query=query, k=10)
261
+
262
+ # Add the page_content, description and title together
263
+ passages = [doc.page_content + "\n" + doc.metadata.get('title', "") +"\n"+ doc.metadata.get('description', "")
264
+ for doc in docs]
265
+ # Prepare the payload
266
+ inputs = [{"text": query, "text_pair": passage} for passage in passages]
267
+
268
+ API_URL = "https://api-inference.huggingface.co/models/deepset/gbert-base-germandpr-reranking"
269
+ headers = {"Authorization": f"Bearer {HUGGINGFACEHUB_API_TOKEN}"}
270
+
271
+ response = requests.post(API_URL, headers=headers, json=inputs)
272
+ scores = response.json()
273
+
274
+ try:
275
+ relevance_scores = [item[1]['score'] for item in scores]
276
+ except ValueError as e:
277
+ print('Could not get the relevance_scores -> something might be wrong with the json output')
278
+ return
279
+
280
+ if relevance_scores:
281
+ ranked_results = sorted(zip(docs, passages, relevance_scores), key=lambda x: x[2], reverse=True)
282
+ top_k_results = ranked_results[:num_docs]
283
+ return [doc for doc, _, _ in top_k_results]
284
+
285
+ # Usage example:
286
+ def main():
287
+ """
288
+ Example usage of the vector store classes.
289
+ """
290
+ # Create an embedding model
291
+ embedding_model = OpenAIEmbeddings()
292
+
293
+ embeddings = SentenceTransformerEmbeddings(model_name=EMBEDDING_MODEL)
294
+
295
+
296
+ # Using Chroma
297
+ chroma_store = ChromaVectorStore(embedding_model, persist_directory="./chroma_store")
298
+ texts = chroma_store.load_and_process_documents("docs/placeholder.txt")
299
+ chroma_store.create_vectorstore(texts)
300
+ results = chroma_store.similarity_search("Your query here")
301
+ print("Chroma results:", results[0].page_content)
302
+ chroma_store.save()
303
+
304
+ # Load existing Chroma store
305
+ existing_chroma = ChromaVectorStore(embedding_model, persist_directory="./chroma_store")
306
+ existing_chroma.load_existing_vectorstore()
307
+ results = existing_chroma.similarity_search("Another query")
308
+ print("Existing Chroma results:", results[0].page_content)
309
+
310
+ # Using FAISS
311
+ faiss_store = FAISSVectorStore(embedding_model, persist_directory="./faiss_store")
312
+ texts = faiss_store.load_and_process_documents("path/to/your/file.txt")
313
+ faiss_store.create_vectorstore(texts)
314
+ results = faiss_store.similarity_search("Your query here")
315
+ print("FAISS results:", results[0].page_content)
316
+ faiss_store.save()
317
+
318
+ # Load existing FAISS store
319
+ existing_faiss = FAISSVectorStore(embedding_model, persist_directory="./faiss_store")
320
+ existing_faiss.load_existing_vectorstore()
321
+ results = existing_faiss.similarity_search("Another query")
322
+ print("Existing FAISS results:", results[0].page_content)
323
+
324
+ if __name__ == "__main__":
325
+ main()
tests/integration/test_vector_store_integration.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ from langchain.schema import Document
3
+ from rag_app.vector_store_handler.vectorstores import ChromaVectorStore, FAISSVectorStore
4
+ # from rag_app.database.init_db import db
5
+ from config import EMBEDDING_MODEL, VECTOR_DATABASE_LOCATION
6
+ from langchain.embeddings import HuggingFaceEmbeddings # Or whatever embedding you're using
7
+
8
+ @pytest.fixture(scope="module")
9
+ def embedding_model():
10
+ return HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
11
+
12
+ @pytest.fixture(params=[ChromaVectorStore, FAISSVectorStore])
13
+ def vector_store(request, embedding_model, tmp_path):
14
+ store = request.param(embedding_model, persist_directory=str(tmp_path))
15
+ yield store
16
+ # Clean up (if necessary)
17
+ if hasattr(store, 'vectorstore'):
18
+ store.vectorstore.delete_collection()
19
+
20
+ @pytest.fixture
21
+ def sample_documents():
22
+ return [
23
+ Document(page_content="This is a test document about AI."),
24
+ Document(page_content="Another document discussing machine learning."),
25
+ Document(page_content="A third document about natural language processing.")
26
+ ]
27
+
28
+ def test_create_and_search(vector_store, sample_documents):
29
+ # Create vector store
30
+ vector_store.create_vectorstore(sample_documents)
31
+
32
+ # Perform a search
33
+ results = vector_store.similarity_search("AI and machine learning")
34
+
35
+ assert len(results) > 0
36
+ assert any("AI" in doc.page_content for doc in results)
37
+ assert any("machine learning" in doc.page_content for doc in results)
38
+
39
+ def test_save_and_load(vector_store, sample_documents, tmp_path):
40
+ # Create and save vector store
41
+ vector_store.create_vectorstore(sample_documents)
42
+ vector_store.save()
43
+
44
+ # Load the vector store
45
+ loaded_store = type(vector_store)(vector_store.embeddings, persist_directory=str(tmp_path))
46
+ loaded_store.load_existing_vectorstore()
47
+
48
+ # Perform a search on the loaded store
49
+ results = loaded_store.similarity_search("natural language processing")
50
+
51
+ assert len(results) > 0
52
+ assert any("natural language processing" in doc.page_content for doc in results)
53
+
54
+ def test_update_vectorstore(vector_store, sample_documents):
55
+ # Create initial vector store
56
+ vector_store.create_vectorstore(sample_documents)
57
+
58
+ # Add a new document
59
+ new_doc = Document(page_content="A new document about deep learning.")
60
+ vector_store.vectorstore.add_documents([new_doc])
61
+
62
+ # Search for the new content
63
+ results = vector_store.similarity_search("deep learning")
64
+
65
+ assert len(results) > 0
66
+ assert any("deep learning" in doc.page_content for doc in results)
67
+
68
+ @pytest.mark.parametrize("query,expected_content", [
69
+ ("AI", "AI"),
70
+ ("machine learning", "machine learning"),
71
+ ("natural language processing", "natural language processing")
72
+ ])
73
+ def test_search_accuracy(vector_store, sample_documents, query, expected_content):
74
+ vector_store.create_vectorstore(sample_documents)
75
+ results = vector_store.similarity_search(query)
76
+ assert any(expected_content in doc.page_content for doc in results)
77
+
78
+ # def test_database_integration(vector_store, sample_documents):
79
+ # # This test assumes your vector store interacts with the database in some way
80
+ # # You may need to adjust this based on your actual implementation
81
+ # vector_store.create_vectorstore(sample_documents)
82
+
83
+ # # Here, you might add some assertions about how the vector store interacts with the database
84
+ # # For example, if you're storing metadata about the documents in the database:
85
+ # for doc in sample_documents:
86
+ # result = db.session.query(YourDocumentModel).filter_by(content=doc.page_content).first()
87
+ # assert result is not None
88
+
89
+ # Add more integration tests as needed
tests/vector_store_handler/test_vectorstores.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ from unittest.mock import MagicMock, patch
3
+ # from langchain.embeddings import OpenAIEmbeddings
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ # from langchain.schema import Document
6
+ from langchain_core.documents import Document
7
+
8
+ # Update the import to reflect your project structure
9
+ from rag_app.vector_store_handler.vectorstores import BaseVectorStore, ChromaVectorStore, FAISSVectorStore
10
+
11
+ class TestBaseVectorStore(unittest.TestCase):
12
+ def setUp(self):
13
+ self.embedding_model = MagicMock(spec=HuggingFaceEmbeddings)
14
+ self.base_store = BaseVectorStore(self.embedding_model, "test_dir")
15
+
16
+ def test_init(self):
17
+ self.assertEqual(self.base_store.persist_directory, "test_dir")
18
+ self.assertEqual(self.base_store.embeddings, self.embedding_model)
19
+ self.assertIsNone(self.base_store.vectorstore)
20
+
21
+ @patch('rag_app.vector_store_handler.vectorstores.TextLoader')
22
+ @patch('rag_app.vector_store_handler.vectorstores.CharacterTextSplitter')
23
+ def test_load_and_process_documents(self, mock_splitter, mock_loader):
24
+ mock_loader.return_value.load.return_value = ["doc1", "doc2"]
25
+ mock_splitter.return_value.split_documents.return_value = ["split1", "split2"]
26
+
27
+ result = self.base_store.load_and_process_documents("test.txt")
28
+
29
+ mock_loader.assert_called_once_with("test.txt")
30
+ mock_splitter.assert_called_once_with(chunk_size=1000, chunk_overlap=0)
31
+ self.assertEqual(result, ["split1", "split2"])
32
+
33
+ def test_similarity_search_not_initialized(self):
34
+ with self.assertRaises(ValueError):
35
+ self.base_store.similarity_search("query")
36
+
37
+ class TestChromaVectorStore(unittest.TestCase):
38
+ def setUp(self):
39
+ self.embedding_model = MagicMock(spec=HuggingFaceEmbeddings)
40
+ self.chroma_store = ChromaVectorStore(self.embedding_model, "test_dir")
41
+
42
+ @patch('rag_app.vector_store_handler.vectorstores.Chroma')
43
+ def test_create_vectorstore(self, mock_chroma):
44
+ texts = [Document(page_content="test")]
45
+ self.chroma_store.create_vectorstore(texts)
46
+ mock_chroma.from_documents.assert_called_once_with(
47
+ texts,
48
+ self.embedding_model,
49
+ persist_directory="test_dir"
50
+ )
51
+
52
+ @patch('rag_app.vector_store_handler.vectorstores.Chroma')
53
+ def test_load_existing_vectorstore(self, mock_chroma):
54
+ self.chroma_store.load_existing_vectorstore()
55
+ mock_chroma.assert_called_once_with(
56
+ persist_directory="test_dir",
57
+ embedding_function=self.embedding_model
58
+ )
59
+
60
+ def test_save(self):
61
+ self.chroma_store.vectorstore = MagicMock()
62
+ self.chroma_store.save()
63
+ self.chroma_store.vectorstore.persist.assert_called_once()
64
+
65
+ class TestFAISSVectorStore(unittest.TestCase):
66
+ def setUp(self):
67
+ self.embedding_model = MagicMock(spec=HuggingFaceEmbeddings)
68
+ self.faiss_store = FAISSVectorStore(self.embedding_model, "test_dir")
69
+
70
+ @patch('rag_app.vector_store_handler.vectorstores.FAISS')
71
+ def test_create_vectorstore(self, mock_faiss):
72
+ texts = [Document(page_content="test")]
73
+ self.faiss_store.create_vectorstore(texts)
74
+ mock_faiss.from_documents.assert_called_once_with(texts, self.embedding_model)
75
+
76
+ @patch('rag_app.vector_store_handler.vectorstores.FAISS')
77
+ def test_load_existing_vectorstore(self, mock_faiss):
78
+ self.faiss_store.load_existing_vectorstore()
79
+ mock_faiss.load_local.assert_called_once_with("test_dir", self.embedding_model)
80
+
81
+ @patch('rag_app.vector_store_handler.vectorstores.FAISS')
82
+ def test_save(self, mock_faiss):
83
+ self.faiss_store.vectorstore = MagicMock()
84
+ self.faiss_store.save()
85
+ self.faiss_store.vectorstore.save_local.assert_called_once_with("test_dir")
86
+
87
+ if __name__ == '__main__':
88
+ unittest.main()