Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import io | |
import base64 | |
import gc | |
from huggingface_hub.utils import HfHubHTTPError | |
from langchain_core.prompts import PromptTemplate | |
from langchain_huggingface import HuggingFaceEndpoint | |
import io, base64 | |
from PIL import Image | |
import torch | |
import gradio as gr | |
import spaces | |
import numpy as np | |
import pandas as pd | |
import pymupdf | |
from PIL import Image | |
from pypdf import PdfReader | |
from dotenv import load_dotenv | |
import shutil | |
from chromadb.config import Settings, DEFAULT_TENANT, DEFAULT_DATABASE | |
from welcome_text import WELCOME_INTRO | |
from doctr.io import DocumentFile | |
from doctr.models import ocr_predictor | |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from chromadb.utils.data_loaders import ImageLoader | |
from langchain_core.prompts import PromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEndpoint | |
from utils import extract_pdfs, extract_images, clean_text, image_to_bytes | |
from utils import * | |
# ───────────────────────────────────────────────────────────────────────────── | |
# Load .env | |
load_dotenv() | |
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
processor = None | |
vision_model = None | |
# OCR + multimodal image description setup | |
ocr_model = ocr_predictor( | |
"db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True | |
) | |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
).to("cuda") | |
# Add at the top of your module, alongside your other globals | |
PERSIST_DIR = "./chroma_db" | |
if os.path.exists(PERSIST_DIR): | |
shutil.rmtree(PERSIST_DIR) | |
def get_image_description(image: Image.Image) -> str: | |
""" | |
Lazy-loads the Llava processor + model inside the GPU worker, | |
runs captioning, and returns a one-sentence description. | |
""" | |
global processor, vision_model | |
# On first call, instantiate + move to CUDA | |
if processor is None or vision_model is None: | |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
).to("cuda") | |
torch.cuda.empty_cache() | |
gc.collect() | |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
inputs = processor(prompt, image, return_tensors="pt").to("cuda") | |
output = vision_model.generate(**inputs, max_new_tokens=100) | |
return processor.decode(output[0], skip_special_tokens=True) | |
# Vector DB setup | |
# at top of file, alongside your other imports | |
from chromadb.utils import embedding_functions | |
from chromadb.utils.data_loaders import ImageLoader | |
import chromadb | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from utils import image_to_bytes # your helper | |
# 1) Create one shared embedding function (defaulting to All-MiniLM-L6-v2, 384-dim) | |
SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="all-MiniLM-L6-v2" | |
) | |
def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]): | |
""" | |
Build a *persistent* ChromaDB instance on disk, with two collections: | |
• text_db (chunks of the PDF text) | |
• image_db (image descriptions + raw image bytes) | |
""" | |
# 1) Make or clean the on-disk folder | |
shutil.rmtree(PERSIST_DIR, ignore_errors=True) | |
os.makedirs(PERSIST_DIR, exist_ok=True) | |
client = chromadb.PersistentClient( | |
path=PERSIST_DIR, | |
settings=Settings(), | |
tenant=DEFAULT_TENANT, | |
database=DEFAULT_DATABASE | |
) | |
# 3) Create / wipe collections | |
for col in ("text_db", "image_db"): | |
if col in [c.name for c in client.list_collections()]: | |
client.delete_collection(col) | |
text_col = client.get_or_create_collection( | |
name="text_db", | |
embedding_function=SHARED_EMB_FN | |
) | |
img_col = client.get_or_create_collection( | |
name="image_db", | |
embedding_function=SHARED_EMB_FN, | |
metadata={"hnsw:space": "cosine"} | |
) | |
# 4) Add images | |
if images: | |
descs, metas = [], [] | |
for idx, img in enumerate(images): | |
try: | |
cap = get_image_description(img) | |
except: | |
cap = "⚠️ could not describe image" | |
descs.append(f"{img_names[idx]}: {cap}") | |
metas.append({"image": image_to_bytes(img)}) | |
img_col.add(ids=[str(i) for i in range(len(images))], | |
documents=descs, | |
metadatas=metas) | |
# 5) Chunk & add text | |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
docs = splitter.create_documents([text]) | |
text_col.add(ids=[str(i) for i in range(len(docs))], | |
documents=[d.page_content for d in docs]) | |
return client | |
# Text extraction | |
def result_to_text(result, as_text=False): | |
pages = [] | |
for pg in result.pages: | |
txt = " ".join(w.value for block in pg.blocks for line in block.lines for w in line.words) | |
pages.append(clean_text(txt)) | |
return "\n\n".join(pages) if as_text else pages | |
OCR_CHOICES = { | |
"db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"), | |
"db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"), | |
} | |
def extract_data_from_pdfs( | |
docs: list[str], | |
session: dict, | |
include_images: str, | |
do_ocr: str, | |
ocr_choice: str, | |
vlm_choice: str, | |
progress=gr.Progress() | |
): | |
if not docs: | |
raise gr.Error("No documents to process") | |
# 1) OCR pipeline if requested | |
if do_ocr == "Get Text With OCR": | |
db_m, crnn_m = OCR_CHOICES[ocr_choice] | |
local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True) | |
else: | |
local_ocr = None | |
# 2) Vision–language model | |
proc = LlavaNextProcessor.from_pretrained(vlm_choice) | |
vis = (LlavaNextForConditionalGeneration | |
.from_pretrained(vlm_choice, torch_dtype=torch.float16, low_cpu_mem_usage=True) | |
.to("cuda")) | |
# 3) Monkey-patch caption fn | |
def describe(img): | |
torch.cuda.empty_cache(); gc.collect() | |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
inp = proc(prompt, img, return_tensors="pt").to("cuda") | |
out = vis.generate(**inp, max_new_tokens=100) | |
return proc.decode(out[0], skip_special_tokens=True) | |
global get_image_description | |
get_image_description = describe | |
# 4) Extract text & images | |
progress(0.2, "Extracting text and images…") | |
all_text = "" | |
images, names = [], [] | |
for path in docs: | |
if local_ocr: | |
pdf = DocumentFile.from_pdf(path) | |
res = local_ocr(pdf) | |
all_text += result_to_text(res, as_text=True) + "\n\n" | |
else: | |
all_text += (PdfReader(path).pages[0].extract_text() or "") + "\n\n" | |
if include_images == "Include Images": | |
imgs = extract_images([path]) | |
images.extend(imgs) | |
names.extend([os.path.basename(path)] * len(imgs)) | |
# 5) Build + persist the vectordb | |
progress(0.6, "Indexing in vector DB…") | |
client = get_vectordb(all_text, images, names) | |
# 6) Mark session and return UI outputs | |
session["processed"] = True | |
session["persist_directory"] = PERSIST_DIR | |
sample_imgs = images[:4] if include_images == "Include Images" else [] | |
return ( | |
session, # gr.State | |
all_text[:2000] + "...", | |
sample_imgs, | |
"<h3>Done!</h3>" | |
) | |
# Chat function | |
def conversation( | |
session: dict, | |
question: str, | |
num_ctx: int, | |
img_ctx: int, | |
history: list, | |
temp: float, | |
max_tok: int, | |
model_id: str | |
): | |
pd = session.get("persist_directory") | |
if not session.get("processed") or not pd: | |
raise gr.Error("Please extract data first") | |
# 1) Reopen the same persistent client (new API) | |
client = chromadb.PersistentClient( | |
path=pd, | |
settings=Settings(), | |
tenant=DEFAULT_TENANT, | |
database=DEFAULT_DATABASE | |
) | |
# 2) Text retrieval | |
text_col = client.get_collection("text_db") | |
docs = text_col.query(query_texts=[question], | |
n_results=int(num_ctx), | |
include=["documents"])["documents"][0] | |
# 3) Image retrieval | |
img_col = client.get_collection("image_db") | |
img_q = img_col.query(query_texts=[question], | |
n_results=int(img_ctx), | |
include=["metadatas","documents"]) | |
img_descs = img_q["documents"][0] or ["No images found"] | |
images = [] | |
for meta in img_q["metadatas"][0]: | |
b64 = meta.get("image","") | |
try: | |
images.append(Image.open(io.BytesIO(base64.b64decode(b64)))) | |
except: | |
pass | |
img_desc = "\n".join(img_descs) | |
# 4) Build prompt & call LLM | |
llm = HuggingFaceEndpoint( | |
repo_id=model_id, | |
task="text-generation", | |
temperature=temp, | |
max_new_tokens=max_tok, | |
huggingfacehub_api_token=HF_TOKEN | |
) | |
prompt = PromptTemplate( | |
template=""" | |
Context: | |
{text} | |
Included Images: | |
{img_desc} | |
Question: | |
{q} | |
Answer: | |
""", input_variables=["text","img_desc","q"] | |
) | |
inp = prompt.format(text="\n\n".join(docs), img_desc=img_desc, q=question) | |
try: | |
answer = llm.invoke(inp) | |
except HfHubHTTPError as e: | |
answer = "❌ Model not hosted" if e.response.status_code==404 else f"⚠️ HF error: {e}" | |
except Exception as e: | |
answer = f"⚠️ Unexpected error: {e}" | |
new_history = history + [ | |
{"role":"user", "content":question}, | |
{"role":"assistant","content":answer} | |
] | |
return new_history, docs, images | |
# ───────────────────────────────────────────────────────────────────────────── | |
# Gradio UI | |
CSS = """ | |
footer {visibility:hidden;} | |
""" | |
MODEL_OPTIONS = [ | |
"HuggingFaceH4/zephyr-7b-beta", | |
"mistralai/Mistral-7B-Instruct-v0.2", | |
"openchat/openchat-3.5-0106", | |
"google/gemma-7b-it", | |
"deepseek-ai/deepseek-llm-7b-chat", | |
"microsoft/Phi-3-mini-4k-instruct", | |
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
"Qwen/Qwen1.5-7B-Chat", | |
"tiiuae/falcon-7b-instruct", # Falcon 7B Instruct | |
"bigscience/bloomz-7b1", # BLOOMZ 7B | |
"facebook/opt-2.7b", | |
] | |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
session_state = gr.State({}) | |
with gr.Column(visible=True) as welcome_col: | |
gr.Markdown(f"<div style='text-align:center'>{WELCOME_INTRO}</div>") | |
start_btn = gr.Button("🚀 Start") | |
with gr.Column(visible=False) as app_col: | |
gr.Markdown("## 📚 Multimodal Chat-PDF Playground") | |
extract_event = None | |
with gr.Tabs() as tabs: | |
with gr.TabItem("1. Upload & Extract"): | |
docs = gr.File(file_count="multiple", file_types=[".pdf"], label="Upload PDFs") | |
include_dd = gr.Radio(["Include Images","Exclude Images"],"Exclude Images","Images") | |
ocr_radio = gr.Radio(["Get Text With OCR","Get Available Text Only"],"Get Available Text Only","OCR") | |
ocr_dd = gr.Dropdown(list(OCR_CHOICES.keys()), list(OCR_CHOICES.keys())[0], "OCR Model") | |
vlm_dd = gr.Dropdown(["llava-hf/llava-v1.6-mistral-7b-hf","llava-hf/llava-v1.5-mistral-7b"], "llava-hf/llava-v1.6-mistral-7b-hf", "Vision-Language Model") | |
extract_btn = gr.Button("Extract") | |
preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False) | |
preview_img = gr.Gallery(label="Sample Images", rows=2, value=[]) | |
preview_html = gr.HTML() | |
extract_event = extract_btn.click( | |
fn=extract_data_from_pdfs, | |
inputs=[docs, session_state, include_dd, ocr_radio, ocr_dd, vlm_dd], | |
outputs=[session_state, preview_text, preview_img, preview_html] | |
) | |
with gr.TabItem("2. Chat", visible=False) as chat_tab: | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chat = gr.Chatbot(type="messages", label="Chat") | |
msg = gr.Textbox(placeholder="Ask about your PDF...", label="Your question") | |
send = gr.Button("Send") | |
with gr.Column(scale=1): | |
model_dd = gr.Dropdown(MODEL_OPTIONS, MODEL_OPTIONS[0], "Choose Chat Model") | |
num_ctx = gr.Slider(1,20, value=3, label="Text Contexts") | |
img_ctx = gr.Slider(1,10, value=2, label="Image Contexts") | |
temp = gr.Slider(0.1,1.0, step=0.1, value=0.4, label="Temperature") | |
max_tok = gr.Slider(10,1000, step=10, value=200, label="Max Tokens") | |
send.click( | |
fn=conversation, | |
inputs=[session_state, msg, num_ctx, img_ctx, chat, temp, max_tok, model_dd], | |
outputs=[chat, gr.Dataframe(), gr.Gallery(label="Relevant Images", rows=2, value=[])] | |
) | |
# Unhide the Chat tab once extraction completes | |
extract_event.then( | |
fn=lambda: gr.update(visible=True), | |
inputs=[], | |
outputs=[chat_tab] | |
) | |
gr.HTML("<center>Made with ❤️ by Zamal</center>") | |
start_btn.click( | |
fn=lambda: (gr.update(visible=False), gr.update(visible=True)), | |
outputs=[welcome_col, app_col] | |
) | |
if __name__ == "__main__": | |
demo.launch() |