DETR ResNet-50 DC5 for Dhivehi Layout-Aware Document Parsing

A fine-tuned DETR (DEtection TRansformer) model based on facebook/detr-resnet-50-dc5, trained on a custom COCO-style dataset for layout-aware document understanding in Dhivehi and similar documents. The model can detect key structural elements such as headings, authorship, paragraphs, and text lines — with awareness of document reading direction (LTR/RTL).

Model Summary

  • Base Model: facebook/detr-resnet-50-dc5
  • Dataset: Custom COCO-format document layout dataset (coco-dv-layout)
  • Categories:
    • layout-analysis-QvA6, author, caption, columns, date, footnote, heading, paragraph, picture, textline
  • Reading Direction Support: Left-to-Right (LTR) and Right-to-Left (RTL) documents
  • Backbone: ResNet-50 DC5

Usage

Inference Script

from transformers import pipeline
from PIL import Image
import torch

image = Image.open("ocr.png")

obj_detector = pipeline(
    "object-detection", 
    model="alakxender/detr-resnet-50-dc5-dv-layout-sm1",
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
    use_fast=True
)

results = obj_detector(image)
print(results)

Test Script:

import requests
from transformers import pipeline
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
import argparse
import json
import re

parser = argparse.ArgumentParser()
parser.add_argument("--threshold", type=float, default=0.6)
parser.add_argument("--rtl", action="store_true", default=True, help="Process as right-to-left language document")
args = parser.parse_args()

threshold = args.threshold
is_rtl = args.rtl

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device set to use {device}")
print(f"Document direction: {'Right-to-Left' if is_rtl else 'Left-to-Right'}")

image = Image.open("ocr-bill.jpeg")

obj_detector = pipeline(
    "object-detection", 
    model="alakxender/detr-resnet-50-dc5-dv-layout-sm1",
    device=device,
    use_fast=True  # Set use_fast=True to avoid slow processor warning
)

results = obj_detector(image)
print(results)

# Define colors for different labels
category_colors = {
    "author": (0, 255, 0),                  # Green  
    "caption": (0, 0, 255),                 # Blue
    "columns": (255, 255, 0),               # Yellow
    "date": (255, 0, 255),                  # Magenta
    "footnote": (0, 255, 255),              # Cyan
    "heading": (128, 0, 0),                 # Dark Red
    "paragraph": (0, 128, 0),               # Dark Green
    "picture": (0, 0, 128),                 # Dark Blue
    "textline": (128, 128, 0)              # Olive
}

# Define document element hierarchy (lower value = higher priority)
element_priority = {
    "heading": 1,
    "author": 2,
    "date": 3,
    "columns": 4,
    "paragraph": 5,
    "textline": 6,
    "picture": 7,
    "caption": 8,
    "footnote": 9
}

def detect_text_direction(results, threshold=0.6):
    """
    Attempt to automatically detect if the document is RTL based on detected text elements.
    This is a heuristic approach - for production use, consider using language detection.
    """
    # Filter by confidence threshold
    filtered_results = [r for r in results if r['score'] > threshold]
    
    # Focus on text elements (textline, paragraph, heading)
    text_elements = [r for r in filtered_results if r['label'] in ['textline', 'paragraph', 'heading']]
    
    if not text_elements:
        return False  # Default to LTR if no text elements
    
    # Get coordinates
    coordinates = []
    for r in text_elements:
        box = list(r['box'].values())
        if len(box) == 4:
            x1, y1, x2, y2 = box
            width = x2 - x1
            # Store element with its position info
            coordinates.append({
                'xmin': x1,
                'xmax': x2,
                'width': width,
                'x_center': (x1 + x2) / 2
            })
    
    if not coordinates:
        return False  # Default to LTR
    
    # Analyze the horizontal distribution of elements
    image_width = max([c['xmax'] for c in coordinates])
    
    # Calculate the average center position relative to image width
    avg_center_position = sum([c['x_center'] for c in coordinates]) / len(coordinates)
    relative_position = avg_center_position / image_width
    
    # If elements tend to be more on the right side, it might be RTL
    # This is a simple heuristic - a more sophisticated approach would use OCR or language detection
    is_rtl_detected = relative_position > 0.55  # Slight bias to right side suggests RTL
    
    print(f"Auto-detected document direction: {'Right-to-Left' if is_rtl_detected else 'Left-to-Right'}")
    print(f"Average element center position: {relative_position:.2f} of document width")
    
    return is_rtl_detected

