File size: 3,205 Bytes
a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e 5888c6a a47415e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import os
import cv2
import torch
from PIL import Image
import numpy as np
# EasyOCR
import easyocr
# TrOCR (Transformer-based OCR from Hugging Face)
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# --- Load TrOCR model and processor once ---
try:
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
trocr_model.eval()
use_trocr = True
except Exception as e:
print(f"TrOCR load failed: {e}")
use_trocr = False
# --- Load EasyOCR once ---
try:
reader = easyocr.Reader(['en'])
use_easyocr = True
except Exception as e:
print(f"EasyOCR load failed: {e}")
use_easyocr = False
# --- Preprocess image for EasyOCR ---
def preprocess_image(image):
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
denoised = cv2.fastNlMeansDenoising(gray, h=10)
processed = cv2.adaptiveThreshold(denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY, 11, 2)
return processed
# --- Preprocess image for TrOCR ---
def preprocess_for_trocr(image_path):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (1280, 720), interpolation=cv2.INTER_LINEAR)
return Image.fromarray(image)
# --- TrOCR extraction ---
def extract_text_with_trocr(image_path):
image = preprocess_for_trocr(image_path)
pixel_values = processor(images=image, return_tensors="pt").pixel_values
with torch.no_grad():
generated_ids = trocr_model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text.strip()
# --- EasyOCR extraction ---
def extract_text_with_easyocr(image_path):
image = cv2.imread(image_path)
processed_image = preprocess_image(image)
# Use processed image
temp_path = os.path.join(os.path.dirname(image_path), f"temp_{os.path.basename(image_path)}")
cv2.imwrite(temp_path, processed_image)
results = reader.readtext(temp_path)
os.remove(temp_path)
text = ' '.join([res[1] for res in results]).strip()
# Fallback to original if empty
if not text:
results = reader.readtext(image_path)
text = ' '.join([res[1] for res in results]).strip()
return text
# --- Main unified function ---
def extract_text_from_image(image_path):
"""
Try extracting handwritten text using TrOCR. Fallback to EasyOCR.
"""
try:
if use_trocr:
print("Using TrOCR...")
trocr_text = extract_text_with_trocr(image_path)
if trocr_text and len(trocr_text.strip()) > 5:
return trocr_text
print("TrOCR output too short. Falling back to EasyOCR...")
if use_easyocr:
print("Using EasyOCR...")
return extract_text_with_easyocr(image_path)
raise Exception("No OCR backend available.")
except Exception as e:
print(f"OCR failed: {e}")
return "Text extraction failed."
|