xTwin / knowledge_engine.py
aamirhameed's picture
Update knowledge_engine.py
0ccee0d verified
raw
history blame
2.36 kB
import os
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain.llms import HuggingFacePipeline
from transformers import pipeline
from langchain.text_splitter import RecursiveCharacterTextSplitter
class KnowledgeManager:
def __init__(self, root_dir="."):
self.root_dir = root_dir
self.docsearch = None
self.qa_chain = None
self.llm = None
self.embeddings = None
self._initialize_llm()
self._initialize_embeddings()
self._load_knowledge_base()
def _initialize_llm(self):
# Load local text2text model using HuggingFace pipeline (FLAN-T5 small)
local_pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=256)
self.llm = HuggingFacePipeline(pipeline=local_pipe)
def _initialize_embeddings(self):
# Use general-purpose sentence transformer
self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
def _load_knowledge_base(self):
# Automatically find all .txt files in the root directory
txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")]
if not txt_files:
raise FileNotFoundError("No .txt files found in root directory.")
all_texts = []
for filename in txt_files:
path = os.path.join(self.root_dir, filename)
with open(path, "r", encoding="utf-8") as f:
all_texts.append(f.read())
full_text = "\n\n".join(all_texts)
# Split text into chunks for embedding
text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
docs = text_splitter.create_documents([full_text])
# Create FAISS vector store
self.docsearch = FAISS.from_documents(docs, self.embeddings)
# Build the QA chain
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.docsearch.as_retriever(),
return_source_documents=True,
)
def ask(self, query):
if not self.qa_chain:
raise ValueError("Knowledge base not initialized.")
result = self.qa_chain(query)
return result['result']