xTwin / knowledge_engine.py
aamirhameed's picture
Update knowledge_engine.py
15e4bac verified
import os
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
class KnowledgeManager:
def __init__(self, root_dir="."):
self.root_dir = root_dir
self.docsearch = None
self.qa_chain = None
self.llm = None
self.embeddings = None
self._initialize_llm()
self._initialize_embeddings()
self._load_knowledge_base()
def _initialize_llm(self):
# Load local text2text model using HuggingFace pipeline (FLAN-T5 small)
local_pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=1024)
self.llm = HuggingFacePipeline(pipeline=local_pipe)
def _initialize_embeddings(self):
# Use general-purpose sentence transformer
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
def _load_knowledge_base(self):
# Automatically find all .txt files in the root directory
txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")]
if not txt_files:
raise FileNotFoundError("No .txt files found in root directory.")
all_texts = []
for filename in txt_files:
path = os.path.join(self.root_dir, filename)
with open(path, "r", encoding="utf-8") as f:
all_texts.append(f.read())
full_text = "\n\n".join(all_texts)
# Split text into chunks for embedding
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = text_splitter.create_documents([full_text])
# Create FAISS vector store
self.docsearch = FAISS.from_documents(docs, self.embeddings)
# Build the QA chain
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.docsearch.as_retriever(),
return_source_documents=True,
)
def ask(self, query):
if not self.qa_chain:
raise ValueError("Knowledge base not initialized.")
result = self.qa_chain(query)
return result['result']