Spaces:
Running
Running
#load & split data | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# embed data | |
from langchain_mistralai.embeddings import MistralAIEmbeddings | |
# vector store | |
from langchain_community.vectorstores import FAISS | |
# prompt | |
from langchain.prompts import PromptTemplate | |
# memory | |
from langchain.memory import ConversationBufferMemory | |
#llm | |
from langchain_mistralai.chat_models import ChatMistralAI | |
#chain modules | |
from langchain.chains import RetrievalQA | |
# import PyPDF2 | |
import os | |
import re | |
from dotenv import load_dotenv | |
load_dotenv() | |
from collections import defaultdict | |
class RagModule(): | |
def __init__(self): | |
self.mistral_api_key = "jdkSsdQeimr6g3x3H4slLYhTKBEH5pRC" | |
self.model_name_embedding = "mistral-embed" | |
self.embedding_model = MistralAIEmbeddings(model=self.model_name_embedding, mistral_api_key=self.mistral_api_key) | |
self.chunk_size = 1000 | |
self.chunk_overlap = 120 | |
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) | |
self.db_faiss_path = "data/vector_store" | |
#params llm | |
self.llm_model = "mistral-small" | |
self.max_new_tokens = 512 | |
self.top_p = 0.5 | |
self.temperature = 0.1 | |
def split_text(self, text:str) -> list: | |
"""Split the text into chunk | |
Args: | |
text (str): _description_ | |
Returns: | |
list: _description_ | |
""" | |
texts = self.text_splitter.split_text(text) | |
return texts | |
def get_metadata(self, texts:list) -> list: | |
"""_summary_ | |
Args: | |
texts (list): _description_ | |
Returns: | |
list: _description_ | |
""" | |
metadatas = [{"source": f'Paragraphe: {i}'} for i in range(len(texts))] | |
return metadatas | |
def get_faiss_db(self): | |
"""load local faiss vector store containing all embeddings | |
""" | |
db = FAISS.load_local(self.db_faiss_path, self.embedding_model) | |
return db | |
def set_custom_prompt(self, prompt_template:str): | |
"""Instantiate prompt template for Q&A retreival for each vectore stores | |
Args: | |
prompt_template (str): description of the prompt | |
input_variables (list): variables in the prompt | |
""" | |
prompt = PromptTemplate.from_template( | |
template=prompt_template, | |
) | |
return prompt | |
def load_mistral(self): | |
"""instantiate LLM | |
""" | |
model_kwargs = { | |
"mistral_api_key": self.mistral_api_key, | |
"model": self.llm_model, | |
"max_new_tokens": self.max_new_tokens, | |
"top_p": self.top_p, | |
"temperature": self.temperature, | |
} | |
llm = ChatMistralAI(**model_kwargs) | |
return llm | |
def retrieval_qa_memory_chain(self, db, prompt_template): | |
"""_summary_ | |
""" | |
llm = self.load_mistral() | |
prompt = self.set_custom_prompt(prompt_template) | |
memory = ConversationBufferMemory( | |
memory_key = 'history', | |
input_key = 'question' | |
) | |
chain_type_kwargs= { | |
"prompt" : prompt, | |
"memory" : memory | |
} | |
qa_chain = RetrievalQA.from_chain_type( | |
llm = llm, | |
chain_type = 'stuff', | |
retriever = db.as_retriever(search_kwargs={"k":5}), | |
chain_type_kwargs = chain_type_kwargs, | |
return_source_documents = True, | |
) | |
return qa_chain | |
def retrieval_qa_chain(self, db, prompt_template): | |
"""_summary_ | |
""" | |
llm = self.load_llm() | |
prompt = self.set_custom_prompt(prompt_template) | |
chain_type_kwargs= { | |
"prompt" : prompt, | |
} | |
qa_chain = RetrievalQA.from_chain_type( | |
llm = llm, | |
chain_type = 'stuff', | |
retriever = db.as_retriever(search_kwargs={"k":3}), | |
chain_type_kwargs = chain_type_kwargs, | |
return_source_documents = True, | |
) | |
return qa_chain | |
def get_sources_document(self, source_documents:list) -> dict: | |
"""generate dictionnary with path (as a key) and list of pages associated to one path | |
Args: | |
source_document (list): list of documents containing source_document of rag response | |
Returns: | |
dict: { | |
path/to/file1 : [0, 1, 3], | |
path/to/file2 : [5, 2] | |
} | |
""" | |
sources = defaultdict(list) | |
for doc in source_documents: | |
sources[doc.metadata["source"]].append(doc.metadata["page"]) | |
return sources | |
def shape_answer_with_source(self, answer: str, sources: dict): | |
"""_summary_ | |
Args: | |
answer (str): _description_ | |
source (dict): _description_ | |
""" | |
pattern = r"^(.+)\/([^\/]+)$" | |
source_msg = "" | |
for path, page in sources.items(): | |
file = re.findall(pattern, path)[0][1] | |
source_msg += f"\nFichier: {file} - Page: {page}" | |
answer += f"\n{source_msg}" | |
return answer |