gemma-3n-gek408-dpo / README.md
Argobell's picture
Update README.md
6ae71cb verified
metadata
license: other
datasets:
  - Argobell/gek408
  - Argobell/gek408-dpo
language:
  - en
base_model: google/gemma-3n-E2B-it
pipeline_tag: image-text-to-text
library_name: transformers
tags:
  - gemma3n
  - sft
  - dpo
  - unsloth
  - instruction-tuning
  - text-generation
  - multimodal
  - education
  - reasoning

๐Ÿง  Model Card for gemma-3n-gek408-dpo

gemma-3n-gek408-dpo is a high-performance, fine-tuned version of google/gemma-3n-E2B-it, meticulously optimized for educational and scientific reasoning. This model was trained leveraging the Unsloth library for significantly faster training and reduced memory usage.

The training followed a two-stage process:

  1. Supervised Fine-Tuning (SFT): To teach the model the desired instruction-following behavior on scientific and mathematical tasks.
  2. Direct Preference Optimization (DPO): To align the model's responses with human preferences for clarity, accuracy, and helpfulness.

This model was developed for the Google - The Gemma 3n Impact Challenge competition.

๐Ÿ“Œ Model Details

๐Ÿงพ Model Description

  • Developed by: Argobell
  • Shared by: Argobell
  • Model type: Multimodal model, capable of processing text image and audio inputs.
  • Finetuned from: google/gemma-3n-E2B-it
  • License: This model is subject to the Gemma Terms of Use. Users must agree to and comply with the Gemma Terms of Use and the Gemma Prohibited Use Policy.
  • Primary Domain: Education, STEM, Visual Reasoning

๐Ÿ“‚ Model Sources

๐ŸŽฏ Uses

โœ… Direct Use

This model is ideal for:

  • ๐Ÿงฎ Math Tutoring Agents: Guiding students through complex math problems.
  • ๐Ÿง‘โ€๐Ÿซ Educational AI Assistants: Answering questions based on educational materials.
  • ๐Ÿ“Š Diagram-based Question Answering: Interpreting charts, graphs, and scientific diagrams.
  • ๐Ÿ” Visual Reasoning & Explanation: Explaining logical steps from a visual prompt.

๐Ÿงฉ Downstream Use

This model serves as a strong foundation for:

  • Create interactive, offline-ready learning experiences for students in low-connectivity regions.
  • Advanced multimodal AI systems for educational platforms.
  • Domain-specific reasoning tools for science and engineering.
  • Interactive learning applications in STEM fields.

โš ๏ธ Bias, Risks, and Limitations

This model inherits limitations common to most LLMs and has specific risks related to its application:

  • Hallucination: The model can generate incorrect or fabricated information.
  • Prompt Sensitivity: The phrasing of a prompt can significantly affect the output quality.
  • Inherited Biases: It may reflect biases present in the gemma-3n-E2B-it base model and the gek408 dataset.
  • Risk of "Fluent Nonsense": In educational contexts, the model might generate explanations that sound logical and correct but contain subtle mathematical or scientific inaccuracies. Human verification is crucial for factual and educational use cases.

๐Ÿ’ก Recommendations

Always critically evaluate the model's output before use in any real-world application. For educational purposes, outputs should be reviewed by a subject matter expert.

๐Ÿš€ Getting Started

The model was trained with Unsloth, so using it for inference is recommended for maximum performance.

from unsloth import FastModel
import torch
from transformers import TextStreamer
import gc

# Load the model and tokenizer with 4-bit quantization
model, tokenizer = FastModel.from_pretrained(
    model_name = "Argobell/gemma-3n-gek408-dpo", 
    max_seq_length = 1024, # Choose any for long context!
    load_in_4bit = True,  # 4 bit quantization to reduce memory
    # token = "hf_...", # use one if using gated models
)

# Helper function for inference
def do_gemma_3n_inference(model, messages, max_new_tokens = 128):
    inputs = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt = True, # Must add for generation
        tokenize = True,
        return_dict = True,
        return_tensors = "pt",
    ).to("cuda")
    _ = model.generate(
        **inputs,
        max_new_tokens = max_new_tokens,
        temperature = 1.0, top_p = 0.95, top_k = 64,
        streamer = TextStreamer(tokenizer, skip_prompt = True),
    )
    # Cleanup to reduce VRAM usage
    del inputs
    torch.cuda.empty_cache()
    gc.collect()

sloth_link = "https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg"

messages = [{
    "role" : "user",
    "content": [
        { "type": "image", "image" : sloth_link },
        { "type": "text",  "text" : "Which films does this animal feature in?" }
    ]
}]
# You might have to wait 1 minute for Unsloth's auto compiler
do_gemma_3n_inference(model, messages, max_new_tokens = 256)

๐Ÿ› ๏ธ Training Details

The training was conducted in two distinct phases, using a LoRA-based approach accelerated by Unsloth.

๐Ÿ“š Phase 1: Supervised Fine-Tuning (SFT)

  • Goal: To teach the model the fundamental structure of responding to mathematical prompts.
  • Dataset: Argobell/gek408
  • Key Hyperparameters: The following parameters were used to tune both the vision and language components of the model.
# SFT Stage Configuration
--max_seq_length 2048
--max_steps 320
--learning_rate 2e-4
--lr_scheduler_type "cosine"
--optim "adamw_torch_fused"

# LoRA Configuration
--tune_vision                
--tune_language_layers       
--tune_attention_modules     
--tune_mlp_modules           
--r 16                       
--alpha 16                   
--lora_dropout 0.05

# Batching & Memory
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--gradient_accumulation_steps 8 
--gradient_checkpointing

๐Ÿ“š Phase 2: Direct Preference Optimization (DPO)

  • Goal: To refine the SFT model by training it to prefer helpful, accurate responses over less desirable ones.
  • Dataset: Argobell/gek408-dpo
  • Key Hyperparameters: Starting from the SFT-tuned model, DPO training was performed with the following settings.
# DPO Stage Configuration
--max_seq_length 2048
--max_prompt_length 1024
--max_steps 100
--learning_rate 5e-6         
--optim "adamw_torch_fused"
--warmup_ration 0.1
--weight_decay 0.01

# LoRA Configuration
--tune_vision                
--tune_language_layers       
--tune_attention_modules     
--tune_mlp_modules           
--r 4
--alpha 4
--lora_dropout 0.1

# Batching & Memory
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--gradient_accumulation_steps 4
--gradient_checkpointing

๐Ÿ’ป Infrastructure & Software

  • Hardware: 1ร— NVIDIA RTX 5880 Ada Generation
  • Key Software:
    • Unsloth: Used for 2-3x faster training and ~60% less memory usage, enabling more extensive experimentation.
    • Hugging Face TRL: For implementing the SFT and DPO training loops.
    • Hugging Face Transformers & Datasets.

๐Ÿงฐ Technical Specifications

Architecture

Gemma-3n utilizes a Matryoshka Transformer (MatFormer) architecture, which nests smaller, self-contained models within a larger one.

๐Ÿ™ Acknowledgements

This work would not have been possible without the foundational models and libraries developed by the open-source community. We would like to extend our gratitude to:

  • Google: For developing and releasing the powerful gemma-3n-E2B-it base model.
  • The Unsloth AI team: For creating the Unsloth library, which was instrumental in accelerating the training process and reducing computational costs.
  • Hugging Face: For providing the transformers, datasets, and TRL libraries that formed the backbone of our training and experimentation pipeline.

๐Ÿ“– Citation

If you use this model in your work, please cite it as follows:

@misc{gemma3ngek408dpo,
  author = {Argobell},
  title = {gemma-3n-gek408-dpo},
  howpublished = {\url{https://huggingface.co/Argobell/gemma-3n-gek408-dpo}},
  year = {2025}
}

๐Ÿ‘ฅ Model Card Authors

  • Argobell

๐Ÿ“ฌ Contact

For questions, feedback, or collaboration, please reach out via email: [email protected]