|
import os |
|
import cv2 |
|
import torch |
|
from PIL import Image |
|
import numpy as np |
|
|
|
|
|
import easyocr |
|
|
|
|
|
from transformers import TrOCRProcessor, VisionEncoderDecoderModel |
|
|
|
|
|
try: |
|
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") |
|
trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") |
|
trocr_model.eval() |
|
use_trocr = True |
|
except Exception as e: |
|
print(f"TrOCR load failed: {e}") |
|
use_trocr = False |
|
|
|
|
|
try: |
|
reader = easyocr.Reader(['en']) |
|
use_easyocr = True |
|
except Exception as e: |
|
print(f"EasyOCR load failed: {e}") |
|
use_easyocr = False |
|
|
|
|
|
def preprocess_image(image): |
|
if len(image.shape) == 3: |
|
gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
|
else: |
|
gray = image |
|
|
|
denoised = cv2.fastNlMeansDenoising(gray, h=10) |
|
processed = cv2.adaptiveThreshold(denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, |
|
cv2.THRESH_BINARY, 11, 2) |
|
return processed |
|
|
|
|
|
def preprocess_for_trocr(image_path): |
|
image = cv2.imread(image_path) |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
image = cv2.resize(image, (1280, 720), interpolation=cv2.INTER_LINEAR) |
|
return Image.fromarray(image) |
|
|
|
|
|
def extract_text_with_trocr(image_path): |
|
image = preprocess_for_trocr(image_path) |
|
pixel_values = processor(images=image, return_tensors="pt").pixel_values |
|
|
|
with torch.no_grad(): |
|
generated_ids = trocr_model.generate(pixel_values) |
|
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
|
|
|
return generated_text.strip() |
|
|
|
|
|
def extract_text_with_easyocr(image_path): |
|
image = cv2.imread(image_path) |
|
processed_image = preprocess_image(image) |
|
|
|
|
|
temp_path = os.path.join(os.path.dirname(image_path), f"temp_{os.path.basename(image_path)}") |
|
cv2.imwrite(temp_path, processed_image) |
|
results = reader.readtext(temp_path) |
|
os.remove(temp_path) |
|
|
|
text = ' '.join([res[1] for res in results]).strip() |
|
|
|
|
|
if not text: |
|
results = reader.readtext(image_path) |
|
text = ' '.join([res[1] for res in results]).strip() |
|
|
|
return text |
|
|
|
|
|
def extract_text_from_image(image_path): |
|
""" |
|
Try extracting handwritten text using TrOCR. Fallback to EasyOCR. |
|
""" |
|
try: |
|
if use_trocr: |
|
print("Using TrOCR...") |
|
trocr_text = extract_text_with_trocr(image_path) |
|
if trocr_text and len(trocr_text.strip()) > 5: |
|
return trocr_text |
|
print("TrOCR output too short. Falling back to EasyOCR...") |
|
|
|
if use_easyocr: |
|
print("Using EasyOCR...") |
|
return extract_text_with_easyocr(image_path) |
|
|
|
raise Exception("No OCR backend available.") |
|
|
|
except Exception as e: |
|
print(f"OCR failed: {e}") |
|
return "Text extraction failed." |
|
|