Ayush kumar commited on
Commit
7fd6367
·
1 Parent(s): 36bd59a

Initial PDF heading extractor app

Browse files
Files changed (2) hide show
  1. app.py +265 -4
  2. requirements.txt +8 -0
app.py CHANGED
@@ -1,7 +1,268 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import json
4
+ from PIL import Image, ImageDraw
5
+ import numpy as np
6
+ from transformers import (
7
+ LayoutLMv3FeatureExtractor,
8
+ LayoutLMv3Tokenizer,
9
+ LayoutLMv3ForTokenClassification,
10
+ LayoutLMv3Config
11
+ )
12
+ import pytesseract
13
+ from datasets import load_dataset
14
+ import os
15
 
16
+ # Set up device
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ print(f"Using device: {device}")
19
 
20
+ # Constants
21
+ NUM_LABELS = 5 # 0: regular text, 1: title, 2: H1, 3: H2, 4: H3
22
+
23
+ def create_student_model(num_labels=5):
24
+ """Create a distilled version of LayoutLMv3"""
25
+ student_config = LayoutLMv3Config(
26
+ hidden_size=384, # vs 768 original
27
+ num_attention_heads=6, # vs 12 original
28
+ intermediate_size=1536, # vs 3072 original
29
+ num_hidden_layers=8, # vs 12 original
30
+ num_labels=num_labels
31
+ )
32
+
33
+ model = LayoutLMv3ForTokenClassification(student_config)
34
+ return model
35
+
36
+ def load_model():
37
+ """Load the model and components"""
38
+ print("Creating model components...")
39
+
40
+ # Create feature extractor
41
+ feature_extractor = LayoutLMv3FeatureExtractor(
42
+ do_resize=True,
43
+ size=224,
44
+ apply_ocr=False,
45
+ image_mean=[0.5, 0.5, 0.5],
46
+ image_std=[0.5, 0.5, 0.5]
47
+ )
48
+
49
+ # Create tokenizer
50
+ tokenizer = LayoutLMv3Tokenizer.from_pretrained("microsoft/layoutlmv3-base")
51
+
52
+ # Create student model
53
+ model = create_student_model(num_labels=NUM_LABELS)
54
+ model.to(device)
55
+
56
+ # For demo purposes, we'll use random weights
57
+ # In production, you would load your trained weights here
58
+ print("Model components created successfully!")
59
+
60
+ return model, feature_extractor, tokenizer
61
+
62
+ def perform_ocr(image):
63
+ """Extract text and bounding boxes from image using OCR"""
64
+ try:
65
+ # Convert PIL image to numpy array
66
+ img_array = np.array(image)
67
+
68
+ # Get OCR data
69
+ ocr_data = pytesseract.image_to_data(img_array, output_type=pytesseract.Output.DICT)
70
+
71
+ words = []
72
+ boxes = []
73
+ confidences = ocr_data['conf']
74
+
75
+ for i in range(len(ocr_data['text'])):
76
+ if int(confidences[i]) > 30: # Filter low confidence
77
+ word = ocr_data['text'][i].strip()
78
+ if word: # Only add non-empty words
79
+ x, y, w, h = (ocr_data['left'][i], ocr_data['top'][i],
80
+ ocr_data['width'][i], ocr_data['height'][i])
81
+
82
+ # Normalize coordinates
83
+ img_width, img_height = image.size
84
+ normalized_box = [
85
+ x / img_width,
86
+ y / img_height,
87
+ (x + w) / img_width,
88
+ (y + h) / img_height
89
+ ]
90
+
91
+ words.append(word)
92
+ boxes.append(normalized_box)
93
+
94
+ return words, boxes
95
+
96
+ except Exception as e:
97
+ print(f"OCR failed: {e}")
98
+ return ["sample", "text"], [[0, 0, 0.5, 0.1], [0.5, 0, 1.0, 0.1]]
99
+
100
+ def extract_headings_from_image(image, model, feature_extractor, tokenizer):
101
+ """Extract headings from uploaded image using the model"""
102
+ try:
103
+ # Perform OCR to get words and boxes
104
+ words, boxes = perform_ocr(image)
105
+
106
+ if not words:
107
+ return {"ERROR": ["No text found in image"]}
108
+
109
+ # Prepare inputs for the model
110
+ # Process image
111
+ pixel_values = feature_extractor(image, return_tensors="pt")["pixel_values"]
112
+ pixel_values = pixel_values.to(device)
113
+
114
+ # Process text and boxes (limit to first 512 tokens)
115
+ max_words = min(len(words), 500) # Leave room for special tokens
116
+ words = words[:max_words]
117
+ boxes = boxes[:max_words]
118
+
119
+ # Convert boxes to the format expected by LayoutLMv3 (0-1000 scale)
120
+ scaled_boxes = []
121
+ for box in boxes:
122
+ scaled_box = [
123
+ int(box[0] * 1000),
124
+ int(box[1] * 1000),
125
+ int(box[2] * 1000),
126
+ int(box[3] * 1000)
127
+ ]
128
+ scaled_boxes.append(scaled_box)
129
+
130
+ # Tokenize
131
+ encoding = tokenizer(
132
+ words,
133
+ boxes=scaled_boxes,
134
+ max_length=512,
135
+ padding="max_length",
136
+ truncation=True,
137
+ return_tensors="pt"
138
+ )
139
+
140
+ # Move to device
141
+ input_ids = encoding["input_ids"].to(device)
142
+ attention_mask = encoding["attention_mask"].to(device)
143
+ bbox = encoding["bbox"].to(device)
144
+
145
+ # Run inference
146
+ with torch.no_grad():
147
+ outputs = model(
148
+ input_ids=input_ids,
149
+ attention_mask=attention_mask,
150
+ bbox=bbox,
151
+ pixel_values=pixel_values
152
+ )
153
+
154
+ # Get predictions
155
+ predictions = torch.argmax(outputs.logits, dim=-1).cpu().numpy()[0]
156
+
157
+ # Map predictions back to words
158
+ word_ids = encoding.word_ids(batch_index=0)
159
+
160
+ # Extract headings by label
161
+ headings = {"TITLE": [], "H1": [], "H2": [], "H3": []}
162
+ label_map = {0: "TEXT", 1: "TITLE", 2: "H1", 3: "H2", 4: "H3"}
163
+
164
+ current_heading = {"text": "", "level": None}
165
+
166
+ for i, (word_id, pred) in enumerate(zip(word_ids, predictions)):
167
+ if word_id is not None and word_id < len(words):
168
+ predicted_label = label_map.get(pred, "TEXT")
169
+
170
+ if predicted_label != "TEXT":
171
+ if current_heading["level"] == predicted_label:
172
+ # Continue building current heading
173
+ current_heading["text"] += " " + words[word_id]
174
+ else:
175
+ # Save previous heading if it exists
176
+ if current_heading["text"] and current_heading["level"]:
177
+ headings[current_heading["level"]].append(current_heading["text"].strip())
178
+
179
+ # Start new heading
180
+ current_heading = {"text": words[word_id], "level": predicted_label}
181
+ else:
182
+ # Save current heading when we hit regular text
183
+ if current_heading["text"] and current_heading["level"]:
184
+ headings[current_heading["level"]].append(current_heading["text"].strip())
185
+ current_heading = {"text": "", "level": None}
186
+
187
+ # Save final heading
188
+ if current_heading["text"] and current_heading["level"]:
189
+ headings[current_heading["level"]].append(current_heading["text"].strip())
190
+
191
+ # Remove empty lists and return
192
+ headings = {k: v for k, v in headings.items() if v}
193
+
194
+ if not headings:
195
+ return {"INFO": ["No headings detected - this might be a model training issue"]}
196
+
197
+ return headings
198
+
199
+ except Exception as e:
200
+ return {"ERROR": [f"Processing failed: {str(e)}"]}
201
+
202
+ # Load model (this will happen when the Space starts)
203
+ print("Loading model...")
204
+ model, feature_extractor, tokenizer = load_model()
205
+ print("Model loaded successfully!")
206
+
207
+ def process_document(image):
208
+ """Main function to process uploaded document"""
209
+ if image is None:
210
+ return "Please upload an image"
211
+
212
+ print("Processing uploaded image...")
213
+
214
+ # Extract headings
215
+ headings = extract_headings_from_image(image, model, feature_extractor, tokenizer)
216
+
217
+ # Format output
218
+ result = "## Extracted Document Structure:\n\n"
219
+
220
+ if "ERROR" in headings:
221
+ result += f"❌ **Error:** {headings['ERROR'][0]}\n"
222
+ return result
223
+
224
+ if "INFO" in headings:
225
+ result += f"ℹ️ **Info:** {headings['INFO'][0]}\n"
226
+ return result
227
+
228
+ # Display found headings
229
+ for level, texts in headings.items():
230
+ result += f"**{level}:**\n"
231
+ for text in texts:
232
+ if level == "TITLE":
233
+ result += f"# {text}\n"
234
+ elif level == "H1":
235
+ result += f"## {text}\n"
236
+ elif level == "H2":
237
+ result += f"### {text}\n"
238
+ elif level == "H3":
239
+ result += f"#### {text}\n"
240
+ result += "\n"
241
+
242
+ if not any(headings.values()):
243
+ result += "⚠️ No headings were detected in this image.\n\n"
244
+ result += "**Possible reasons:**\n"
245
+ result += "- The model needs training on actual data\n"
246
+ result += "- The image quality is too low\n"
247
+ result += "- The document doesn't contain clear headings\n"
248
+
249
+ return result
250
+
251
+ # Create Gradio interface
252
+ demo = gr.Interface(
253
+ fn=process_document,
254
+ inputs=gr.Image(type="pil", label="Upload Document Image"),
255
+ outputs=gr.Markdown(label="Extracted Headings"),
256
+ title="📄 PDF Heading Extractor",
257
+ description="""
258
+ Upload an image of a document to extract its heading hierarchy.
259
+
260
+ **Note:** This is a demo version using an untrained model.
261
+ The actual model would need to be trained on DocLayNet data for accurate results.
262
+ """,
263
+ examples=None,
264
+ allow_flagging="never"
265
+ )
266
+
267
+ if __name__ == "__main__":
268
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ transformers==4.36.0
2
+ torch>=1.9.0
3
+ torchvision
4
+ datasets
5
+ gradio
6
+ pillow
7
+ numpy
8
+ scikit-learn