aamirhameed commited on
Commit
15e4bac
·
verified ·
1 Parent(s): 85a2e09

Update knowledge_engine.py

Browse files
Files changed (1) hide show
  1. knowledge_engine.py +12 -25
knowledge_engine.py CHANGED
@@ -3,7 +3,7 @@ from langchain.vectorstores import FAISS
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain.chains import RetrievalQA
5
  from langchain.llms import HuggingFacePipeline
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
 
9
  class KnowledgeManager:
@@ -12,38 +12,23 @@ class KnowledgeManager:
12
  self.docsearch = None
13
  self.qa_chain = None
14
  self.llm = None
 
15
 
16
  self._initialize_llm()
17
  self._initialize_embeddings()
18
  self._load_knowledge_base()
19
 
20
  def _initialize_llm(self):
21
- model_id = "tiiuae/falcon-7b-instruct"
22
-
23
- tokenizer = AutoTokenizer.from_pretrained(model_id)
24
- model = AutoModelForCausalLM.from_pretrained(
25
- model_id,
26
- trust_remote_code=True,
27
- torch_dtype="auto", # Will use float16 on GPU, float32 on CPU
28
- device_map="auto"
29
- )
30
-
31
- falcon_pipeline = pipeline(
32
- "text-generation",
33
- model=model,
34
- tokenizer=tokenizer,
35
- max_new_tokens=512,
36
- temperature=0.7,
37
- top_p=0.95,
38
- repetition_penalty=1.1
39
- )
40
-
41
- self.llm = HuggingFacePipeline(pipeline=falcon_pipeline)
42
 
43
  def _initialize_embeddings(self):
 
44
  self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
45
 
46
  def _load_knowledge_base(self):
 
47
  txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")]
48
 
49
  if not txt_files:
@@ -53,16 +38,18 @@ class KnowledgeManager:
53
  for filename in txt_files:
54
  path = os.path.join(self.root_dir, filename)
55
  with open(path, "r", encoding="utf-8") as f:
56
- content = f.read()
57
- all_texts.append(content)
58
 
59
  full_text = "\n\n".join(all_texts)
60
 
 
61
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
62
  docs = text_splitter.create_documents([full_text])
63
 
 
64
  self.docsearch = FAISS.from_documents(docs, self.embeddings)
65
 
 
66
  self.qa_chain = RetrievalQA.from_chain_type(
67
  llm=self.llm,
68
  chain_type="stuff",
@@ -74,4 +61,4 @@ class KnowledgeManager:
74
  if not self.qa_chain:
75
  raise ValueError("Knowledge base not initialized.")
76
  result = self.qa_chain(query)
77
- return result["result"]
 
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain.chains import RetrievalQA
5
  from langchain.llms import HuggingFacePipeline
6
+ from transformers import pipeline
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
 
9
  class KnowledgeManager:
 
12
  self.docsearch = None
13
  self.qa_chain = None
14
  self.llm = None
15
+ self.embeddings = None
16
 
17
  self._initialize_llm()
18
  self._initialize_embeddings()
19
  self._load_knowledge_base()
20
 
21
  def _initialize_llm(self):
22
+ # Load local text2text model using HuggingFace pipeline (FLAN-T5 small)
23
+ local_pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=1024)
24
+ self.llm = HuggingFacePipeline(pipeline=local_pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def _initialize_embeddings(self):
27
+ # Use general-purpose sentence transformer
28
  self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
29
 
30
  def _load_knowledge_base(self):
31
+ # Automatically find all .txt files in the root directory
32
  txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")]
33
 
34
  if not txt_files:
 
38
  for filename in txt_files:
39
  path = os.path.join(self.root_dir, filename)
40
  with open(path, "r", encoding="utf-8") as f:
41
+ all_texts.append(f.read())
 
42
 
43
  full_text = "\n\n".join(all_texts)
44
 
45
+ # Split text into chunks for embedding
46
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
47
  docs = text_splitter.create_documents([full_text])
48
 
49
+ # Create FAISS vector store
50
  self.docsearch = FAISS.from_documents(docs, self.embeddings)
51
 
52
+ # Build the QA chain
53
  self.qa_chain = RetrievalQA.from_chain_type(
54
  llm=self.llm,
55
  chain_type="stuff",
 
61
  if not self.qa_chain:
62
  raise ValueError("Knowledge base not initialized.")
63
  result = self.qa_chain(query)
64
+ return result['result']