answer-evaluation-app / utils /image_processor.py
yeswanthvarma's picture
Initial FastAPI app for answer evaluation
a47415e
raw
history blame
3.21 kB
import os
import cv2
import torch
from PIL import Image
import numpy as np
# EasyOCR
import easyocr
# TrOCR (Transformer-based OCR from Hugging Face)
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
# --- Load TrOCR model and processor once ---
try:
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
trocr_model.eval()
use_trocr = True
except Exception as e:
print(f"TrOCR load failed: {e}")
use_trocr = False
# --- Load EasyOCR once ---
try:
reader = easyocr.Reader(['en'])
use_easyocr = True
except Exception as e:
print(f"EasyOCR load failed: {e}")
use_easyocr = False
# --- Preprocess image for EasyOCR ---
def preprocess_image(image):
if len(image.shape) == 3:
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
else:
gray = image
denoised = cv2.fastNlMeansDenoising(gray, h=10)
processed = cv2.adaptiveThreshold(denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
cv2.THRESH_BINARY, 11, 2)
return processed
# --- Preprocess image for TrOCR ---
def preprocess_for_trocr(image_path):
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image = cv2.resize(image, (1280, 720), interpolation=cv2.INTER_LINEAR)
return Image.fromarray(image)
# --- TrOCR extraction ---
def extract_text_with_trocr(image_path):
image = preprocess_for_trocr(image_path)
pixel_values = processor(images=image, return_tensors="pt").pixel_values
with torch.no_grad():
generated_ids = trocr_model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return generated_text.strip()
# --- EasyOCR extraction ---
def extract_text_with_easyocr(image_path):
image = cv2.imread(image_path)
processed_image = preprocess_image(image)
# Use processed image
temp_path = os.path.join(os.path.dirname(image_path), f"temp_{os.path.basename(image_path)}")
cv2.imwrite(temp_path, processed_image)
results = reader.readtext(temp_path)
os.remove(temp_path)
text = ' '.join([res[1] for res in results]).strip()
# Fallback to original if empty
if not text:
results = reader.readtext(image_path)
text = ' '.join([res[1] for res in results]).strip()
return text
# --- Main unified function ---
def extract_text_from_image(image_path):
"""
Try extracting handwritten text using TrOCR. Fallback to EasyOCR.
"""
try:
if use_trocr:
print("Using TrOCR...")
trocr_text = extract_text_with_trocr(image_path)
if trocr_text and len(trocr_text.strip()) > 5:
return trocr_text
print("TrOCR output too short. Falling back to EasyOCR...")
if use_easyocr:
print("Using EasyOCR...")
return extract_text_with_easyocr(image_path)
raise Exception("No OCR backend available.")
except Exception as e:
print(f"OCR failed: {e}")
return "Text extraction failed."