zamal commited on
Commit
ca125f5
·
verified ·
1 Parent(s): af1ccb2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +247 -96
app.py CHANGED
@@ -5,143 +5,294 @@ import gc
5
  from huggingface_hub.utils import HfHubHTTPError
6
  from langchain_core.prompts import PromptTemplate
7
  from langchain_huggingface import HuggingFaceEndpoint
8
- from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
9
- from doctr.io import DocumentFile
10
- from doctr.models import ocr_predictor
11
- from pypdf import PdfReader
 
 
 
 
12
  from PIL import Image
13
- from utils import extract_images, image_to_bytes, clean_text
 
 
 
14
  from welcome_text import WELCOME_INTRO
 
 
 
 
 
15
  import chromadb
16
  from chromadb.utils import embedding_functions
 
 
 
17
  from langchain.text_splitter import RecursiveCharacterTextSplitter
18
- import gradio as gr
 
 
 
19
 
20
  # ─────────────────────────────────────────────────────────────────────────────
21
- # Globals
22
- CURRENT_VDB = None
 
23
  processor = None
24
  vision_model = None
25
-
26
- # OCR & V+L defaults
27
- OCR_CHOICES = {
28
- "db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"),
29
- "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"),
30
- }
31
- SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction(
32
- model_name="all-MiniLM-L6-v2"
33
  )
34
-
35
- def get_image_description(img: Image.Image) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  global processor, vision_model
 
 
37
  if processor is None or vision_model is None:
38
- # use the same default V+L model everywhere
39
- vlm = "llava-hf/llava-v1.6-mistral-7b-hf"
40
- processor = LlavaNextProcessor.from_pretrained(vlm)
41
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
42
- vlm, torch_dtype=torch.float16, low_cpu_mem_usage=True
 
 
43
  ).to("cuda")
44
- torch.cuda.empty_cache(); gc.collect()
 
 
 