def get_reading_order(results, threshold=0.6, rtl=is_rtl):
    """
    Sort detection results in natural reading order for both LTR and RTL documents:
    1. First by element priority (headings first)
    2. Then by vertical position (top to bottom)
    3. For elements with similar y-values, sort by horizontal position based on text direction
    """
    # Filter by confidence threshold
    filtered_results = [r for r in results if r['score'] > threshold]
    
    # If no manual RTL flag is set, try to auto-detect
    if rtl is None:
        rtl = detect_text_direction(results, threshold)
    
    # Group text lines by their vertical position
    # Text lines within ~20 pixels vertically are considered on the same line
    y_tolerance = 20
    
    # Let's first check the structure of box to understand its keys
    if filtered_results and 'box' in filtered_results[0]:
        box_keys = filtered_results[0]['box'].keys()
        print(f"Box structure keys: {box_keys}")
        
        # Extract coordinates based on the box format
        # Assuming box format is {'xmin', 'ymin', 'xmax', 'ymax'} or similar
        if 'ymin' in box_keys:
            y_key, height_key = 'ymin', None
            x_key = 'xmin'
        elif 'top' in box_keys:
            y_key, height_key = 'top', 'height'
            x_key = 'left'
        else:
            print("Unknown box format, defaulting to list unpacking")
            # Default case using list unpacking method
            y_key, x_key, height_key = None, None, None
    else:
        print("No box format detected, defaulting to list unpacking")
        y_key, x_key, height_key = None, None, None
    
    # Separate heading and non-heading elements
    structural_elements = []
    content_elements = []
    
    for r in filtered_results:
        if r['label'] in ["heading", "author", "date"]:
            structural_elements.append(r)
        else:
            content_elements.append(r)
    
    # Extract coordinate functions based on the format we have
    def get_y(element):
        if y_key:
            return element['box'][y_key]
        else:
            # If we don't know the format, assume box values() returns [xmin, ymin, xmax, ymax]
            return list(element['box'].values())[1]  # ymin is typically the second value
            
    def get_x(element):
        if x_key:
            return element['box'][x_key]
        else:
            # If we don't know the format, assume box values() returns [xmin, ymin, xmax, ymax]
            return list(element['box'].values())[0]  # xmin is typically the first value
    
    def get_x_max(element):
        box_values = list(element['box'].values())
        if len(box_values) >= 4:
            return box_values[2]  # xmax is typically the third value
        return get_x(element)  # fallback
            
    def get_y_center(element):
        if y_key and height_key:
            return element['box'][y_key] + (element['box'][height_key] / 2)
        else:
            # If using list format [xmin, ymin, xmax, ymax]
            box_values = list(element['box'].values())
            return (box_values[1] + box_values[3]) / 2  # (ymin + ymax) / 2
    
    # Sort structural elements by priority first, then by y position
    sorted_structural = sorted(
        structural_elements,
        key=lambda x: (
            element_priority.get(x['label'], 999),
            get_y(x)
        )
    )
    
    # Group content elements that may be in the same row (similar y-coordinate)
    rows = []
    for element in content_elements:
        y_center = get_y_center(element)
        
        # Check if this element belongs to an existing row
        found_row = False
        for row in rows:
            row_y_centers = [get_y_center(e) for e in row]
            row_y_center = sum(row_y_centers) / len(row_y_centers)
            if abs(y_center - row_y_center) < y_tolerance:
                row.append(element)
                found_row = True
                break
        
        # If not found in any existing row, create a new row
        if not found_row:
            rows.append([element])
    
    # Sort elements within each row according to reading direction (left-to-right or right-to-left)
    for row in rows:
        if rtl:
            # For RTL, sort from right to left (descending x values)
            row.sort(key=lambda x: get_x(x), reverse=True)
        else:
            # For LTR, sort from left to right (ascending x values)
            row.sort(key=lambda x: get_x(x))
    
    # Sort rows by y position (top to bottom)
    rows.sort(key=lambda row: sum(get_y_center(e) for e in row) / len(row))
    
    # Flatten the rows into a single list
    sorted_content = [element for row in rows for element in row]
    
    # Combine structural and content elements
    return sorted_structural + sorted_content

