import json import cv2 import torch import torchvision.transforms.functional as TF import matplotlib.pyplot as plt from PIL import Image # ----------------------------- # Device # ----------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ----------------------------- # Load vocab # ----------------------------- def load_vocab(vocab_path): with open(vocab_path, "r", encoding="utf-8") as f: vocab = json.load(f) char_to_idx = vocab["char_to_idx"] idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()} return char_to_idx, idx_to_char # ----------------------------- # Greedy decoder # ----------------------------- def greedy_decode(output, idx_to_char): output = output.argmax(2) texts = [] for seq in output: prev = -1 chars = [] for idx in seq.cpu().numpy(): if idx != prev and idx != 0: chars.append(idx_to_char.get(idx, "")) prev = idx texts.append("".join(chars)) return texts # ----------------------------- # Transforms # ----------------------------- class OCRTestTransform: def __init__(self, img_height=64, max_width=1600): self.img_height = img_height self.max_width = max_width def __call__(self, img): img = img.convert("L") w, h = img.size new_w = int(w * self.img_height / h) img = img.resize((min(new_w, self.max_width), self.img_height), Image.BICUBIC) new_img = Image.new("L", (self.max_width, self.img_height), 255) new_img.paste(img, (0, 0)) img = TF.to_tensor(new_img) img = TF.normalize(img, (0.5,), (0.5,)) return img transform_test = OCRTestTransform() # ----------------------------- # Line segmentation # ----------------------------- def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False): img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE) _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (img.shape[1]//30, 1)) morphed = cv2.dilate(binary, kernel, iterations=1) contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1]) lines = [] for ctr in contours: x, y, w, h = cv2.boundingRect(ctr) if h < min_line_height: continue y1 = max(0, y - margin) y2 = min(img.shape[0], y + h + margin) line_img = img[y1:y2, x:x+w] lines.append(Image.fromarray(line_img)) if visualize: for i, line_img in enumerate(lines): plt.figure(figsize=(12,2)) plt.imshow(line_img, cmap='gray') plt.axis('off') plt.title(f"Line {i+1}") plt.show() return lines # ----------------------------- # OCR function # ----------------------------- def ocr_page(image_path, model, idx_to_char, visualize=False): lines = segment_lines_precise(image_path, visualize=visualize) all_texts = [] for idx, line_img in enumerate(lines, 1): img_tensor = transform_test(line_img).unsqueeze(0).to(device) with torch.no_grad(): outputs = model(img_tensor) pred_text = greedy_decode(outputs, idx_to_char)[0] all_texts.append(pred_text) print(f"Line {idx}: {pred_text}") return "\n".join(all_texts)