Spaces:
Sleeping
Sleeping
import os | |
from langchain.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain.llms import HuggingFacePipeline | |
from transformers import pipeline | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
class KnowledgeManager: | |
def __init__(self, root_dir="."): | |
self.root_dir = root_dir | |
self.docsearch = None | |
self.qa_chain = None | |
self.llm = None | |
self.embeddings = None | |
self._initialize_llm() | |
self._initialize_embeddings() | |
self._load_knowledge_base() | |
def _initialize_llm(self): | |
# Load local text2text model using HuggingFace pipeline (FLAN-T5 small) | |
local_pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=1024) | |
self.llm = HuggingFacePipeline(pipeline=local_pipe) | |
def _initialize_embeddings(self): | |
# Use general-purpose sentence transformer | |
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
def _load_knowledge_base(self): | |
# Automatically find all .txt files in the root directory | |
txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")] | |
if not txt_files: | |
raise FileNotFoundError("No .txt files found in root directory.") | |
all_texts = [] | |
for filename in txt_files: | |
path = os.path.join(self.root_dir, filename) | |
with open(path, "r", encoding="utf-8") as f: | |
all_texts.append(f.read()) | |
full_text = "\n\n".join(all_texts) | |
# Split text into chunks for embedding | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
docs = text_splitter.create_documents([full_text]) | |
# Create FAISS vector store | |
self.docsearch = FAISS.from_documents(docs, self.embeddings) | |
# Build the QA chain | |
self.qa_chain = RetrievalQA.from_chain_type( | |
llm=self.llm, | |
chain_type="stuff", | |
retriever=self.docsearch.as_retriever(), | |
return_source_documents=True, | |
) | |
def ask(self, query): | |
if not self.qa_chain: | |
raise ValueError("Knowledge base not initialized.") | |
result = self.qa_chain(query) | |
return result['result'] | |