def plot_results(image, results, threshold=threshold, save_path='output.jpg', rtl=is_rtl):
    # Convert image to appropriate format if it's not already a PIL Image
    if not isinstance(image, Image.Image):
        image = Image.fromarray(np.uint8(image))
    
    draw = ImageDraw.Draw(image)
    width, height = image.size

    # If rtl is None (not explicitly specified), try to auto-detect
    if rtl is None:
        rtl = detect_text_direction(results, threshold)

    # Get results in reading order
    ordered_results = get_reading_order(results, threshold, rtl)
    
    # Create a list to store formatted results
    formatted_results = []

    # Add order number to visualize the detection sequence
    for i, result in enumerate(ordered_results):
        label = result['label']
        box = list(result['box'].values())
        score = result['score']

        # Make sure box has exactly 4 values
        if len(box) == 4:
            x1, y1, x2, y2 = tuple(box)
        else:
            print(f"Warning: Unexpected box format for {label}: {box}")
            continue
            
        color = category_colors.get(label, (255, 255, 255))  # Default to white if label not found
        
        # Draw bounding box and labels
        draw.rectangle((x1, y1, x2, y2), outline=color, width=2)
        
        # Add order number to visualize the reading sequence
        draw.text((x1 + 5, y1 - 20), f'#{i+1}', fill=(255, 255, 255))
        
        # For RTL languages, draw indicators differently
        if rtl and label in ['textline', 'paragraph', 'heading']:
            draw.text((x1 + 5, y1 - 10), f'{label} (RTL)', fill=color)
            # Draw arrow showing reading direction (right to left)
            arrow_y = y1 - 5
            draw.line([(x2 - 20, arrow_y), (x1 + 20, arrow_y)], fill=color, width=1)
            draw.polygon([(x1 + 20, arrow_y - 3), (x1 + 20, arrow_y + 3), (x1 + 15, arrow_y)], fill=color)
        else:
            draw.text((x1 + 5, y1 - 10), label, fill=color)
            
        draw.text((x1 + 5, y1 + 10), f'{score:.2f}', fill='green' if score > 0.7 else 'red')
        
        # Add result to formatted list with order index
        formatted_results.append({
            "order_index": i,
            "label": label,
            "is_rtl": rtl if label in ['textline', 'paragraph', 'heading'] else False,
            "score": float(score),
            "bbox": {
                "x1": float(x1),
                "y1": float(y1),
                "x2": float(x2),
                "y2": float(y2)
            }
        })

    image.save(save_path)
    
    # Save results to JSON file with RTL information
    with open('results.json', 'w') as f:
        json.dump({
            "document_direction": "rtl" if rtl else "ltr",
            "elements": formatted_results
        }, f, indent=2)
        
    return image

    image.save(save_path)
    
    # Save results to JSON file
    with open('results.json', 'w') as f:
        json.dump(formatted_results, f, indent=2)
        
    return image

if len(results) > 0:  # Only plot if there are results
    # If RTL flag not set, try to auto-detect
    if not hasattr(args, 'rtl') or args.rtl is None:
        is_rtl = detect_text_direction(results)
    
    plot_results(image, results, rtl=is_rtl)
    print(f"Processing complete. Document interpreted as {'RTL' if is_rtl else 'LTR'}")
else:
    print("No objects detected in the image")

Output Example

  • Visual Output: Bounding boxes with labels and order
  • JSON Output:
{
  "document_direction": "rtl",
  "elements": [
    {
      "order_index": 0,
      "label": "heading",
      "is_rtl": true,
      "score": 0.97,
      "bbox": {
        "x1": 120.5,
        "y1": 65.2,
        "x2": 620.4,
        "y2": 120.7
      }
    }
  ]
}

Training Summary

  • Training script: Uses Hugging Face Trainer API
  • Eval Strategy: steps with MeanAveragePrecision via torchmetrics

Downloads last month
37
Safetensors
Model size
41.6M params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for alakxender/detr-resnet-50-dc5-dv-layout-sm1

Finetuned
(41)
this model

Dataset used to train alakxender/detr-resnet-50-dc5-dv-layout-sm1