45
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
46
- inputs = processor(prompt, img, return_tensors="pt").to("cuda")
47
- out = vision_model.generate(**inputs, max_new_tokens=100)
48
- return processor.decode(out[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  def extract_data_from_pdfs(
51
- docs, session, include_images, do_ocr, ocr_choice, vlm_choice, progress=gr.Progress()
 
 
 
 
 
 
52
  ):
53
  if not docs:
54
  raise gr.Error("No documents to process")
55
 
56
- # 1) Optional OCR
57
- local_ocr = None
58
  if do_ocr == "Get Text With OCR":
59
  db_m, crnn_m = OCR_CHOICES[ocr_choice]
60
  local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
 
 
61
 
62
- # 2) Prepare V+L
63
  proc = LlavaNextProcessor.from_pretrained(vlm_choice)
64
- vis = LlavaNextForConditionalGeneration.from_pretrained(
65
- vlm_choice, torch_dtype=torch.float16, low_cpu_mem_usage=True
66
- ).to("cuda")
67
 
68
- # 3) Patch get_image_description to use this choice
69
- def describe(img: Image.Image) -> str:
70
  torch.cuda.empty_cache(); gc.collect()
71
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
72
  inp = proc(prompt, img, return_tensors="pt").to("cuda")
73
  out = vis.generate(**inp, max_new_tokens=100)
74
  return proc.decode(out[0], skip_special_tokens=True)
75
- global get_image_description, CURRENT_VDB
 
76
  get_image_description = describe
77
 
78
- # 4) Pull text + images
79
  progress(0.2, "Extracting text and images…")
80
- full_text, images, names = "", [], []
81
- for p in docs:
 
82
  if local_ocr:
83
- pdf = DocumentFile.from_pdf(p)
84
  res = local_ocr(pdf)
85
- full_text += " ".join(w.value for blk in res.pages for line in blk.lines for w in line.words) + "\n\n"
86
  else:
87
- full_text += (PdfReader(p).pages[0].extract_text() or "") + "\n\n"
88
 
89
  if include_images == "Include Images":
90
- imgs = extract_images([p])
91
  images.extend(imgs)
92
- names.extend([os.path.basename(p)] * len(imgs))
93
 
94
- # 5) Build in-memory Chroma
95
  progress(0.6, "Indexing in vector DB…")
96
- client = chromadb.EphemeralClient()
97
- for col in ("text_db", "image_db"):
98
- if col in [c.name for c in client.list_collections()]:
99
- client.delete_collection(col)
100
- text_col = client.get_or_create_collection("text_db", embedding_function=SHARED_EMB_FN)
101
- img_col = client.get_or_create_collection("image_db", embedding_function=SHARED_EMB_FN,
102
- metadata={"hnsw:space":"cosine"})
 
 
 
 
 
 
103
 
104
- if images:
105
- descs, metas = [], []
106
- for i, im in enumerate(images):
107
- cap = get_image_description(im)
108
- descs.append(f"{names[i]}: {cap}")
109
- metas.append({"image": image_to_bytes(im)})
110
- img_col.add(ids=[str(i) for i in range(len(images))],
111
- documents=descs, metadatas=metas)
112
 
113
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
114
- docs_ = splitter.create_documents([full_text])
115
- text_col.add(ids=[str(i) for i in range(len(docs_))],
116
- documents=[d.page_content for d in docs_])
117
 
118
- CURRENT_VDB = client
119
- session["processed"] = True
120
- sample = images[:4] if include_images=="Include Images" else []
121
- return session, full_text[:2000]+"...", sample, "<h3>Done!</h3>"
122
 
123
- def conversation(session, question, num_ctx, img_ctx, history, temp, max_tok, model_id):
124
- global CURRENT_VDB
125
- if not session.get("processed") or CURRENT_VDB is None:
 
 
 
 
 
 
 
 
 
 
126
  raise gr.Error("Please extract data first")
127
 
128
- # a) text retrieval
129
- docs = CURRENT_VDB.get_collection("text_db")\
130
- .query(query_texts=[question], n_results=int(num_ctx), include=["documents"])["documents"][0]
 
 
 
 
 
131
 
132
- # b) image retrieval
133
- img_q = CURRENT_VDB.get_collection("image_db")\
134
- .query(query_texts=[question], n_results=int(img_ctx),
135
- include=["metadatas","documents"])
 
 
 
 
 
 
 
136
  img_descs = img_q["documents"][0] or ["No images found"]
137
  images = []
138
- for m in img_q["metadatas"][0]:
139
- b = m.get("image","")
140
- try: images.append(Image.open(io.BytesIO(base64.b64decode(b))))
141
- except: pass
 
 
142
  img_desc = "\n".join(img_descs)
143
 
144
- # c) prompt & LLM
 
 
 
 
 
 
 
 
145
  prompt = PromptTemplate(
146
  template="""
147
  Context:
@@ -154,23 +305,23 @@ Question:
154
  {q}
155
 
156
  Answer:
157
- """, input_variables=["text","img_desc","q"])
 
158
  inp = prompt.format(text="\n\n".join(docs), img_desc=img_desc, q=question)
159
 
160
- llm = HuggingFaceEndpoint(
161
- repo_id=model_id, task="text-generation",
162
- temperature=temp, max_new_tokens=max_tok,
163
- huggingfacehub_api_token=HF_TOKEN
164
- )
165
- try: ans = llm.invoke(inp)
166
  except HfHubHTTPError as e:
167
- ans = f"❌ Model `{model_id}` not hosted." if e.response.status_code==404 else f"⚠️ HF API error: {e}"
168
  except Exception as e:
169
- ans = f"⚠️ Unexpected error: {e}"
 
 
 
 
 
 
170
 
171
- new_hist = history + [{"role":"user","content":question},
172
- {"role":"assistant","content":ans}]
173
- return new_hist, docs, images
174
 
175
 
176
 
@@ -258,4 +409,4 @@ with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo:
258
  )
259
 
260
  if __name__ == "__main__":
261
- demo.launch()
 
5
  from huggingface_hub.utils import HfHubHTTPError
6
  from langchain_core.prompts import PromptTemplate
7
  from langchain_huggingface import HuggingFaceEndpoint
8
+ import io, base64
9
+ from PIL import Image
10
+ import torch
11
+ import gradio as gr
12
+ import spaces
13
+ import numpy as np
14
+ import pandas as pd
15
+ import pymupdf
16
  from PIL import Image
17
+ from pypdf import PdfReader
18
+ from dotenv import load_dotenv
19
+ import shutil
20
+ from chromadb.config import Settings, DEFAULT_TENANT, DEFAULT_DATABASE
21
  from welcome_text import WELCOME_INTRO
22
+
23
+ from doctr.io import DocumentFile
24
+ from doctr.models import ocr_predictor
25
+ from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
26
+
27
  import chromadb
28
  from chromadb.utils import embedding_functions
29
+ from chromadb.utils.data_loaders import ImageLoader
30
+
31
+ from langchain_core.prompts import PromptTemplate
32
  from langchain.text_splitter import RecursiveCharacterTextSplitter
33
+ from langchain_huggingface import HuggingFaceEndpoint
34
+
35
+ from utils import extract_pdfs, extract_images, clean_text, image_to_bytes
36
+ from utils import *
37
 
38
  # ─────────────────────────────────────────────────────────────────────────────
39
+ # Load .env
40
+ load_dotenv()
41
+ HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN")
42
  processor = None
43
  vision_model = None
44
+ # OCR + multimodal image description setup
45
+ ocr_model = ocr_predictor(
46
+ "db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True
 
 
 
 
 
47
  )
48
+ processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
49
+ vision_model = LlavaNextForConditionalGeneration.from_pretrained(
50
+ "llava-hf/llava-v1.6-mistral-7b-hf",
51
+ torch_dtype=torch.float16,
52
+ low_cpu_mem_usage=True
53
+ ).to("cuda")
54
+
55
+
56
+ # Add at the top of your module, alongside your other globals
57
+ PERSIST_DIR = "./chroma_db"
58
+ if os.path.exists(PERSIST_DIR):
59
+ shutil.rmtree(PERSIST_DIR)
60
+
61
+ @spaces.GPU()
62
+ def get_image_description(image: Image.Image) -> str:
63
+ """
64
+ Lazy-loads the Llava processor + model inside the GPU worker,
65
+ runs captioning, and returns a one-sentence description.
66
+ """
67
  global processor, vision_model
68
+
69
+ # On first call, instantiate + move to CUDA
70
  if processor is None or vision_model is None:
71
+ processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
 
 
72
  vision_model = LlavaNextForConditionalGeneration.from_pretrained(
73
+ "llava-hf/llava-v1.6-mistral-7b-hf",
74
+ torch_dtype=torch.float16,
75
+ low_cpu_mem_usage=True
76
  ).to("cuda")
77
+
78
+ torch.cuda.empty_cache()
79
+ gc.collect()
80
+
81
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
82
+ inputs = processor(prompt, image, return_tensors="pt").to("cuda")
83
+ output = vision_model.generate(**inputs, max_new_tokens=100)
84
+ return processor.decode(output[0], skip_special_tokens=True)
85
+
86
+ # Vector DB setup
87
+ # at top of file, alongside your other imports
88
+ from chromadb.utils import embedding_functions
89
+ from chromadb.utils.data_loaders import ImageLoader
90
+ import chromadb
91
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
92
+ from utils import image_to_bytes # your helper
93
+
94
+ # 1) Create one shared embedding function (defaulting to All-MiniLM-L6-v2, 384-dim)
95
+ SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction(
96
+ model_name="all-MiniLM-L6-v2"
97
+ )
98
+
99
+ def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]):
100
+ """
101
+ Build a *persistent* ChromaDB instance on disk, with two collections:
102
+ • text_db (chunks of the PDF text)
103
+ • image_db (image descriptions + raw image bytes)
104
+ """
105
+ # 1) Make or clean the on-disk folder
106
+ shutil.rmtree(PERSIST_DIR, ignore_errors=True)
107
+ os.makedirs(PERSIST_DIR, exist_ok=True)
108
+
109
+ client = chromadb.PersistentClient(
110
+ path=PERSIST_DIR,
111
+ settings=Settings(),
112
+ tenant=DEFAULT_TENANT,
113
+ database=DEFAULT_DATABASE
114
+ )
115
+
116
+ # 3) Create / wipe collections
117
+ for col in ("text_db", "image_db"):
118
+ if col in [c.name for c in client.list_collections()]:
119
+ client.delete_collection(col)
120
+
121
+ text_col = client.get_or_create_collection(
122
+ name="text_db",
123
+ embedding_function=SHARED_EMB_FN
124
+ )
125
+ img_col = client.get_or_create_collection(
126
+ name="image_db",
127
+ embedding_function=SHARED_EMB_FN,
128
+ metadata={"hnsw:space": "cosine"}
129
+ )
130
+
131
+ # 4) Add images
132
+ if images:
133
+ descs, metas = [], []
134
+ for idx, img in enumerate(images):
135
+ try:
136
+ cap = get_image_description(img)
137
+ except:
138
+ cap = "⚠️ could not describe image"
139
+ descs.append(f"{img_names[idx]}: {cap}")
140
+ metas.append({"image": image_to_bytes(img)})
141
+ img_col.add(ids=[str(i) for i in range(len(images))],
142
+ documents=descs,
143
+ metadatas=metas)
144
+
145
+ # 5) Chunk & add text
146
+ splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
147
+ docs = splitter.create_documents([text])
148
+ text_col.add(ids=[str(i) for i in range(len(docs))],
149
+ documents=[d.page_content for d in docs])
150
+
151
+ return client
152
+
153
+
154
 
