minar09 commited on
Commit
f8daace
·
verified ·
1 Parent(s): a5b6e19

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +37 -137
  2. main.py +199 -0
app.py CHANGED
@@ -1,148 +1,48 @@
1
  import os
2
  import gradio as gr
3
- import fitz # PyMuPDF
4
  import shutil
5
- import json
6
- import torch
7
- from PIL import Image
8
- import re
9
 
10
- # Import multimodal and Qwen2-VL models and processor from your dependencies.
11
- from byaldi import RAGMultiModalModel
12
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
13
- from qwen_vl_utils import process_vision_info
14
 
15
- # --- Model Initialization ---
 
 
 
16
 
17
- def initialize_models():
18
- """
19
- Loads and returns the RAG multimodal and Qwen2-VL models along with the processor.
20
- """
21
- multimodal_rag = RAGMultiModalModel.from_pretrained("vidore/colpali")
22
- qwen_model = Qwen2VLForConditionalGeneration.from_pretrained(
23
- "Qwen/Qwen2-VL-2B-Instruct",
24
- trust_remote_code=True,
25
- torch_dtype=torch.float32
26
- )
27
- qwen_processor = AutoProcessor.from_pretrained(
28
- "Qwen/Qwen2-VL-2B-Instruct",
29
- trust_remote_code=True
30
- )
31
- return multimodal_rag, qwen_model, qwen_processor
32
 
