aamirhameed commited on
Commit
5fce8b9
·
verified ·
1 Parent(s): 6302c50

Update knowledge_engine.py

Browse files
Files changed (1) hide show
  1. knowledge_engine.py +25 -12
knowledge_engine.py CHANGED
@@ -3,7 +3,7 @@ from langchain.vectorstores import FAISS
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain.chains import RetrievalQA
5
  from langchain.llms import HuggingFacePipeline
6
- from transformers import pipeline
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
 
9
  class KnowledgeManager:
@@ -12,23 +12,38 @@ class KnowledgeManager:
12
  self.docsearch = None
13
  self.qa_chain = None
14
  self.llm = None
15
- self.embeddings = None
16
 
17
  self._initialize_llm()
18
  self._initialize_embeddings()
19
  self._load_knowledge_base()
20
 
21
  def _initialize_llm(self):
22
- # Load local text2text model using HuggingFace pipeline (FLAN-T5 small)
23
- local_pipe = pipeline("text2text-generation", model="google/flan-t5-small", max_length=500)
24
- self.llm = HuggingFacePipeline(pipeline=local_pipe)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def _initialize_embeddings(self):
27
- # Use general-purpose sentence transformer
28
  self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
29
 
30
  def _load_knowledge_base(self):
31
- # Automatically find all .txt files in the root directory
32
  txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")]
33
 
34
  if not txt_files:
@@ -38,18 +53,16 @@ class KnowledgeManager:
38
  for filename in txt_files:
39
  path = os.path.join(self.root_dir, filename)
40
  with open(path, "r", encoding="utf-8") as f:
41
- all_texts.append(f.read())
 
42
 
43
  full_text = "\n\n".join(all_texts)
44
 
45
- # Split text into chunks for embedding
46
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
47
  docs = text_splitter.create_documents([full_text])
48
 
49
- # Create FAISS vector store
50
  self.docsearch = FAISS.from_documents(docs, self.embeddings)
51
 
52
- # Build the QA chain
53
  self.qa_chain = RetrievalQA.from_chain_type(
54
  llm=self.llm,
55
  chain_type="stuff",
@@ -61,4 +74,4 @@ class KnowledgeManager:
61
  if not self.qa_chain:
62
  raise ValueError("Knowledge base not initialized.")
63
  result = self.qa_chain(query)
64
- return result['result']
 
3
  from langchain.embeddings import HuggingFaceEmbeddings
4
  from langchain.chains import RetrievalQA
5
  from langchain.llms import HuggingFacePipeline
6
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
7
  from langchain.text_splitter import RecursiveCharacterTextSplitter
8
 
9
  class KnowledgeManager:
 
12
  self.docsearch = None
13
  self.qa_chain = None
14
  self.llm = None
 
15
 
16
  self._initialize_llm()
17
  self._initialize_embeddings()
18
  self._load_knowledge_base()
19
 
20
  def _initialize_llm(self):
21
+ model_id = "tiiuae/falcon-7b-instruct"
22
+
23
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
24
+ model = AutoModelForCausalLM.from_pretrained(
25
+ model_id,
26
+ trust_remote_code=True,
27
+ torch_dtype="auto", # Will use float16 on GPU, float32 on CPU
28
+ device_map="auto"
29
+ )
30
+
31
+ falcon_pipeline = pipeline(
32
+ "text-generation",
33
+ model=model,
34
+ tokenizer=tokenizer,
35
+ max_new_tokens=512,
36
+ temperature=0.7,
37
+ top_p=0.95,
38
+ repetition_penalty=1.1
39
+ )
40
+
41
+ self.llm = HuggingFacePipeline(pipeline=falcon_pipeline)
42
 
43
  def _initialize_embeddings(self):
 
44
  self.embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
45
 
46
  def _load_knowledge_base(self):
 
47
  txt_files = [f for f in os.listdir(self.root_dir) if f.endswith(".txt")]
48
 
49
  if not txt_files:
 
53
  for filename in txt_files:
54
  path = os.path.join(self.root_dir, filename)
55
  with open(path, "r", encoding="utf-8") as f:
56
+ content = f.read()
57
+ all_texts.append(content)
58
 
59
  full_text = "\n\n".join(all_texts)
60
 
 
61
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
62
  docs = text_splitter.create_documents([full_text])
63
 
 
64
  self.docsearch = FAISS.from_documents(docs, self.embeddings)
65
 
 
66
  self.qa_chain = RetrievalQA.from_chain_type(
67
  llm=self.llm,
68
  chain_type="stuff",
 
74
  if not self.qa_chain:
75
  raise ValueError("Knowledge base not initialized.")
76
  result = self.qa_chain(query)
77
+ return result["result"]