155
+
156
+ # Text extraction
157
+ def result_to_text(result, as_text=False):
158
+ pages = []
159
+ for pg in result.pages:
160
+ txt = " ".join(w.value for block in pg.blocks for line in block.lines for w in line.words)
161
+ pages.append(clean_text(txt))
162
+ return "\n\n".join(pages) if as_text else pages
163
+
164
+ OCR_CHOICES = {
165
+ "db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"),
166
+ "db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"),
167
+ }
168
+
169
+ @spaces.GPU()
170
  def extract_data_from_pdfs(
171
+ docs: list[str],
172
+ session: dict,
173
+ include_images: str,
174
+ do_ocr: str,
175
+ ocr_choice: str,
176
+ vlm_choice: str,
177
+ progress=gr.Progress()
178
  ):
179
  if not docs:
180
  raise gr.Error("No documents to process")
181
 
182
+ # 1) OCR pipeline if requested
 
183
  if do_ocr == "Get Text With OCR":
184
  db_m, crnn_m = OCR_CHOICES[ocr_choice]
185
  local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True)
186
+ else:
187
+ local_ocr = None
188
 
189
+ # 2) Vision–language model
190
  proc = LlavaNextProcessor.from_pretrained(vlm_choice)
191
+ vis = (LlavaNextForConditionalGeneration
192
+ .from_pretrained(vlm_choice, torch_dtype=torch.float16, low_cpu_mem_usage=True)
193
+ .to("cuda"))
194
 
