Persian_OCR / README.md
farbodpya's picture
Update README.md
e0399f1 verified
|
raw
history blame
7.36 kB
metadata
license: apache-2.0
language:
  - fa
pipeline_tag: image-to-text
widget:
  - src: >-
      https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/papers/attention.png
    example_title: Persian OCR

Persian-OCR

Persian-OCR is a deep learning model for Optical Character Recognition (OCR) designed specifically for Persian text. The model uses a CNN + Transformer architecture trained with CTC loss to extract text from images.

Files

  • pytorch_model.bin : PyTorch model weights
  • vocab.json : Character vocabulary
  • config.json : Model configuration

Installation

!pip install torch torchvision huggingface_hub



## Usage Example


import json
import torch
import torch.nn as nn
from PIL import Image
import torchvision.transforms.functional as TF
import cv2
import matplotlib.pyplot as plt
from huggingface_hub import hf_hub_download

# -----------------------------
# 1️⃣ Device
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# -----------------------------
# 2️⃣ Load vocab
# -----------------------------
vocab_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="vocab.json")
with open(vocab_path, "r", encoding="utf-8") as f:
    vocab = json.load(f)
char_to_idx = vocab["char_to_idx"]
idx_to_char = {int(k): v for k, v in vocab["idx_to_char"].items()}

# -----------------------------
# 3️⃣ Model definition
# -----------------------------
def GN(c, groups=16): return nn.GroupNorm(min(groups, c), c)

class LightResNetCNN(nn.Module):
    def __init__(self, in_channels=1, adaptive_height=8):
        super().__init__()
        self.adaptive_height = adaptive_height
        self.layer1 = nn.Sequential(nn.Conv2d(in_channels, 32, 3, 1, 1), GN(32), nn.ReLU(), nn.MaxPool2d(2, 2))
        self.layer2 = nn.Sequential(nn.Conv2d(32, 64, 3, 1, 1), GN(64), nn.ReLU(), nn.MaxPool2d(2, 2))
        self.layer3 = nn.Sequential(nn.Conv2d(64, 128, 3, 1, 1), GN(128), nn.ReLU(), nn.MaxPool2d(2, 2))
        self.layer4 = nn.Sequential(nn.Conv2d(128, 256, 3, 1, 1), GN(256), nn.ReLU())
        self.layer5 = nn.Sequential(nn.Conv2d(256, 256, 3, 1, 1), GN(256), nn.ReLU())
        self.layer6 = nn.Sequential(nn.Conv2d(256, 128, 3, 1, 1), GN(128), nn.ReLU())
        self.adaptive_pool = nn.AdaptiveAvgPool2d((self.adaptive_height, None))
    def forward(self, x):
        for i in range(1, 7):
            x = getattr(self, f"layer{i}")(x)
        x = self.adaptive_pool(x)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-torch.log(torch.tensor(10000.0)) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))
    def forward(self, x):
        return x + self.pe[:, :x.size(1), :]

class CNN_Transformer_OCR(nn.Module):
    def __init__(self, num_classes, d_model=1280, nhead=16, num_layers=8, dropout=0.2):
        super().__init__()
        self.cnn = LightResNetCNN(in_channels=1, adaptive_height=8)
        self.proj = nn.Linear(128 * 8, d_model)
        self.posenc = PositionalEncoding(d_model)
        encoder_layer = nn.TransformerEncoderLayer(d_model, nhead, batch_first=True, dropout=dropout)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.Linear(d_model, num_classes)
    def forward(self, x):
        f = self.cnn(x)
        B, C, H, W = f.size()
        f = f.permute(0, 3, 1, 2).reshape(B, W, C * H)
        f = self.posenc(self.proj(f))
        out = self.transformer(f)
        out = self.fc(out)
        return out.log_softmax(2)

# -----------------------------
# 4️⃣ Load model weights
# -----------------------------
model_path = hf_hub_download(repo_id="farbodpya/Persian-OCR", filename="pytorch_model.bin")
model = CNN_Transformer_OCR(num_classes=len(idx_to_char)+1).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

# -----------------------------
# 5️⃣ Greedy decoder
# -----------------------------
def greedy_decode(output, idx_to_char):
    output = output.argmax(2)
    texts = []
    for seq in output:
        prev = -1
        chars = []
        for idx in seq.cpu().numpy():
            if idx != prev and idx != 0:
                chars.append(idx_to_char.get(idx, ""))
            prev = idx
        texts.append("".join(chars))
    return texts

# -----------------------------
# 6️⃣ Transforms
# -----------------------------
class OCRTestTransform:
    def __init__(self, img_height=64, max_width=1600):
        self.img_height = img_height
        self.max_width = max_width
    def __call__(self, img):
        img = img.convert("L")
        w, h = img.size
        new_w = int(w * self.img_height / h)
        img = img.resize((min(new_w, self.max_width), self.img_height), Image.BICUBIC)
        new_img = Image.new("L", (self.max_width, self.img_height), 255)
        new_img.paste(img, (0, 0))
        img = TF.to_tensor(new_img)
        img = TF.normalize(img, (0.5,), (0.5,))
        return img

transform_test = OCRTestTransform()

# -----------------------------
# 7️⃣ Line segmentation
# -----------------------------
def segment_lines_precise(image_path, min_line_height=12, margin=6, visualize=False):
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    _, binary = cv2.threshold(img, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
    kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (img.shape[1]//30, 1))
    morphed = cv2.dilate(binary, kernel, iterations=1)
    contours, _ = cv2.findContours(morphed, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = sorted(contours, key=lambda ctr: cv2.boundingRect(ctr)[1])
    lines = []
    for ctr in contours:
        x, y, w, h = cv2.boundingRect(ctr)
        if h < min_line_height: continue
        y1 = max(0, y - margin)
        y2 = min(img.shape[0], y + h + margin)
        line_img = img[y1:y2, x:x+w]
        lines.append(Image.fromarray(line_img))
    if visualize:
        for i, line_img in enumerate(lines):
            plt.figure(figsize=(12,2))
            plt.imshow(line_img, cmap='gray')
            plt.axis('off')
            plt.title(f"Line {i+1}")
            plt.show()
    return lines

# -----------------------------
# 8️⃣ OCR function
# -----------------------------
def ocr_page(image_path, visualize=False):
    lines = segment_lines_precise(image_path, visualize=visualize)
    all_texts = []
    for idx, line_img in enumerate(lines, 1):
        img_tensor = transform_test(line_img).unsqueeze(0).to(device)
        with torch.no_grad():
            outputs = model(img_tensor)
        pred_text = greedy_decode(outputs, idx_to_char)[0]
        all_texts.append(pred_text)
        print(f"Line {idx}: {pred_text}")
    return "\n".join(all_texts)

# -----------------------------
# 9️⃣ Example usage
# -----------------------------
img_path = "/content/farsi_line.png"  # put your own image path here
final_text = ocr_page(img_path, visualize=True)
print("\n=== Final OCR Page ===\n", final_text)