Spaces:
Running
Running
import os | |
from pathlib import Path | |
from langchain.document_loaders import TextLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain.llms import HuggingFaceHub | |
class KnowledgeManager: | |
def __init__(self, knowledge_dir="."): # root dir by default | |
self.knowledge_dir = Path(knowledge_dir) | |
self.documents = [] | |
self.embeddings = None | |
self.vectorstore = None | |
self.retriever = None | |
self.llm = None | |
self.qa_chain = None | |
self._load_documents() | |
if self.documents: | |
self._initialize_embeddings() | |
self._initialize_vectorstore() | |
self._initialize_llm() | |
self._initialize_qa_chain() | |
def _load_documents(self): | |
if not self.knowledge_dir.exists(): | |
raise FileNotFoundError(f"Directory {self.knowledge_dir} does not exist.") | |
files = list(self.knowledge_dir.glob("*.txt")) | |
if not files: | |
raise FileNotFoundError(f"No .txt files found in {self.knowledge_dir}. Please upload your knowledge base files in root.") | |
for file in files: | |
loader = TextLoader(str(file)) | |
self.documents.extend(loader.load()) | |
splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=100) | |
self.documents = splitter.split_documents(self.documents) | |
def _initialize_embeddings(self): | |
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
def _initialize_vectorstore(self): | |
self.vectorstore = FAISS.from_documents(self.documents, self.embeddings) | |
self.retriever = self.vectorstore.as_retriever() | |
def _initialize_llm(self): | |
self.llm = HuggingFaceHub(repo_id="google/flan-t5-small", model_kwargs={"temperature":0, "max_length":256}) | |
def _initialize_qa_chain(self): | |
self.qa_chain = RetrievalQA.from_chain_type(llm=self.llm, chain_type="stuff", retriever=self.retriever) | |
def ask(self, query): | |
if not self.qa_chain: | |
return "Knowledge base not initialized properly." | |
return self.qa_chain.run(query) | |
def get_knowledge_summary(self): | |
return f"Loaded {len(self.documents)} document chunks from {self.knowledge_dir}" | |