195
+ # 3) Monkey-patch caption fn
196
+ def describe(img):
197
  torch.cuda.empty_cache(); gc.collect()
198
  prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]"
199
  inp = proc(prompt, img, return_tensors="pt").to("cuda")
200
  out = vis.generate(**inp, max_new_tokens=100)
201
  return proc.decode(out[0], skip_special_tokens=True)
202
+
203
+ global get_image_description
204
  get_image_description = describe
205
 
206
+ # 4) Extract text & images
207
  progress(0.2, "Extracting text and images…")
208
+ all_text = ""
209
+ images, names = [], []
210
+ for path in docs:
211
  if local_ocr:
212
+ pdf = DocumentFile.from_pdf(path)
213
  res = local_ocr(pdf)
214
+ all_text += result_to_text(res, as_text=True) + "\n\n"
215
  else:
216
+ all_text += (PdfReader(path).pages[0].extract_text() or "") + "\n\n"
217
 
218
  if include_images == "Include Images":
219
+ imgs = extract_images([path])
220
  images.extend(imgs)
221
+ names.extend([os.path.basename(path)] * len(imgs))
222
 
223
+ # 5) Build + persist the vectordb
224
  progress(0.6, "Indexing in vector DB…")
225
+ client = get_vectordb(all_text, images, names)
226
+
227
+ # 6) Mark session and return UI outputs
228
+ session["processed"] = True
229
+ session["persist_directory"] = PERSIST_DIR
230
+ sample_imgs = images[:4] if include_images == "Include Images" else []
231
+
232
+ return (
233
+ session, # gr.State
234
+ all_text[:2000] + "...",
235
+ sample_imgs,
236
+ "<h3>Done!</h3>"
237
+ )
238
 
 
 
 
 
 
 
 
 
