minar09 commited on
Commit
d27c873
·
verified ·
1 Parent(s): 0362a74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -37
app.py CHANGED
@@ -1,48 +1,148 @@
1
  import os
2
  import gradio as gr
3
- import main
4
  import shutil
 
 
 
 
5
 
 
 
 
 
6
 
7
- def predict_from_pdf(pdf_file):
8
- # Create a temporary directory for file uploads
9
- upload_dir = "./catalogue/"
10
- os.makedirs(upload_dir, exist_ok=True)
11
 
12
- # Use the provided file path from Gradio's file object
13
- dest_file_path = os.path.join(upload_dir, os.path.basename(pdf_file.name))
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- try:
16
- # Save the uploaded file using shutil.copy
17
- shutil.copy(pdf_file, dest_file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # Check if the file was saved successfully
20
- if not os.path.exists(dest_file_path):
21
- return None, f"Error: The file {dest_file_path} could not be found or opened."
 
 
 
22
 
23
- # Process the PDF and retrieve the product info
24
- df, response = main.process_pdf_catalog(dest_file_path)
25
- return df, response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  except Exception as e:
28
- return None, f"Error processing PDF: {str(e)}"
29
-
30
-
31
- # Define example PDFs
32
- pdf_examples = [
33
- ["catalogue/flexpocket.pdf"],
34
- ["catalogue/ASICS_Catalog.pdf"],
35
- ]
36
-
37
- demo = gr.Interface(
38
- fn=predict_from_pdf,
39
- inputs=gr.File(label="Upload PDF Catalog"),
40
- outputs=["json", "text"],
41
- examples=pdf_examples,
42
- title="Open Source PDF Catalog Parser",
43
- description="Efficient PDF catalog processing using fitz and OpenLLM",
44
- article="Uses PyMuPDF for layout analysis and Llama-CPP for structured extraction"
45
- )
46
-
47
- if __name__ == "__main__":
48
- demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
 
 
 
 
 
 
 
 
1
  import os
2
  import gradio as gr
3
+ import fitz # PyMuPDF
4
  import shutil
5
+ import json
6
+ import torch
7
+ from PIL import Image
8
+ import re
9
 
10
+ # Import multimodal and Qwen2-VL models and processor from your dependencies.
11
+ from byaldi import RAGMultiModalModel
12
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
13
+ from qwen_vl_utils import process_vision_info
14
 
15
+ # --- Model Initialization ---
 
 
 
16
 
17
+ def initialize_models():
18
+ """
19
+ Loads and returns the RAG multimodal and Qwen2-VL models along with the processor.
20
+ """
21
+ multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
22
+ qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
23
+ "Qwen/Qwen2-VL-2B-Instruct",
24
+ trust_remote_code=True,
25
+ torch_dtype=torch.float32
26
+ )
27
+ qwen_processor = AutoProcessor.from_pretrained(
28
+ "Qwen/Qwen2-VL-2B-Instruct",
29
+ trust_remote_code=True
30
+ )
31
+ return multimodal_rag, qwen_model, qwen_processor
32
 
33
+ multimodal_rag, qwen_model, qwen_processor = initialize_models()
34
+
35
+ # --- OCR Function ---
36
+ def perform_ocr(image: Image.Image) -> str:
37
+ """
38
+ Extracts text from an image using the Qwen2-VL model.
39
+ """
40
+ query = "Extract text from the image in its original language."
41
+ user_input = [
42
+ {
43
+ "role": "user",
44
+ "content": [
45
+ {"type": "image", "image": image},
46
+ {"type": "text", "text": query}
47
+ ]
48
+ }
49
+ ]
50
+ input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
51
+ image_inputs, video_inputs = process_vision_info(user_input)
52
+ model_inputs = qwen_processor(
53
+ text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
54
+ ).to("cpu") # Use CPU for inference
55
+ with torch.no_grad():
56
+ generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000)
57
+ # Remove the prompt tokens from the generated output
58
+ trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)]
59
+ ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
60
+ return ocr_result
61
 
62
+ # --- Product Parsing Function ---
63
+ def parse_product_info(text: str) -> dict:
64
+ """
65
+ Parses the combined OCR text into structured product information using Qwen2-VL.
66
+ """
67
+ prompt = f"""Extract product specifications from the following text. If no product information is found, return an empty JSON object with keys.
68
 
69
+ Text:
70
+ {text}
71
+
72
+ Return JSON format exactly as:
73
+ {{
74
+ "name": "product name",
75
+ "description": "product description",
76
+ "price": numeric_price,
77
+ "attributes": {{"key": "value"}}
78
+ }}"""
79
+ user_input = [{"role": "user", "content": prompt}]
80
+ input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
81
+ image_inputs, video_inputs = process_vision_info(user_input)
82
+ model_inputs = qwen_processor(
83
+ text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
84
+ ).to("cpu")
85
+ with torch.no_grad():
86
+ generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=512)
87
+ trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)]
88
+ parsed_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
89
+ try:
90
+ json_start = parsed_result.find('{')
91
+ json_end = parsed_result.rfind('}') + 1
92
+ data = json.loads(parsed_result[json_start:json_end])
93
+ except Exception as e:
94
+ data = {}
95
+ return data
96
 
97
+ # --- PDF Processing Function ---
98
+ def process_pdf(pdf_file) -> dict:
99
+ """
100
+ Processes a PDF file by converting each page to an image,
101
+ performing OCR on each page, and then parsing the combined
102
+ text into structured product information.
103
+ """
104
+ # Create a temporary directory for the PDF file
105
+ temp_dir = "./temp_pdf/"
106
+ os.makedirs(temp_dir, exist_ok=True)
107
+ pdf_path = os.path.join(temp_dir, pdf_file.name)
108
+ with open(pdf_path, "wb") as f:
109
+ if hasattr(pdf_file, "file"):
110
+ shutil.copyfileobj(pdf_file.file, f)
111
+ elif hasattr(pdf_file, "name"):
112
+ # In case pdf_file is a path string (unlikely in Gradio, but safe-guard)
113
+ shutil.copy(pdf_file.name, pdf_path)
114
+ else:
115
+ raise TypeError("Invalid file input type.")
116
+
117
+ # Open the PDF file using PyMuPDF
118
+ try:
119
+ doc = fitz.open(pdf_path)
120
  except Exception as e:
121
+ raise RuntimeError(f"Cannot open PDF file: {e}")
122
+
123
+ combined_text = ""
124
+ # Iterate over each page and extract text via OCR
125
+ for page in doc:
126
+ try:
127
+ # Render page as image; adjust dpi as needed for quality/speed balance
128
+ pix = page.get_pixmap(dpi=150)
129
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
130
+ page_text = perform_ocr(img)
131
+ combined_text += page_text + "\n"
132
+ except Exception as e:
133
+ print(f"Warning: Failed to process page {page.number + 1}: {e}")
134
+
135
+ # Parse the combined OCR text into structured product info
136
+ product_info = parse_product_info(combined_text)
137
+ return product_info
138
+
139
+ # --- Gradio Interface ---
140
+ with gr.Blocks() as interface:
141
+ gr.Markdown("<h1 style='text-align: center;'>PDF Product Info Extractor</h1>")
142
+ with gr.Row():
143
+ pdf_input = gr.File(label="Upload PDF File", file_count="single")
144
+ extract_btn = gr.Button("Extract Product Info")
145
+ output_box = gr.JSON(label="Extracted Product Info")
146
+ extract_btn.click(process_pdf, inputs=pdf_input, outputs=output_box)
147
+
148
+ interface.launch(debug=True)