import os import gradio as gr import fitz # PyMuPDF import shutil import json import torch from PIL import Image import re # Import multimodal and Qwen2-VL models and processor from your dependencies. from byaldi import RAGMultiModalModel from transformers import Qwen2VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info # --- Model Initialization --- def initialize_models(): """ Loads and returns the RAG multimodal and Qwen2-VL models along with the processor. """ multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali") qwen_model = Qwen2VLForConditionalGeneration.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True, torch_dtype=torch.float32 ) qwen_processor = AutoProcessor.from_pretrained( "Qwen/Qwen2-VL-2B-Instruct", trust_remote_code=True ) return multimodal_rag, qwen_model, qwen_processor multimodal_rag, qwen_model, qwen_processor = initialize_models() # --- OCR Function --- def perform_ocr(image: Image.Image) -> str: """ Extracts text from an image using the Qwen2-VL model. """ query = "Extract text from the image in its original language." user_input = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": query} ] } ] input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(user_input) model_inputs = qwen_processor( text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ).to("cpu") # Use CPU for inference with torch.no_grad(): generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000) # Remove the prompt tokens from the generated output trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)] ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] return ocr_result # --- Product Parsing Function --- def parse_product_info(text: str) -> dict: """ Parses the combined OCR text into structured product information using Qwen2-VL. """ prompt = f"""Extract product specifications from the following text. If no product information is found, return an empty JSON object with keys. Text: {text} Return JSON format exactly as: {{ "name": "product name", "description": "product description", "price": numeric_price, "attributes": {{"key": "value"}} }}""" user_input = [{"role": "user", "content": prompt}] input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs = process_vision_info(user_input) model_inputs = qwen_processor( text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt" ).to("cpu") with torch.no_grad(): generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=512) trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)] parsed_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] try: json_start = parsed_result.find('{') json_end = parsed_result.rfind('}') + 1 data = json.loads(parsed_result[json_start:json_end]) except Exception as e: data = {} return data # --- PDF Processing Function --- def process_pdf(pdf_file) -> dict: """ Processes a PDF file by converting each page to an image, performing OCR on each page, and then parsing the combined text into structured product information. """ # Create a temporary directory for the PDF file temp_dir = "./temp_pdf/" os.makedirs(temp_dir, exist_ok=True) pdf_path = os.path.join(temp_dir, pdf_file.name) with open(pdf_path, "wb") as f: if hasattr(pdf_file, "file"): shutil.copyfileobj(pdf_file.file, f) elif hasattr(pdf_file, "name"): # In case pdf_file is a path string (unlikely in Gradio, but safe-guard) shutil.copy(pdf_file.name, pdf_path) else: raise TypeError("Invalid file input type.") # Open the PDF file using PyMuPDF try: doc = fitz.open(pdf_path) except Exception as e: raise RuntimeError(f"Cannot open PDF file: {e}") combined_text = "" # Iterate over each page and extract text via OCR for page in doc: try: # Render page as image; adjust dpi as needed for quality/speed balance pix = page.get_pixmap(dpi=150) img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) page_text = perform_ocr(img) combined_text += page_text + "\n" except Exception as e: print(f"Warning: Failed to process page {page.number + 1}: {e}") # Parse the combined OCR text into structured product info product_info = parse_product_info(combined_text) return product_info # --- Gradio Interface --- with gr.Blocks() as interface: gr.Markdown("