Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import io | |
import base64 | |
import gc | |
from huggingface_hub.utils import HfHubHTTPError | |
from langchain_core.prompts import PromptTemplate | |
from langchain_huggingface import HuggingFaceEndpoint | |
import io, base64 | |
from PIL import Image | |
import gradio as gr | |
import torch | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import pymupdf | |
from PIL import Image | |
from pypdf import PdfReader | |
from dotenv import load_dotenv | |
from welcome_text import WELCOME_INTRO | |
from doctr.io import DocumentFile | |
from doctr.models import ocr_predictor | |
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | |
import chromadb | |
from chromadb.utils import embedding_functions | |
from chromadb.utils.data_loaders import ImageLoader | |
from langchain_core.prompts import PromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_huggingface import HuggingFaceEndpoint | |
from utils import extract_pdfs, extract_images, clean_text, image_to_bytes | |
from utils import * | |
# ───────────────────────────────────────────────────────────────────────────── | |
# Load .env | |
load_dotenv() | |
HF_TOKEN = os.getenv("HUGGINGFACE_TOKEN") | |
# OCR + multimodal image description setup | |
ocr_model = ocr_predictor( | |
"db_resnet50", "crnn_mobilenet_v3_large", pretrained=True, assume_straight_pages=True | |
) | |
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf") | |
vision_model = LlavaNextForConditionalGeneration.from_pretrained( | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
).to("cpu") | |
def get_image_description(image: Image.Image) -> str: | |
"""Generate a one-sentence description via LlavaNext.""" | |
torch.cuda.empty_cache() | |
gc.collect() | |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
inputs = processor(prompt, image, return_tensors="pt").to("cpu") | |
output = vision_model.generate(**inputs, max_new_tokens=100) | |
return processor.decode(output[0], skip_special_tokens=True) | |
# Vector DB setup | |
# at top of file, alongside your other imports | |
from chromadb.utils import embedding_functions | |
from chromadb.utils.data_loaders import ImageLoader | |
import chromadb | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from utils import image_to_bytes # your helper | |
# 1) Create one shared embedding function (defaulting to All-MiniLM-L6-v2, 384-dim) | |
SHARED_EMB_FN = embedding_functions.SentenceTransformerEmbeddingFunction( | |
model_name="all-MiniLM-L6-v2" | |
) | |
def get_vectordb(text: str, images: list[Image.Image], img_names: list[str]): | |
""" | |
Build an in-memory ChromaDB instance with two collections: | |
• text_db (chunks of the PDF text) | |
• image_db (image descriptions + raw image bytes) | |
Returns the Chroma client for later querying. | |
""" | |
# ——— 1) Init & wipe old ———————————————— | |
client = chromadb.EphemeralClient() | |
for col in ("text_db", "image_db"): | |
if col in [c.name for c in client.list_collections()]: | |
client.delete_collection(col) | |
# ——— 2) Create fresh collections ————————— | |
text_col = client.get_or_create_collection( | |
name="text_db", | |
embedding_function=SHARED_EMB_FN, | |
data_loader=ImageLoader(), # loader only matters for images, benign here | |
) | |
img_col = client.get_or_create_collection( | |
name="image_db", | |
embedding_function=SHARED_EMB_FN, | |
metadata={"hnsw:space": "cosine"}, | |
data_loader=ImageLoader(), | |
) | |
# ——— 3) Add images if any ——————————————— | |
if images: | |
descs = [] | |
metas = [] | |
for idx, img in enumerate(images): | |
# build one-line caption (or fallback) | |
try: | |
caption = get_image_description(img) | |
except Exception: | |
caption = "⚠️ could not describe image" | |
descs.append(f"{img_names[idx]}: {caption}") | |
metas.append({"image": image_to_bytes(img)}) | |
img_col.add( | |
ids=[str(i) for i in range(len(images))], | |
documents=descs, | |
metadatas=metas, | |
) | |
# ——— 4) Chunk & add text ——————————————— | |
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50) | |
docs = splitter.create_documents([text]) | |
text_col.add( | |
ids=[str(i) for i in range(len(docs))], | |
documents=[d.page_content for d in docs], | |
) | |
return client | |
# Text extraction | |
def result_to_text(result, as_text=False): | |
pages = [] | |
for pg in result.pages: | |
txt = " ".join(w.value for block in pg.blocks for line in block.lines for w in line.words) | |
pages.append(clean_text(txt)) | |
return "\n\n".join(pages) if as_text else pages | |
OCR_CHOICES = { | |
"db_resnet50 + crnn_mobilenet_v3_large": ("db_resnet50", "crnn_mobilenet_v3_large"), | |
"db_resnet50 + crnn_resnet31": ("db_resnet50", "crnn_resnet31"), | |
} | |
def extract_data_from_pdfs( | |
docs, | |
session, | |
include_images, # "Include Images" or "Exclude Images" | |
do_ocr, # "Get Text With OCR" or "Get Available Text Only" | |
ocr_choice, # key into OCR_CHOICES | |
vlm_choice, # HF repo ID for LlavaNext | |
progress=gr.Progress() | |
): | |
""" | |
1) Dynamically instantiate the chosen OCR pipeline (if any) | |
2) Dynamically instantiate the chosen vision‐language model | |
3) Override the global get_image_description to use that model for captions | |
4) Extract text & images, index into ChromaDB | |
""" | |
if not docs: | |
raise gr.Error("No documents to process") | |
# ——— 1) Set up OCR if requested ———————————————— | |
if do_ocr == "Get Text With OCR": | |
db_m, crnn_m = OCR_CHOICES[ocr_choice] | |
local_ocr = ocr_predictor(db_m, crnn_m, pretrained=True, assume_straight_pages=True) | |
else: | |
local_ocr = None | |
# ——— 2) Set up vision‐language model ————————————— | |
proc = LlavaNextProcessor.from_pretrained(vlm_choice) | |
vis = LlavaNextForConditionalGeneration.from_pretrained( | |
vlm_choice, | |
torch_dtype=torch.float16, | |
low_cpu_mem_usage=True | |
).to("cpu") | |
# ——— 3) Monkey‐patch global get_image_description ———— | |
def describe(img: Image.Image) -> str: | |
torch.cuda.empty_cache(); gc.collect() | |
prompt = "[INST] <image>\nDescribe the image in a sentence [/INST]" | |
inputs = proc(prompt, img, return_tensors="pt").to("cpu") | |
output = vis.generate(**inputs, max_new_tokens=100) | |
return proc.decode(output[0], skip_special_tokens=True) | |
global get_image_description | |
get_image_description = describe | |
# ——— 4) Extract text & images ———————————————— | |
progress(0.2, "Extracting text and images…") | |
all_text, images, names = "", [], [] | |
for path in docs: | |
if local_ocr: | |
pdf = DocumentFile.from_pdf(path) | |
res = local_ocr(pdf) | |
all_text += result_to_text(res, as_text=True) + "\n\n" | |
else: | |
txt = PdfReader(path).pages[0].extract_text() or "" | |
all_text += "\n\n" + txt + "\n\n" | |
if include_images == "Include Images": | |
imgs = extract_images([path]) | |
images.extend(imgs) | |
names.extend([os.path.basename(path)] * len(imgs)) | |
# ——— 5) Index into vector DB ———————————————— | |
progress(0.6, "Indexing in vector DB…") | |
vdb = get_vectordb(all_text, images, names) | |
session["processed"] = True | |
sample_imgs = images[:4] if include_images == "Include Images" else [] | |
return ( | |
vdb, | |
session, | |
gr.Row(visible=True), | |
all_text[:2000] + "...", | |
sample_imgs, | |
"<h3>Done!</h3>" | |
) | |
# Chat function | |
def conversation( | |
vdb, question: str, num_ctx, img_ctx, | |
history: list, temp: float, max_tok: int, model_id: str | |
): | |
# 0) Cast the context sliders to ints | |
num_ctx = int(num_ctx) | |
img_ctx = int(img_ctx) | |
# 1) Guard: must have extracted first | |
if vdb is None: | |
raise gr.Error("Please extract data first") | |
# 2) Instantiate the chosen HF endpoint | |
llm = HuggingFaceEndpoint( | |
repo_id=model_id, | |
temperature=temp, | |
max_new_tokens=max_tok, | |
huggingfacehub_api_token=HF_TOKEN | |
) | |
# 3) Query text collection | |
text_col = vdb.get_collection("text_db") | |
docs = text_col.query( | |
query_texts=[question], | |
n_results=num_ctx, # now an int | |
include=["documents"] | |
)["documents"][0] | |
# 4) Query image collection | |
img_col = vdb.get_collection("image_db") | |
img_q = img_col.query( | |
query_texts=[question], | |
n_results=img_ctx, # now an int | |
include=["metadatas", "documents"] | |
) | |
# … rest unchanged … | |
images, img_descs = [], img_q["documents"][0] or ["No images found"] | |
for meta in img_q["metadatas"][0]: | |
b64 = meta.get("image", "") | |
try: | |
images.append(Image.open(io.BytesIO(base64.b64decode(b64)))) | |
except: | |
pass | |
img_desc = "\n".join(img_descs) | |
# 5) Build prompt | |
prompt = PromptTemplate( | |
template=""" | |
Context: | |
{text} | |
Included Images: | |
{img_desc} | |
Question: | |
{q} | |
Answer: | |
""", | |
input_variables=["text", "img_desc", "q"], | |
) | |
context = "\n\n".join(docs) | |
user_input = prompt.format(text=context, img_desc=img_desc, q=question) | |
# 6) Call the model with error handling | |
try: | |
answer = llm.invoke(user_input) | |
except HfHubHTTPError as e: | |
if e.response.status_code == 404: | |
answer = f"❌ Model `{model_id}` not hosted on HF Inference API." | |
else: | |
answer = f"⚠️ HF API error: {e}" | |
except Exception as e: | |
answer = f"⚠️ Unexpected error: {e}" | |
# 7) Append to history | |
new_history = history + [ | |
{"role":"user", "content": question}, | |
{"role":"assistant","content": answer} | |
] | |
# 8) Return updated history, docs, images | |
return new_history, docs, images | |
# ───────────────────────────────────────────────────────────────────────────── | |
# Gradio UI | |
CSS = """ | |
footer {visibility:hidden;} | |
""" | |
MODEL_OPTIONS = [ | |
"HuggingFaceH4/zephyr-7b-beta", | |
"mistralai/Mistral-7B-Instruct-v0.2", | |
"openchat/openchat-3.5-0106", | |
"google/gemma-7b-it", | |
"deepseek-ai/deepseek-llm-7b-chat", | |
"microsoft/Phi-3-mini-4k-instruct", | |
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", | |
"Qwen/Qwen1.5-7B-Chat", | |
"tiiuae/falcon-7b-instruct", # Falcon 7B Instruct | |
"bigscience/bloomz-7b1", # BLOOMZ 7B | |
"facebook/opt-2.7b", | |
] | |
with gr.Blocks(css=CSS, theme=gr.themes.Soft()) as demo: | |
vdb_state = gr.State() | |
session_state = gr.State({}) | |
# ─── Welcome Screen ───────────────────────────────────────────── | |
with gr.Column(visible=True) as welcome_col: | |
gr.Markdown( | |
f"<div style='text-align: center'>\n{WELCOME_INTRO}\n</div>", | |
elem_id="welcome_md" | |
) | |
start_btn = gr.Button("🚀 Start") | |
# ─── Main App (hidden until Start is clicked) ─────────────────── | |
with gr.Column(visible=False) as app_col: | |
gr.Markdown("## 📚 Multimodal Chat-PDF Playground") | |
with gr.Tabs(): | |
# Tab 1: Upload & Extract | |
with gr.TabItem("1. Upload & Extract"): | |
docs = gr.File( | |
file_count="multiple", | |
file_types=[".pdf"], | |
label="Upload PDFs" | |
) | |
include_dd = gr.Radio( | |
["Include Images", "Exclude Images"], | |
value="Exclude Images", | |
label="Images" | |
) | |
ocr_dd = gr.Dropdown( | |
choices=[ | |
"db_resnet50 + crnn_mobilenet_v3_large", | |
"db_resnet50 + crnn_resnet31" | |
], | |
value="db_resnet50 + crnn_mobilenet_v3_large", | |
label="OCR Model" | |
) | |
vlm_dd = gr.Dropdown( | |
choices=[ | |
"llava-hf/llava-v1.6-mistral-7b-hf", | |
"llava-hf/llava-v1.5-mistral-7b" | |
], | |
value="llava-hf/llava-v1.6-mistral-7b-hf", | |
label="Vision-Language Model" | |
) | |
extract_btn = gr.Button("Extract") | |
preview_text = gr.Textbox(lines=10, label="Sample Text", interactive=False) | |
preview_img = gr.Gallery(label="Sample Images", rows=2, value=[]) | |
extract_btn.click( | |
extract_data_from_pdfs, | |
inputs=[ | |
docs, | |
session_state, | |
include_dd, | |
gr.Radio( | |
["Get Text With OCR", "Get Available Text Only"], | |
value="Get Available Text Only", | |
label="OCR" | |
), | |
ocr_dd, | |
vlm_dd | |
], | |
outputs=[ | |
vdb_state, | |
session_state, | |
gr.Row(visible=False), | |
preview_text, | |
preview_img, | |
gr.HTML() | |
] | |
) | |
# Tab 2: Chat | |
with gr.TabItem("2. Chat"): | |
with gr.Row(): | |
with gr.Column(scale=3): | |
chat = gr.Chatbot(type="messages", label="Chat") | |
msg = gr.Textbox( | |
placeholder="Ask about your PDF...", | |
label="Your question" | |
) | |
send = gr.Button("Send") | |
with gr.Column(scale=1): | |
model_dd = gr.Dropdown( | |
MODEL_OPTIONS, | |
value=MODEL_OPTIONS[0], | |
label="Choose Chat Model" | |
) | |
num_ctx = gr.Slider(1,20,value=3,label="Text Contexts") | |
img_ctx = gr.Slider(1,10,value=2,label="Image Contexts") | |
temp = gr.Slider(0.1,1.0,step=0.1,value=0.4,label="Temperature") | |
max_tok = gr.Slider(10,1000,step=10,value=200,label="Max Tokens") | |
send.click( | |
conversation, | |
inputs=[ | |
vdb_state, | |
msg, | |
num_ctx, | |
img_ctx, | |
chat, | |
temp, | |
max_tok, | |
model_dd | |
], | |
outputs=[ | |
chat, | |
gr.Dataframe(), | |
gr.Gallery(label="Relevant Images", rows=2, value=[]) | |
] | |
) | |
# Footer inside app_col | |
gr.HTML("<center>Made with ❤️ by Zamal</center>") | |
# ─── Wire the Start button ─────────────────────────────────────── | |
start_btn.click( | |
fn=lambda: (gr.update(visible=False), gr.update(visible=True)), | |
inputs=[], outputs=[welcome_col, app_col] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |