DETR ResNet-50 DC5 for Dhivehi Layout-Aware Document Parsing
A fine-tuned DETR (DEtection TRansformer) model based on facebook/detr-resnet-50-dc5
, trained on a custom COCO-style dataset for layout-aware document understanding in Dhivehi and similar documents. The model can detect key structural elements such as headings, authorship, paragraphs, and text lines — with awareness of document reading direction (LTR/RTL).
Model Summary
- Base Model: facebook/detr-resnet-50-dc5
- Dataset: Custom COCO-format document layout dataset (
coco-dv-layout
) - Categories:
layout-analysis-QvA6
,author
,caption
,columns
,date
,footnote
,heading
,paragraph
,picture
,textline
- Reading Direction Support: Left-to-Right (LTR) and Right-to-Left (RTL) documents
- Backbone: ResNet-50 DC5
Usage
Inference Script
from transformers import pipeline
from PIL import Image
import torch
image = Image.open("ocr.png")
obj_detector = pipeline(
"object-detection",
model="alakxender/detr-resnet-50-dc5-dv-layout-sm1",
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
use_fast=True
)
results = obj_detector(image)
print(results)
Test Script:
import requests
from transformers import pipeline
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import torch
import argparse
import json
import re
parser = argparse.ArgumentParser()
parser.add_argument("--threshold", type=float, default=0.6)
parser.add_argument("--rtl", action="store_true", default=True, help="Process as right-to-left language document")
args = parser.parse_args()
threshold = args.threshold
is_rtl = args.rtl
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Device set to use {device}")
print(f"Document direction: {'Right-to-Left' if is_rtl else 'Left-to-Right'}")
image = Image.open("ocr-bill.jpeg")
obj_detector = pipeline(
"object-detection",
model="alakxender/detr-resnet-50-dc5-dv-layout-sm1",
device=device,
use_fast=True # Set use_fast=True to avoid slow processor warning
)
results = obj_detector(image)
print(results)
# Define colors for different labels
category_colors = {
"author": (0, 255, 0), # Green
"caption": (0, 0, 255), # Blue
"columns": (255, 255, 0), # Yellow
"date": (255, 0, 255), # Magenta
"footnote": (0, 255, 255), # Cyan
"heading": (128, 0, 0), # Dark Red
"paragraph": (0, 128, 0), # Dark Green
"picture": (0, 0, 128), # Dark Blue
"textline": (128, 128, 0) # Olive
}
# Define document element hierarchy (lower value = higher priority)
element_priority = {
"heading": 1,
"author": 2,
"date": 3,
"columns": 4,
"paragraph": 5,
"textline": 6,
"picture": 7,
"caption": 8,
"footnote": 9
}
def detect_text_direction(results, threshold=0.6):
"""
Attempt to automatically detect if the document is RTL based on detected text elements.
This is a heuristic approach - for production use, consider using language detection.
"""
# Filter by confidence threshold
filtered_results = [r for r in results if r['score'] > threshold]
# Focus on text elements (textline, paragraph, heading)
text_elements = [r for r in filtered_results if r['label'] in ['textline', 'paragraph', 'heading']]
if not text_elements:
return False # Default to LTR if no text elements
# Get coordinates
coordinates = []
for r in text_elements:
box = list(r['box'].values())
if len(box) == 4:
x1, y1, x2, y2 = box
width = x2 - x1
# Store element with its position info
coordinates.append({
'xmin': x1,
'xmax': x2,
'width': width,
'x_center': (x1 + x2) / 2
})
if not coordinates:
return False # Default to LTR
# Analyze the horizontal distribution of elements
image_width = max([c['xmax'] for c in coordinates])
# Calculate the average center position relative to image width
avg_center_position = sum([c['x_center'] for c in coordinates]) / len(coordinates)
relative_position = avg_center_position / image_width
# If elements tend to be more on the right side, it might be RTL
# This is a simple heuristic - a more sophisticated approach would use OCR or language detection
is_rtl_detected = relative_position > 0.55 # Slight bias to right side suggests RTL
print(f"Auto-detected document direction: {'Right-to-Left' if is_rtl_detected else 'Left-to-Right'}")
print(f"Average element center position: {relative_position:.2f} of document width")
return is_rtl_detected
def get_reading_order(results, threshold=0.6, rtl=is_rtl):
"""
Sort detection results in natural reading order for both LTR and RTL documents:
1. First by element priority (headings first)
2. Then by vertical position (top to bottom)
3. For elements with similar y-values, sort by horizontal position based on text direction
"""
# Filter by confidence threshold
filtered_results = [r for r in results if r['score'] > threshold]
# If no manual RTL flag is set, try to auto-detect
if rtl is None:
rtl = detect_text_direction(results, threshold)
# Group text lines by their vertical position
# Text lines within ~20 pixels vertically are considered on the same line
y_tolerance = 20
# Let's first check the structure of box to understand its keys
if filtered_results and 'box' in filtered_results[0]:
box_keys = filtered_results[0]['box'].keys()
print(f"Box structure keys: {box_keys}")
# Extract coordinates based on the box format
# Assuming box format is {'xmin', 'ymin', 'xmax', 'ymax'} or similar
if 'ymin' in box_keys:
y_key, height_key = 'ymin', None
x_key = 'xmin'
elif 'top' in box_keys:
y_key, height_key = 'top', 'height'
x_key = 'left'
else:
print("Unknown box format, defaulting to list unpacking")
# Default case using list unpacking method
y_key, x_key, height_key = None, None, None
else:
print("No box format detected, defaulting to list unpacking")
y_key, x_key, height_key = None, None, None
# Separate heading and non-heading elements
structural_elements = []
content_elements = []
for r in filtered_results:
if r['label'] in ["heading", "author", "date"]:
structural_elements.append(r)
else:
content_elements.append(r)
# Extract coordinate functions based on the format we have
def get_y(element):
if y_key:
return element['box'][y_key]
else:
# If we don't know the format, assume box values() returns [xmin, ymin, xmax, ymax]
return list(element['box'].values())[1] # ymin is typically the second value
def get_x(element):
if x_key:
return element['box'][x_key]
else:
# If we don't know the format, assume box values() returns [xmin, ymin, xmax, ymax]
return list(element['box'].values())[0] # xmin is typically the first value
def get_x_max(element):
box_values = list(element['box'].values())
if len(box_values) >= 4:
return box_values[2] # xmax is typically the third value
return get_x(element) # fallback
def get_y_center(element):
if y_key and height_key:
return element['box'][y_key] + (element['box'][height_key] / 2)
else:
# If using list format [xmin, ymin, xmax, ymax]
box_values = list(element['box'].values())
return (box_values[1] + box_values[3]) / 2 # (ymin + ymax) / 2
# Sort structural elements by priority first, then by y position
sorted_structural = sorted(
structural_elements,
key=lambda x: (
element_priority.get(x['label'], 999),
get_y(x)
)
)
# Group content elements that may be in the same row (similar y-coordinate)
rows = []
for element in content_elements:
y_center = get_y_center(element)
# Check if this element belongs to an existing row
found_row = False
for row in rows:
row_y_centers = [get_y_center(e) for e in row]
row_y_center = sum(row_y_centers) / len(row_y_centers)
if abs(y_center - row_y_center) < y_tolerance:
row.append(element)
found_row = True
break
# If not found in any existing row, create a new row
if not found_row:
rows.append([element])
# Sort elements within each row according to reading direction (left-to-right or right-to-left)
for row in rows:
if rtl:
# For RTL, sort from right to left (descending x values)
row.sort(key=lambda x: get_x(x), reverse=True)
else:
# For LTR, sort from left to right (ascending x values)
row.sort(key=lambda x: get_x(x))
# Sort rows by y position (top to bottom)
rows.sort(key=lambda row: sum(get_y_center(e) for e in row) / len(row))
# Flatten the rows into a single list
sorted_content = [element for row in rows for element in row]
# Combine structural and content elements
return sorted_structural + sorted_content
def plot_results(image, results, threshold=threshold, save_path='output.jpg', rtl=is_rtl):
# Convert image to appropriate format if it's not already a PIL Image
if not isinstance(image, Image.Image):
image = Image.fromarray(np.uint8(image))
draw = ImageDraw.Draw(image)
width, height = image.size
# If rtl is None (not explicitly specified), try to auto-detect
if rtl is None:
rtl = detect_text_direction(results, threshold)
# Get results in reading order
ordered_results = get_reading_order(results, threshold, rtl)
# Create a list to store formatted results
formatted_results = []
# Add order number to visualize the detection sequence
for i, result in enumerate(ordered_results):
label = result['label']
box = list(result['box'].values())
score = result['score']
# Make sure box has exactly 4 values
if len(box) == 4:
x1, y1, x2, y2 = tuple(box)
else:
print(f"Warning: Unexpected box format for {label}: {box}")
continue
color = category_colors.get(label, (255, 255, 255)) # Default to white if label not found
# Draw bounding box and labels
draw.rectangle((x1, y1, x2, y2), outline=color, width=2)
# Add order number to visualize the reading sequence
draw.text((x1 + 5, y1 - 20), f'#{i+1}', fill=(255, 255, 255))
# For RTL languages, draw indicators differently
if rtl and label in ['textline', 'paragraph', 'heading']:
draw.text((x1 + 5, y1 - 10), f'{label} (RTL)', fill=color)
# Draw arrow showing reading direction (right to left)
arrow_y = y1 - 5
draw.line([(x2 - 20, arrow_y), (x1 + 20, arrow_y)], fill=color, width=1)
draw.polygon([(x1 + 20, arrow_y - 3), (x1 + 20, arrow_y + 3), (x1 + 15, arrow_y)], fill=color)
else:
draw.text((x1 + 5, y1 - 10), label, fill=color)
draw.text((x1 + 5, y1 + 10), f'{score:.2f}', fill='green' if score > 0.7 else 'red')
# Add result to formatted list with order index
formatted_results.append({
"order_index": i,
"label": label,
"is_rtl": rtl if label in ['textline', 'paragraph', 'heading'] else False,
"score": float(score),
"bbox": {
"x1": float(x1),
"y1": float(y1),
"x2": float(x2),
"y2": float(y2)
}
})
image.save(save_path)
# Save results to JSON file with RTL information
with open('results.json', 'w') as f:
json.dump({
"document_direction": "rtl" if rtl else "ltr",
"elements": formatted_results
}, f, indent=2)
return image
image.save(save_path)
# Save results to JSON file
with open('results.json', 'w') as f:
json.dump(formatted_results, f, indent=2)
return image
if len(results) > 0: # Only plot if there are results
# If RTL flag not set, try to auto-detect
if not hasattr(args, 'rtl') or args.rtl is None:
is_rtl = detect_text_direction(results)
plot_results(image, results, rtl=is_rtl)
print(f"Processing complete. Document interpreted as {'RTL' if is_rtl else 'LTR'}")
else:
print("No objects detected in the image")
Output Example
- Visual Output: Bounding boxes with labels and order
- JSON Output:
{
"document_direction": "rtl",
"elements": [
{
"order_index": 0,
"label": "heading",
"is_rtl": true,
"score": 0.97,
"bbox": {
"x1": 120.5,
"y1": 65.2,
"x2": 620.4,
"y2": 120.7
}
}
]
}
Training Summary
- Training script: Uses Hugging Face
Trainer
API - Eval Strategy:
steps
withMeanAveragePrecision
viatorchmetrics
- Downloads last month
- 37
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
🙋
Ask for provider support
Model tree for alakxender/detr-resnet-50-dc5-dv-layout-sm1
Base model
facebook/detr-resnet-50-dc5