farbodpya commited on
Commit
1aaa7f3
·
verified ·
1 Parent(s): c1ca773

Upload utils.py

Browse files
Files changed (1) hide show
  1. utils.py +99 -0
utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import cv2
3
+ import torch
4
+ import torchvision.transforms.functional as TF
5
+ import matplotlib.pyplot as plt
6
+ from PIL import Image
7
+
8
+ # -----------------------------
9
+ # Device
10
+ # -----------------------------
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+
13
+ # -----------------------------
14
+ # Load vocab
15
+ # -----------------------------
16
+ def load_vocab(vocab_path):
17
+ with open(vocab_path, "r", encoding="utf-8") as f:
18
+ vocab = json.load(f)
19
+ char_to_idx = vocab["char_to_idx"]
20
+ idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()}
21
+ return char_to_idx, idx_to_char
22
+
23
+ # -----------------------------
24
+ # Greedy decoder
25
+ # -----------------------------
26
+ def greedy_decode(output, idx_to_char):
27
+ output = output.argmax(2)
28
+ texts = []
29
+ for seq in output:
30
+ prev = -1
31
+ chars = []
32
+ for idx in seq.cpu().numpy():
33
+ if idx != prev and idx != 0:
34
+ chars.append(idx_to_char.get(idx, ""))
35
+ prev = idx
36
+ texts.append("".join(chars))
37
+ return texts
38
+
39
+ # -----------------------------
40
+ # Transforms
41
+ # -----------------------------
42
+ class OCRTestTransform:
43
+ def __init__(self, img_height=64, max_width=1600):
44
+ self.img_height = img_height
45
+ self.max_width = max_width
46
+ def __call__(self, img):
47
+ img = img.convert("L")
48
+ w, h = img.size
49
+ new_w = int(w * self.img_height / h)
50
+ img = img.resize((min(new_w, self.max_width), self.img_height), Image.BICUBIC)
51
+ new_img = Image.new("L", (self.max_width, self.img_height), 255)
52
+ new_img.paste(img, (0, 0))
53
+ img = TF.to_tensor(new_img)
54
+ img = TF.normalize(img, (0.5,), (0.5,))
55
+ return img
56
+
57
+ transform_test = OCRTestTransform()
58
+
59
+ # -----------------------------
60
+ # Line segmentation
61
+ # -----------------------------
62
+ def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
63
+ img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
64
+ _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
65
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (img.shape[1]//30, 1))
66
+ morphed = cv2.dilate(binary, kernel, iterations=1)
67
+ contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
68
+ contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
69
+ lines = []
70
+ for ctr in contours:
71
+ x, y, w, h = cv2.boundingRect(ctr)
72
+ if h < min_line_height: continue
73
+ y1 = max(0, y - margin)
74
+ y2 = min(img.shape[0], y + h + margin)
75
+ line_img = img[y1:y2, x:x+w]
76
+ lines.append(Image.fromarray(line_img))
77
+ if visualize:
78
+ for i, line_img in enumerate(lines):
79
+ plt.figure(figsize=(12,2))
80
+ plt.imshow(line_img, cmap='gray')
81
+ plt.axis('off')
82
+ plt.title(f"Line {i+1}")
83
+ plt.show()
84
+ return lines
85
+
86
+ # -----------------------------
87
+ # OCR function
88
+ # -----------------------------
89
+ def ocr_page(image_path, model, idx_to_char, visualize=False):
90
+ lines = segment_lines_precise(image_path, visualize=visualize)
91
+ all_texts = []
92
+ for idx, line_img in enumerate(lines, 1):
93
+ img_tensor = transform_test(line_img).unsqueeze(0).to(device)
94
+ with torch.no_grad():
95
+ outputs = model(img_tensor)
96
+ pred_text = greedy_decode(outputs, idx_to_char)[0]
97
+ all_texts.append(pred_text)
98
+ print(f"Line {idx}: {pred_text}")
99
+ return "\n".join(all_texts)