fsommers commited on
Commit
2493114
1 Parent(s): fcceed0

Initial version

Browse files
Files changed (3) hide show
  1. app.py +140 -0
  2. packages.txt +2 -0
  3. requirements.txt +8 -0
app.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import streamlit as st
4
+
5
+ from PIL import Image
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import pytesseract
10
+
11
+ import plotly.express as px
12
+
13
+ from torch.utils.data import Dataset, DataLoader, Subset
14
+ import os
15
+ import io
16
+ import pytesseract
17
+ import fitz
18
+ from typing import List
19
+ import json
20
+
21
+ import sys
22
+ from pathlib import Path
23
+
24
+ from transformers import LayoutLMv3FeatureExtractor, LayoutLMv3TokenizerFast, LayoutLMv3Processor, LayoutLMv3ForSequenceClassification
25
+
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+
28
+ TOKENIZER = "microsoft/layoutlmv3-base"
29
+ MODEL_NAME = "fsommers/layoutlmv3-autofinance-classification-us-v01"
30
+
31
+ TESS_OPTIONS = "--psm 3" # Automatic page segmentation for Tesseract
32
+
33
+ @st.cache_resource
34
+ def create_ocr_reader():
35
+ def scale_bounding_box(box: List[int], w_scale: float = 1.0, h_scale: float = 1.0):
36
+ return [
37
+ int(box[0] * w_scale),
38
+ int(box[1] * h_scale),
39
+ int(box[2] * w_scale),
40
+ int(box[3] * h_scale)
41
+ ]
42
+ def ocr_page(image) -> dict:
43
+ """
44
+ OCR a given image. Return a dictionary of words and the bounding boxes
45
+ for each word. For each word, there is a corresponding bounding box.
46
+ """
47
+ ocr_df = pytesseract.image_to_data(image, output_type='data.frame', config=TESS_OPTIONS)
48
+ ocr_df = ocr_df.dropna().reset_index(drop=True)
49
+ float_cols = ocr_df.select_dtypes('float').columns
50
+ ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
51
+ ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
52
+ ocr_df = ocr_df.dropna().reset_index(drop=True)
53
+
54
+ words = list(ocr_df.text)
55
+ words = [str(w) for w in words]
56
+
57
+ coordinates = ocr_df[['left', 'top', 'width', 'height']]
58
+ boxes = []
59
+ for i, row in coordinates.iterrows():
60
+ x, y, w, h = tuple(row)
61
+ actual_box = [x, y, x + w, y + h]
62
+ boxes.append(actual_box)
63
+
64
+ assert len(words) == len(boxes)
65
+ return {"bbox": boxes, "words": words}
66
+
67
+ def prepare_image(image):
68
+ ocr_data = ocr_page(image)
69
+ width, height = image.size
70
+ width_scale = 1000 / width
71
+ height_scale = 1000 / height
72
+ words = []
73
+ boxes = []
74
+ for w, b in zip(ocr_data["words"], ocr_data["bbox"]):
75
+ words.append(w)
76
+ boxes.append(scale_bounding_box(b, width_scale, height_scale))
77
+
78
+ assert len(words) == len(boxes)
79
+ for bo in boxes:
80
+ for z in bo:
81
+ if (z > 1000):
82
+ raise
83
+ return words, boxes
84
+
85
+ return prepare_image
86
+
87
+ @st.cache_resource
88
+ def create_model():
89
+ model = LayoutLMv3ForSequenceClassification.from_pretrained(MODEL_NAME)
90
+ return model.eval().to(DEVICE)
91
+
92
+ @st.cache_resource
93
+ def create_processor():
94
+ feature_extractor = LayoutLMv3FeatureExtractor(apply_ocr=False)
95
+ tokenizer = LayoutLMv3TokenizerFast.from_pretrained(TOKENIZER)
96
+ return LayoutLMv3Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
97
+
98
+ def predict(image, reader, processor: LayoutLMv3Processor, model: LayoutLMv3ForSequenceClassification):
99
+ words, boxes = reader(image)
100
+ encoding = processor(
101
+ image,
102
+ words,
103
+ boxes=boxes,
104
+ max_length=512,
105
+ padding="max_length",
106
+ truncation=True,
107
+ return_tensors="pt"
108
+ )
109
+ with torch.inference_mode():
110
+ output = model(
111
+ input_ids=encoding["input_ids"].to(DEVICE),
112
+ attention_mask=encoding["attention_mask"].to(DEVICE),
113
+ bbox=encoding["bbox"].to(DEVICE),
114
+ pixel_values=encoding["pixel_values"].to(DEVICE)
115
+ )
116
+ logits = output.logits
117
+ predicted_class = logits.argmax()
118
+ probabilities = F.softmax(logits, dim=-1).flatten().tolist()
119
+ return predicted_class.detach().item(), probabilities
120
+
121
+ reader = create_ocr_reader()
122
+ processor = create_processor()
123
+ model = create_model()
124
+
125
+ uploaded_file = st.file_uploader("Choose a JPG file", ["jpg", "png"])
126
+ if uploaded_file is not None:
127
+ bytes_data = io.BytesIO(uploaded_file.read())
128
+ image = Image.open(bytes_data)
129
+ st.image(image, caption="Uploaded Image", use_column_width=True)
130
+ predicted, probabilities = predict(image, reader, processor, model)
131
+ predicted_label = model.config.id2label[predicted]
132
+ st.markdown(f"Predicted Label: {predicted_label}")
133
+
134
+ df = pd.DataFrame({
135
+ "Label": list(model.config.id2label.values()),
136
+ "Probability": probabilities
137
+ })
138
+ fig = px.bar(df, x="Label", y="Probability")
139
+ st.plotly_chart(fig, use_container_width=True)
140
+
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ tesseract-ocr
2
+ tesseract-ocr-eng-best
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ pandas==2.2.2
2
+ huggingface-hub==0.23.0
3
+ Pillow==10.3.0
4
+ plotly-express==0.4.1
5
+ PyMuPDF==1.24.3
6
+ pytesseract==0.3.10
7
+ torch==2.2.2
8
+ transformers==4.40.2