239
 
 
 
 
 
240
 
 
 
 
 
241
 
242
+ # Chat function
243
+ def conversation(
244
+ session: dict,
245
+ question: str,
246
+ num_ctx: int,
247
+ img_ctx: int,
248
+ history: list,
249
+ temp: float,
250
+ max_tok: int,
251
+ model_id: str
252
+ ):
253
+ pd = session.get("persist_directory")
254
+ if not session.get("processed") or not pd:
255
  raise gr.Error("Please extract data first")
256
 
257
+ # 1) Reopen the same persistent client (new API)
258
+ client = chromadb.PersistentClient(
259
+ path=pd,
260
+ settings=Settings(),
261
+ tenant=DEFAULT_TENANT,
262
+ database=DEFAULT_DATABASE
263
+ )
264
+
265
 
266
+ # 2) Text retrieval
267
+ text_col = client.get_collection("text_db")
268
+ docs = text_col.query(query_texts=[question],
269
+ n_results=int(num_ctx),
270
+ include=["documents"])["documents"][0]
271
+
272
+ # 3) Image retrieval
273
+ img_col = client.get_collection("image_db")
274
+ img_q = img_col.query(query_texts=[question],
275
+ n_results=int(img_ctx),
276
+ include=["metadatas","documents"])
277
  img_descs = img_q["documents"][0] or ["No images found"]
278
  images = []
279
+ for meta in img_q["metadatas"][0]:
280
+ b64 = meta.get("image","")
281
+ try:
282
+ images.append(Image.open(io.BytesIO(base64.b64decode(b64))))
283
+ except:
284
+ pass
285
  img_desc = "\n".join(img_descs)
286
 
287
+ # 4) Build prompt & call LLM
288
+ llm = HuggingFaceEndpoint(
289
+ repo_id=model_id,
290
+ task="text-generation",
291
+ temperature=temp,
292
+ max_new_tokens=max_tok,
293
+ huggingfacehub_api_token=HF_TOKEN
294
+ )
295
+
296
  prompt = PromptTemplate(
297
  template="""
298
  Context:
 
305
  {q}
306
 
307
  Answer:
308
+ """, input_variables=["text","img_desc","q"]
309
+ )
310
  inp = prompt.format(text="\n\n".join(docs), img_desc=img_desc, q=question)
311
 
312
+ try:
313
+ answer = llm.invoke(inp)
 
 
 
 
314
  except HfHubHTTPError as e:
315
+ answer = "❌ Model not hosted" if e.response.status_code==404 else f"⚠️ HF error: {e}"
316
  except Exception as e:
317
+ answer = f"⚠️ Unexpected error: {e}"
318
+
319
+ new_history = history + [
320
+ {"role":"user", "content":question},
321
+ {"role":"assistant","content":answer}
322
+ ]
323
+ return new_history, docs, images
324
 
 
 
 
325
 
326
 
327
 
 
409
  )
410
 
411
  if __name__ == "__main__":
412
+ demo.launch()