license: other
datasets:
- Argobell/gek408
- Argobell/gek408-dpo
language:
- en
base_model: google/gemma-3n-E2B-it
pipeline_tag: image-text-to-text
library_name: transformers
tags:
- gemma3n
- sft
- dpo
- unsloth
- instruction-tuning
- text-generation
- multimodal
- education
- reasoning
๐ง Model Card for gemma-3n-gek408-dpo
gemma-3n-gek408-dpo
is a high-performance, fine-tuned version of google/gemma-3n-E2B-it
, meticulously optimized for educational and scientific reasoning. This model was trained leveraging the Unsloth library for significantly faster training and reduced memory usage.
The training followed a two-stage process:
- Supervised Fine-Tuning (SFT): To teach the model the desired instruction-following behavior on scientific and mathematical tasks.
- Direct Preference Optimization (DPO): To align the model's responses with human preferences for clarity, accuracy, and helpfulness.
This model was developed for the Google - The Gemma 3n Impact Challenge competition.
๐ Model Details
๐งพ Model Description
- Developed by: Argobell
- Shared by: Argobell
- Model type: Multimodal model, capable of processing text image and audio inputs.
- Finetuned from:
google/gemma-3n-E2B-it
- License: This model is subject to the Gemma Terms of Use. Users must agree to and comply with the Gemma Terms of Use and the Gemma Prohibited Use Policy.
- Primary Domain: Education, STEM, Visual Reasoning
๐ Model Sources
- Repository: Argobell/gemma-3n-gek408-dpo
- Competition: Google - The Gemma 3n Impact Challenge
- Demo: GitHub Demo Link
๐ฏ Uses
โ Direct Use
This model is ideal for:
- ๐งฎ Math Tutoring Agents: Guiding students through complex math problems.
- ๐งโ๐ซ Educational AI Assistants: Answering questions based on educational materials.
- ๐ Diagram-based Question Answering: Interpreting charts, graphs, and scientific diagrams.
- ๐ Visual Reasoning & Explanation: Explaining logical steps from a visual prompt.
๐งฉ Downstream Use
This model serves as a strong foundation for:
- Create interactive, offline-ready learning experiences for students in low-connectivity regions.
- Advanced multimodal AI systems for educational platforms.
- Domain-specific reasoning tools for science and engineering.
- Interactive learning applications in STEM fields.
โ ๏ธ Bias, Risks, and Limitations
This model inherits limitations common to most LLMs and has specific risks related to its application:
- Hallucination: The model can generate incorrect or fabricated information.
- Prompt Sensitivity: The phrasing of a prompt can significantly affect the output quality.
- Inherited Biases: It may reflect biases present in the
gemma-3n-E2B-it
base model and thegek408
dataset. - Risk of "Fluent Nonsense": In educational contexts, the model might generate explanations that sound logical and correct but contain subtle mathematical or scientific inaccuracies. Human verification is crucial for factual and educational use cases.
๐ก Recommendations
Always critically evaluate the model's output before use in any real-world application. For educational purposes, outputs should be reviewed by a subject matter expert.
๐ Getting Started
The model was trained with Unsloth, so using it for inference is recommended for maximum performance.
from unsloth import FastModel
import torch
from transformers import TextStreamer
import gc
# Load the model and tokenizer with 4-bit quantization
model, tokenizer = FastModel.from_pretrained(
model_name = "Argobell/gemma-3n-gek408-dpo",
max_seq_length = 1024, # Choose any for long context!
load_in_4bit = True, # 4 bit quantization to reduce memory
# token = "hf_...", # use one if using gated models
)
# Helper function for inference
def do_gemma_3n_inference(model, messages, max_new_tokens = 128):
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt = True, # Must add for generation
tokenize = True,
return_dict = True,
return_tensors = "pt",
).to("cuda")
_ = model.generate(
**inputs,
max_new_tokens = max_new_tokens,
temperature = 1.0, top_p = 0.95, top_k = 64,
streamer = TextStreamer(tokenizer, skip_prompt = True),
)
# Cleanup to reduce VRAM usage
del inputs
torch.cuda.empty_cache()
gc.collect()
sloth_link = "https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg"
messages = [{
"role" : "user",
"content": [
{ "type": "image", "image" : sloth_link },
{ "type": "text", "text" : "Which films does this animal feature in?" }
]
}]
# You might have to wait 1 minute for Unsloth's auto compiler
do_gemma_3n_inference(model, messages, max_new_tokens = 256)
๐ ๏ธ Training Details
The training was conducted in two distinct phases, using a LoRA-based approach accelerated by Unsloth.
๐ Phase 1: Supervised Fine-Tuning (SFT)
- Goal: To teach the model the fundamental structure of responding to mathematical prompts.
- Dataset:
Argobell/gek408
- Key Hyperparameters: The following parameters were used to tune both the vision and language components of the model.
# SFT Stage Configuration
--max_seq_length 2048
--max_steps 320
--learning_rate 2e-4
--lr_scheduler_type "cosine"
--optim "adamw_torch_fused"
# LoRA Configuration
--tune_vision
--tune_language_layers
--tune_attention_modules
--tune_mlp_modules
--r 16
--alpha 16
--lora_dropout 0.05
# Batching & Memory
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--gradient_accumulation_steps 8
--gradient_checkpointing
๐ Phase 2: Direct Preference Optimization (DPO)
- Goal: To refine the SFT model by training it to prefer helpful, accurate responses over less desirable ones.
- Dataset:
Argobell/gek408-dpo
- Key Hyperparameters: Starting from the SFT-tuned model, DPO training was performed with the following settings.
# DPO Stage Configuration
--max_seq_length 2048
--max_prompt_length 1024
--max_steps 100
--learning_rate 5e-6
--optim "adamw_torch_fused"
--warmup_ration 0.1
--weight_decay 0.01
# LoRA Configuration
--tune_vision
--tune_language_layers
--tune_attention_modules
--tune_mlp_modules
--r 4
--alpha 4
--lora_dropout 0.1
# Batching & Memory
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--gradient_accumulation_steps 4
--gradient_checkpointing
๐ป Infrastructure & Software
- Hardware: 1ร NVIDIA RTX 5880 Ada Generation
- Key Software:
- Unsloth: Used for 2-3x faster training and ~60% less memory usage, enabling more extensive experimentation.
- Hugging Face TRL: For implementing the SFT and DPO training loops.
- Hugging Face Transformers & Datasets.
๐งฐ Technical Specifications
Architecture
Gemma-3n utilizes a Matryoshka Transformer (MatFormer) architecture, which nests smaller, self-contained models within a larger one.
๐ Acknowledgements
This work would not have been possible without the foundational models and libraries developed by the open-source community. We would like to extend our gratitude to:
- Google: For developing and releasing the powerful gemma-3n-E2B-it base model.
- The Unsloth AI team: For creating the Unsloth library, which was instrumental in accelerating the training process and reducing computational costs.
- Hugging Face: For providing the transformers, datasets, and TRL libraries that formed the backbone of our training and experimentation pipeline.
๐ Citation
If you use this model in your work, please cite it as follows:
@misc{gemma3ngek408dpo,
author = {Argobell},
title = {gemma-3n-gek408-dpo},
howpublished = {\url{https://huggingface.co/Argobell/gemma-3n-gek408-dpo}},
year = {2025}
}
๐ฅ Model Card Authors
- Argobell
๐ฌ Contact
For questions, feedback, or collaboration, please reach out via email: [email protected]