File size: 3,205 Bytes
a47415e
5888c6a
a47415e
 
5888c6a
a47415e
 
5888c6a
 
a47415e
 
 
 
 
 
 
 
 
 
 
 
 
 
5888c6a
 
a47415e
5888c6a
a47415e
 
5888c6a
a47415e
5888c6a
 
 
 
 
a47415e
5888c6a
a47415e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5888c6a
a47415e
 
 
 
5888c6a
a47415e
5888c6a
a47415e
5888c6a
 
a47415e
5888c6a
 
a47415e
 
 
 
 
 
 
 
 
 
 
 
 
5888c6a
a47415e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import cv2
import torch
from PIL import Image
import numpy as np

# EasyOCR
import easyocr

# TrOCR (Transformer-based OCR from Hugging Face)
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

# --- Load TrOCR model and processor once ---
try:
    processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
    trocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
    trocr_model.eval()
    use_trocr = True
except Exception as e:
    print(f"TrOCR load failed: {e}")
    use_trocr = False

# --- Load EasyOCR once ---
try:
    reader = easyocr.Reader(['en'])
    use_easyocr = True
except Exception as e:
    print(f"EasyOCR load failed: {e}")
    use_easyocr = False

# --- Preprocess image for EasyOCR ---
def preprocess_image(image):
    if len(image.shape) == 3:
        gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    else:
        gray = image

    denoised = cv2.fastNlMeansDenoising(gray, h=10)
    processed = cv2.adaptiveThreshold(denoised, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C,
                                      cv2.THRESH_BINARY, 11, 2)
    return processed

# --- Preprocess image for TrOCR ---
def preprocess_for_trocr(image_path):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = cv2.resize(image, (1280, 720), interpolation=cv2.INTER_LINEAR)
    return Image.fromarray(image)

# --- TrOCR extraction ---
def extract_text_with_trocr(image_path):
    image = preprocess_for_trocr(image_path)
    pixel_values = processor(images=image, return_tensors="pt").pixel_values

    with torch.no_grad():
        generated_ids = trocr_model.generate(pixel_values)
        generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    return generated_text.strip()

# --- EasyOCR extraction ---
def extract_text_with_easyocr(image_path):
    image = cv2.imread(image_path)
    processed_image = preprocess_image(image)

    # Use processed image
    temp_path = os.path.join(os.path.dirname(image_path), f"temp_{os.path.basename(image_path)}")
    cv2.imwrite(temp_path, processed_image)
    results = reader.readtext(temp_path)
    os.remove(temp_path)

    text = ' '.join([res[1] for res in results]).strip()
    
    # Fallback to original if empty
    if not text:
        results = reader.readtext(image_path)
        text = ' '.join([res[1] for res in results]).strip()
    
    return text

# --- Main unified function ---
def extract_text_from_image(image_path):
    """
    Try extracting handwritten text using TrOCR. Fallback to EasyOCR.
    """
    try:
        if use_trocr:
            print("Using TrOCR...")
            trocr_text = extract_text_with_trocr(image_path)
            if trocr_text and len(trocr_text.strip()) > 5:
                return trocr_text
            print("TrOCR output too short. Falling back to EasyOCR...")

        if use_easyocr:
            print("Using EasyOCR...")
            return extract_text_with_easyocr(image_path)

        raise Exception("No OCR backend available.")
    
    except Exception as e:
        print(f"OCR failed: {e}")
        return "Text extraction failed."