PaliGemma2 VRD Dhivehi OCR Model

Model Description

This is a fine-tuned version of the PaliGemma2 model specifically optimized for Optical Character Recognition (OCR) of Dhivehi text in images. The model is based on the google/paligemma2-3b-pt-224 architecture and has been fine-tuned for improved performance in reading and transcribing Dhivehi text from images.

Model Details

  • Model type: Vision-Language Model
  • Base model: google/paligemma2-3b-pt-224
  • Fine-tuning approach: QLoRA
  • Input format: Images with text
  • Output format: Text transcription
  • Supported languages: Primarily Dhivehi

How to Use

Option 1: Direct Loading

from transformers.image_utils import load_image
import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor

# Print GPU information
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"Current GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory allocated: {torch.cuda.memory_allocated(0) / 1024**2:.2f} MB")
    print(f"GPU memory cached: {torch.cuda.memory_reserved(0) / 1024**2:.2f} MB")

model_id = "alakxender/paligemma2-qlora-vrd-dhivehi-ocr-224-sm"
print("Loading model...")
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id).to("cuda")
processor = AutoProcessor.from_pretrained(model_id)

print("Loading image...")
image = load_image("ocr1.png")

print("Processing image...")
prompt = "What text is written in this image?"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16).to("cuda")
input_len = model_inputs["input_ids"].shape[-1]

print("Model inputs device:", model_inputs["input_ids"].device)
print("Model device:", model.device)
print("Generating output...")
with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]
    decoded = processor.decode(generation, skip_special_tokens=True)
    print(decoded)

print("Done!")

Option 2: Memory-Efficient PEFT Loading

from transformers.image_utils import load_image
import torch
from transformers import PaliGemmaForConditionalGeneration, AutoProcessor
from peft import PeftModel, PeftConfig

# Define model ID
model_id = "alakxender/paligemma2-qlora-vrd-dhivehi-ocr-224-sm"

# Load the PEFT configuration to get the base model path
print("Loading PEFT configuration...")
peft_config = PeftConfig.from_pretrained(model_id)

# Load the base model
print(f"Loading base model: {peft_config.base_model_name_or_path}...")
base_model = PaliGemmaForConditionalGeneration.from_pretrained(
    peft_config.base_model_name_or_path,
    device_map="auto",
    torch_dtype=torch.bfloat16
)

# Load the adapter on top of the base model
print(f"Loading PEFT adapter: {model_id}...")
model = PeftModel.from_pretrained(base_model, model_id)

# Load the processor from the base model
processor = AutoProcessor.from_pretrained(peft_config.base_model_name_or_path)

print("Loading image...")
image = load_image("ocr1.png")

print("Processing image...")
prompt = "What text is written in this image?"
model_inputs = processor(text=prompt, images=image, return_tensors="pt").to(torch.bfloat16)

# Move inputs to the same device as the model
if hasattr(model, 'device'):
    device = model.device
else:
    # If device isn't directly accessible, infer from model parameters
    device = next(model.parameters()).device
    
model_inputs = {k: v.to(device) for k, v in model_inputs.items()}

input_len = model_inputs["input_ids"].shape[-1]
# Process without printing device information

print("Generating output...")
with torch.inference_mode():
    generation = model.generate(**model_inputs, max_new_tokens=100, do_sample=False)
    generation = generation[0][input_len:]
    decoded = processor.decode(generation, skip_special_tokens=True)
    print(decoded)

print("Done!")

Training Details

  • Base Model: google/paligemma2-3b-pt-224

  • Dataset: alakxender/dhivehi-vrd-b1-img-questions

  • Training Configuration:

    • Batch size: 2 per device
    • Gradient accumulation steps: 8
    • Effective batch size: 16
    • Learning rate: 2e-5
    • Weight decay: 1e-6
    • Adam β2: 0.999
    • Warmup steps: 2
    • Training steps: 20,000
    • Epochs: 1
    • Mixed precision: bfloat16
  • QLoRA Configuration:

    • Quantization: 4-bit NF4
    • LoRA rank (r): 8
    • Target modules:
      • q_proj
      • k_proj
      • v_proj
      • o_proj
      • gate_proj
      • up_proj
      • down_proj
    • Task type: CAUSAL_LM
    • Optimizer: paged_adamw_8bit
  • Data Processing:

    • Image resize method: LANCZOS
    • Input format: RGB images
    • Text prompt format: "answer [question]"
  • Training Metrics:

    • Initial loss: ~15
    • Final loss: ~2
    • Learning rate: Decreasing from 1.5e-5 to 5e-6
    • Gradient norm: Stabilized around 20-60
    • Model checkpointing: Every 1000 steps
    • Logging frequency: Every 100 steps

Performance

The model showed consistent improvement during training:

  • Loss decreased significantly in the first 5k steps and stabilized afterwards
  • Gradient norms remained in a healthy range throughout training
  • Learning rate was automatically adjusted following a linear decay schedule
  • Training completed successfully with convergence in loss metrics
  • Training progress was monitored using Weights & Biases

Model Architecture

This model uses Parameter-Efficient Fine-Tuning (PEFT) with QLoRA:

  • Quantization: 4-bit quantization for memory efficiency
  • LoRA Adaptation: Low-rank adaptation of key transformer components
  • Memory Optimization: Uses paged optimizer for efficient memory usage
  • Mixed Precision: bfloat16 for training stability and speed

Limitations

  • Primarily optimized for Dhivehi text
  • Performance may vary with different image qualities and text styles
  • May or may not perform optimally on handwritten text

Dataset

  • Source Dataset: alakxender/dhivehi-vrd-images (VRD Batch 1)

  • Processed Dataset: alakxender/dhivehi-vrd-b1-img-questions

  • Dataset Size:

    • Total: 474,169 samples
    • Training set: 379,335 samples (80%)
    • Validation set: 94,834 samples (20%)
  • Question Types: The dataset uses a variety of question prompts for OCR tasks, including:

- "What text is written in this image?"
- "Can you read and transcribe the Dhivehi text shown in this image?"
- "What is the Dhivehi text visible in this image?"
- "Please read out the text content from this image"
- "What Dhivehi text can you see in this image?"
- "Is there any text visible in this image? If so, what does it say?"
- "Could you transcribe the Dhivehi text shown in this image?"
- "What does the text in this image say?"
- "Can you read the Dhivehi text in this image? What does it say?"
- "Please identify and transcribe any text visible in this image"
- "What Dhivehi text is present in this image?"
  • Dataset Format:
    • Features:
      • image: Image containing Dhivehi text
      • question: Randomly selected question from the question pool
      • answer: Ground truth Dhivehi text transcription
    • Processing: Memory-efficient chunked processing (10,000 samples per chunk)
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Model tree for alakxender/paligemma2-qlora-vrd-dhivehi-ocr-224-sm

Finetuned
(81)
this model