33
- multimodal_rag, qwen_model, qwen_processor = initialize_models()
34
-
35
- # --- OCR Function ---
36
- def perform_ocr(image: Image.Image) -> str:
37
- """
38
- Extracts text from an image using the Qwen2-VL model.
39
- """
40
- query = "Extract text from the image in its original language."
41
- user_input = [
42
- {
43
- "role": "user",
44
- "content": [
45
- {"type": "image", "image": image},
46
- {"type": "text", "text": query}
47
- ]
48
- }
49
- ]
50
- input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
51
- image_inputs, video_inputs = process_vision_info(user_input)
52
- model_inputs = qwen_processor(
53
- text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
54
- ).to("cpu") # Use CPU for inference
55
- with torch.no_grad():
56
- generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=2000)
57
- # Remove the prompt tokens from the generated output
58
- trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)]
59
- ocr_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
60
- return ocr_result
61
-
62
- # --- Product Parsing Function ---
63
- def parse_product_info(text: str) -> dict:
64
- """
65
- Parses the combined OCR text into structured product information using Qwen2-VL.
66
- """
67
- prompt = f"""Extract product specifications from the following text. If no product information is found, return an empty JSON object with keys.
68
-
69
- Text:
70
- {text}
71
-
72
- Return JSON format exactly as:
73
- {{
74
- "name": "product name",
75
- "description": "product description",
76
- "price": numeric_price,
77
- "attributes": {{"key": "value"}}
78
- }}"""
79
- user_input = [{"role": "user", "content": prompt}]
80
- input_text = qwen_processor.apply_chat_template(user_input, tokenize=False, add_generation_prompt=True)
81
- image_inputs, video_inputs = process_vision_info(user_input)
82
- model_inputs = qwen_processor(
83
- text=[input_text], images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt"
84
- ).to("cpu")
85
- with torch.no_grad():
86
- generated_ids = qwen_model.generate(**model_inputs, max_new_tokens=512)
87
- trimmed_ids = [output[len(model_inputs.input_ids):] for model_inputs.input_ids, output in zip(model_inputs.input_ids, generated_ids)]
88
- parsed_result = qwen_processor.batch_decode(trimmed_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
89
  try:
90
- json_start = parsed_result.find('{')
91
- json_end = parsed_result.rfind('}') + 1
92
- data = json.loads(parsed_result[json_start:json_end])
93
- except Exception as e:
94
- data = {}
95
- return data
96
 
97
- # --- PDF Processing Function ---
98
- def process_pdf(pdf_file) -> dict:
99
- """
100
- Processes a PDF file by converting each page to an image,
101
- performing OCR on each page, and then parsing the combined
102
- text into structured product information.
103
- """
104
- # Create a temporary directory for the PDF file
105
- temp_dir = "./temp_pdf/"
106
- os.makedirs(temp_dir, exist_ok=True)
107
- pdf_path = os.path.join(temp_dir, pdf_file.name)
108
- with open(pdf_path, "wb") as f:
109
- if hasattr(pdf_file, "file"):
110
- shutil.copyfileobj(pdf_file.file, f)
111
- elif hasattr(pdf_file, "name"):
112
- # In case pdf_file is a path string (unlikely in Gradio, but safe-guard)
113
- shutil.copy(pdf_file.name, pdf_path)
114
- else:
115
- raise TypeError("Invalid file input type.")
116
-
117
- # Open the PDF file using PyMuPDF
118
- try:
119
- doc = fitz.open(pdf_path)
120
- except Exception as e:
121
- raise RuntimeError(f"Cannot open PDF file: {e}")
122
-
123
- combined_text = ""
124
- # Iterate over each page and extract text via OCR
125
- for page in doc:
126
- try:
127
- # Render page as image; adjust dpi as needed for quality/speed balance
128
- pix = page.get_pixmap(dpi=150)
129
- img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
130
- page_text = perform_ocr(img)
131
- combined_text += page_text + "\n"
132
- except Exception as e:
133
- print(f"Warning: Failed to process page {page.number + 1}: {e}")
134
-
135
- # Parse the combined OCR text into structured product info
136
- product_info = parse_product_info(combined_text)
137
- return product_info
138
 
139
- # --- Gradio Interface ---
140
- with gr.Blocks() as interface:
141
- gr.Markdown("<h1 style='text-align: center;'>PDF Product Info Extractor</h1>")
142
- with gr.Row():
143
- pdf_input = gr.File(label="Upload PDF File", file_count="single")
144
- extract_btn = gr.Button("Extract Product Info")
145
- output_box = gr.JSON(label="Extracted Product Info")
146
- extract_btn.click(process_pdf, inputs=pdf_input, outputs=output_box)
147
 
148
- interface.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import gradio as gr
3
+ import main
4
  import shutil
 
 
 
 
5
 
 
 
 
 
6
 
7
+ def predict_from_pdf(pdf_file):
8
+ # Create a temporary directory for file uploads
9
+ upload_dir = "./catalogue/"
10
+ os.makedirs(upload_dir, exist_ok=True)
11
 
12
+ # Use the provided file path from Gradio's file object
13
+ dest_file_path = os.path.join(upload_dir, os.path.basename(pdf_file.name))
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  try:
16
+ # Save the uploaded file using shutil.copy
17
+ shutil.copy(pdf_file, dest_file_path)
 
 
 
 
18
 
19
+ # Check if the file was saved successfully
20
+ if not os.path.exists(dest_file_path):
21
+ return None, f"Error: The file {dest_file_path} could not be found or opened."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Process the PDF and retrieve the product info
24
+ df, response = main.process_pdf_catalog(dest_file_path)
25
+ return df, response
 
 
 
 
 
26
 
27
+ except Exception as e:
28
+ return None, f"Error processing PDF: {str(e)}"
29
+
30
+
31
+ # Define example PDFs
32
+ pdf_examples = [
33
+ ["catalogue/flexpocket.pdf"],
34
+ ["catalogue/ASICS_Catalog.pdf"],
35
+ ]
36
+
37
+ demo = gr.Interface(
38
+ fn=predict_from_pdf,
39
+ inputs=gr.File(label="Upload PDF Catalog"),
40
+ outputs=["json", "text"],
41
+ examples=pdf_examples,
42
+ title="Open Source PDF Catalog Parser",
43
+ description="Efficient PDF catalog processing using fitz and OpenLLM",
44
+ article="Uses PyMuPDF for layout analysis and Llama-CPP for structured extraction"
45
+ )
46
+
47
+ if __name__ == "__main__":
48
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860, share=True)
main.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import time
4
+ import logging
5
+ from pathlib import Path
6
+ from typing import List, Dict, Optional
7
+ from dataclasses import dataclass
8
+ from fastapi.encoders import jsonable_encoder
9
+ import fitz # PyMuPDF
10
+ from sentence_transformers import SentenceTransformer
11
+ from llama_cpp import Llama
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class ProductSpec:
19
+ name: str
20
+ description: Optional[str] = None
21
+ price: Optional[float] = None
22
+ attributes: Dict[str, str] = None
23
+ tables: List[Dict] = None
24
+
25
+ def to_dict(self):
26
+ return jsonable_encoder(self)
27
+
28
+
29
+ class PDFProcessor:
30
+ def __init__(self):
31
+ self.emb_model = self._initialize_emb_model("all-MiniLM-L6-v2")
32
+ # Choose the appropriate model filename below; adjust if needed.
33
+ # self.llm = self._initialize_llm("deepseek-llm-7b-base.Q2_K.gguf")
34
+ self.llm = self._initialize_llm("llama-2-7b.Q2_K.gguf")
35
+ self.output_dir = Path("./output")
36
+ self.output_dir.mkdir(exist_ok=True)
37
+
38
+ def _initialize_emb_model(self, model_name):
39
+ try:
40
+ # Use SentenceTransformer if available
41
+ return SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
42
+ except Exception as e:
43
+ logger.warning(f"SentenceTransformer failed: {e}. Falling back to transformers model.")
44
+ from transformers import AutoTokenizer, AutoModel
45
+ tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/" + model_name)
46
+ model = AutoModel.from_pretrained("sentence-transformers/" + model_name)
47
+ return model
48
+
49
+ def _initialize_llm(self, model_name):
50
+ """Initialize LLM with automatic download if needed"""
51
+ # Here we use from_pretrained so that if the model is missing locally it downloads it.
52
+ model_path = os.path.join("models/", model_name)
53
+ if os.path.exists(model_path):
54
+ return Llama(
55
+ model_path=model_path,
56
+ n_ctx=1024,
57
+ n_gpu_layers=-1,
58
+ n_threads=os.cpu_count() - 1,
59
+ verbose=False
60
+ )
61
+ else:
62
+ return Llama.from_pretrained(
63
+ repo_id="Tien203/llama.cpp",
64
+ filename="Llama-2-7b-hf-q4_0.gguf",
65
+ )
66
+
67
+ def process_pdf(self, pdf_path: str) -> Dict:
68
+ """Process PDF using PyMuPDF"""
69
+ start_time = time.time()
70
+
71
+ # Open PDF
72
+ try:
73
+ doc = fitz.open(pdf_path)
74
+ except Exception as e:
75
+ logger.error(f"Failed to open PDF: {e}")
76
+ raise RuntimeError("Cannot open PDF file.") from e
77
+
78
+ text_blocks = []
79
+ tables = []
80
+
81
+ # Extract text and tables from each page
82
+ for page_num, page in enumerate(doc):
83
+ # Extract text blocks from page and filter out very short blocks (noise)
84
+ blocks = self._extract_text_blocks(page)
85
+ filtered = [block for block in blocks if len(block.strip()) >= 10]
86
+ logger.debug(f"Page {page_num + 1}: Extracted {len(blocks)} blocks, {len(filtered)} kept after filtering.")
87
+ text_blocks.extend(filtered)
88
+
89
+ # Extract tables (if any)
90
+ tables.extend(self._extract_tables(page, page_num))
91
+
92
+ # Process text blocks with LLM to extract product information
93
+ products = []
94
+ for idx, block in enumerate(text_blocks):
95
+ # Log the text block for debugging
96
+ logger.debug(f"Processing text block {idx}: {block[:100]}...")
97
+ product = self._process_text_block(block)
98
+ if product:
99
+ product.tables = tables
100
+ # Only add if at least one key (like name) is non-empty
101
+ if product.name or product.description or product.price or (
102
+ product.attributes and len(product.attributes) > 0):
103
+ products.append(product.to_dict())
104
+ else:
105
+ logger.debug(f"LLM returned empty product for block {idx}.")
106
+ else:
107
+ logger.debug(f"No product extracted from block {idx}.")
108
+
109
+ logger.info(f"Processed {len(products)} products in {time.time() - start_time:.2f}s")
110
+ return {"products": products, "tables": tables}
111
+
112
+ def _extract_text_blocks(self, page) -> List[str]:
113
+ """Extract text blocks from a PDF page using PyMuPDF's blocks method."""
114
+ blocks = []
115
+ for block in page.get_text("blocks"):
116
+ # block[4] contains the text content
117
+ text = block[4].strip()
118
+ if text:
119
+ blocks.append(text)
120
+ return blocks
121
+
122
+ def _extract_tables(self, page, page_num: int) -> List[Dict]:
123
+ """Extract tables from a PDF page using PyMuPDF's table extraction (if available)."""
124
+ tables = []
125
+ try:
126
+ tab = page.find_tables()
127
+ if tab and hasattr(tab, 'tables') and tab.tables:
128
+ for table in tab.tables:
129
+ table_data = table.extract()
130
+ if table_data:
131
+ tables.append({
132
+ "page": page_num + 1,
133
+ "cells": table_data,
134
+ "header": table.header.names if table.header else [],
135
+ "content": table_data
136
+ })
137
+ except Exception as e:
138
+ logger.warning(f"Error extracting tables from page {page_num + 1}: {e}")
139
+ return tables
140
+
141
+ def _process_text_block(self, text: str) -> Optional[ProductSpec]:
142
+ """Process a text block with LLM to extract product specifications."""
143
+ prompt = self._generate_query_prompt(text)
144
+ logger.debug(f"Generated prompt: {prompt[:200]}...")
145
+ try:
146
+ response = self.llm.create_chat_completion(
147
+ messages=[{"role": "user", "content": prompt}],
148
+ temperature=0.1,
149
+ max_tokens=512
150
+ )
151
+ # Debug: log raw response
152
+ logger.debug(f"LLM raw response: {response}")
153
+ return self._parse_response(response['choices'][0]['message']['content'])
154
+ except Exception as e:
155
+ logger.warning(f"Error processing text block: {e}")
156
+ return None
157
+
158
+ def _generate_query_prompt(self, text: str) -> str:
159
+ """Generate a prompt instructing the LLM to extract product information."""
160
+ return f"""Extract product specifications from the following text. If no product is found, return an empty JSON object with keys.\n\nText:\n{text}\n\nReturn JSON format exactly as:\n{{\n \"name\": \"product name\",\n \"description\": \"product description\",\n \"price\": numeric_price,\n \"attributes\": {{ \"key\": \"value\" }}\n}}"""
161
+
162
+ def _parse_response(self, response: str) -> Optional[ProductSpec]:
163
+ """Parse the LLM's response to extract a product specification."""
164
+ try:
165
+ json_start = response.find('{')
166
+ json_end = response.rfind('}') + 1
167
+ json_str = response[json_start:json_end].strip()
168
+ if not json_str:
169
+ raise ValueError("No JSON content found in response.")
170
+ data = json.loads(json_str)
171
+ # If the returned JSON is essentially empty, return None
172
+ if all(not data.get(key) for key in ['name', 'description', 'price', 'attributes']):
173
+ return None
174
+ return ProductSpec(
175
+ name=data.get('name', ''),
176
+ description=data.get('description'),
177
+ price=data.get('price'),
178
+ attributes=data.get('attributes', {})
179
+ )
180
+ except (json.JSONDecodeError, KeyError, ValueError) as e:
181
+ logger.warning(f"Parse error: {e} in response: {response}")
182
+ return None
183
+
184
+
185
+ def process_pdf_catalog(pdf_path: str):
186
+ processor = PDFProcessor()
187
+ try:
188
+ result = processor.process_pdf(pdf_path)
189
+ return result, "Processing completed successfully!"
190
+ except Exception as e:
191
+ logger.error(f"Processing failed: {e}")
192
+ return {}, "Error processing PDF"
193
+
194
+
195
+ if __name__ == "__main__":
196
+ # Example usage: change this if you call process_pdf_catalog elsewhere
197
+ pdf_path = "path/to/your/pdf_file.pdf"
198
+ result, message = process_pdf_catalog(pdf_path)
199
+ print(result, message)