import os import cv2 import torch from PIL import Image import numpy as np # EasyOCR import easyocr # TrOCR (Transformer-based OCR from Hugging Face) from transformers import TrOCRProcessor, VisionEncoderDecoderModel # --- Load TrOCR model and processor once --- try: processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") trocr_model.eval() use_trocr = True except Exception as e: print(f"TrOCR load failed: {e}") use_trocr = False # --- Load EasyOCR once --- try: reader = easyocr.Reader(['en']) use_easyocr = True except Exception as e: print(f"EasyOCR load failed: {e}") use_easyocr = False # --- Preprocess image for EasyOCR --- def preprocess_image(image): if len(image.shape) == 3: gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) else: gray = image denoised = cv2.fastNlMeansDenoising(gray, h=10) processed = cv2.adaptiveThreshold(denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 11, 2) return processed # --- Preprocess image for TrOCR --- def preprocess_for_trocr(image_path): image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = cv2.resize(image, (1280, 720), interpolation=cv2.INTER_LINEAR) return Image.fromarray(image) # --- TrOCR extraction --- def extract_text_with_trocr(image_path): image = preprocess_for_trocr(image_path) pixel_values = processor(images=image, return_tensors="pt").pixel_values with torch.no_grad(): generated_ids = trocr_model.generate(pixel_values) generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text.strip() # --- EasyOCR extraction --- def extract_text_with_easyocr(image_path): image = cv2.imread(image_path) processed_image = preprocess_image(image) # Use processed image temp_path = os.path.join(os.path.dirname(image_path), f"temp_{os.path.basename(image_path)}") cv2.imwrite(temp_path, processed_image) results = reader.readtext(temp_path) os.remove(temp_path) text = ' '.join([res[1] for res in results]).strip() # Fallback to original if empty if not text: results = reader.readtext(image_path) text = ' '.join([res[1] for res in results]).strip() return text # --- Main unified function --- def extract_text_from_image(image_path): """ Try extracting handwritten text using TrOCR. Fallback to EasyOCR. """ try: if use_trocr: print("Using TrOCR...") trocr_text = extract_text_with_trocr(image_path) if trocr_text and len(trocr_text.strip()) > 5: return trocr_text print("TrOCR output too short. Falling back to EasyOCR...") if use_easyocr: print("Using EasyOCR...") return extract_text_with_easyocr(image_path) raise Exception("No OCR backend available.") except Exception as e: print(f"OCR failed: {e}") return "Text extraction failed."