🧠 Model Card for gemma-3n-gek408-dpo
gemma-3n-gek408-dpo
is a high-performance, fine-tuned version of google/gemma-3n-E2B-it
, meticulously optimized for educational and scientific reasoning. This model was trained leveraging the Unsloth library for significantly faster training and reduced memory usage.
The training followed a two-stage process:
- Supervised Fine-Tuning (SFT): To teach the model the desired instruction-following behavior on scientific and mathematical tasks.
- Direct Preference Optimization (DPO): To align the model's responses with human preferences for clarity, accuracy, and helpfulness.
This model was developed for the Google - The Gemma 3n Impact Challenge competition.
📌 Model Details
🧾 Model Description
- Developed by: Argobell
- Shared by: Argobell
- Model type: Multimodal model, capable of processing text image and audio inputs.
- Finetuned from:
google/gemma-3n-E2B-it
- License: This model is subject to the Gemma Terms of Use. Users must agree to and comply with the Gemma Terms of Use and the Gemma Prohibited Use Policy.
- Primary Domain: Education, STEM, Visual Reasoning
📂 Model Sources
- Repository: Argobell/gemma-3n-gek408-dpo
- Competition: Google - The Gemma 3n Impact Challenge
- Demo: GitHub Demo Link
🎯 Uses
✅ Direct Use
This model is ideal for:
- 🧮 Math Tutoring Agents: Guiding students through complex math problems.
- 🧑🏫 Educational AI Assistants: Answering questions based on educational materials.
- 📊 Diagram-based Question Answering: Interpreting charts, graphs, and scientific diagrams.
- 🔍 Visual Reasoning & Explanation: Explaining logical steps from a visual prompt.
🧩 Downstream Use
This model serves as a strong foundation for:
- Create interactive, offline-ready learning experiences for students in low-connectivity regions.
- Advanced multimodal AI systems for educational platforms.
- Domain-specific reasoning tools for science and engineering.
- Interactive learning applications in STEM fields.
⚠️ Bias, Risks, and Limitations
This model inherits limitations common to most LLMs and has specific risks related to its application:
- Hallucination: The model can generate incorrect or fabricated information.
- Prompt Sensitivity: The phrasing of a prompt can significantly affect the output quality.
- Inherited Biases: It may reflect biases present in the
gemma-3n-E2B-it
base model and thegek408
dataset. - Risk of "Fluent Nonsense": In educational contexts, the model might generate explanations that sound logical and correct but contain subtle mathematical or scientific inaccuracies. Human verification is crucial for factual and educational use cases.
💡 Recommendations
Always critically evaluate the model's output before use in any real-world application. For educational purposes, outputs should be reviewed by a subject matter expert.
🚀 Getting Started
The model was trained with Unsloth, so using it for inference is recommended for maximum performance.
from unsloth import FastModel
import torch
from transformers import TextStreamer
import gc
# Load the model and tokenizer with 4-bit quantization
model, tokenizer = FastModel.from_pretrained(
model_name = "Argobell/gemma-3n-gek408-dpo",
max_seq_length = 1024, # Choose any for long context!
load_in_4bit = True, # 4 bit quantization to reduce memory
# token = "hf_...", # use one if using gated models
)
# Helper function for inference
def do_gemma_3n_inference(model, messages, max_new_tokens = 128):
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt = True, # Must add for generation
tokenize = True,
return_dict = True,
return_tensors = "pt",
).to("cuda")
_ = model.generate(
**inputs,
max_new_tokens = max_new_tokens,
temperature = 1.0, top_p = 0.95, top_k = 64,
streamer = TextStreamer(tokenizer, skip_prompt = True),
)
# Cleanup to reduce VRAM usage
del inputs
torch.cuda.empty_cache()
gc.collect()
sloth_link = "https://files.worldwildlife.org/wwfcmsprod/images/Sloth_Sitting_iStock_3_12_2014/story_full_width/8l7pbjmj29_iStock_000011145477Large_mini__1_.jpg"
messages = [{
"role" : "user",
"content": [
{ "type": "image", "image" : sloth_link },
{ "type": "text", "text" : "Which films does this animal feature in?" }
]
}]
# You might have to wait 1 minute for Unsloth's auto compiler
do_gemma_3n_inference(model, messages, max_new_tokens = 256)
🛠️ Training Details
The training was conducted in two distinct phases, using a LoRA-based approach accelerated by Unsloth.
📚 Phase 1: Supervised Fine-Tuning (SFT)
- Goal: To teach the model the fundamental structure of responding to mathematical prompts.
- Dataset:
Argobell/gek408
- Key Hyperparameters: The following parameters were used to tune both the vision and language components of the model.
# SFT Stage Configuration
--max_seq_length 2048
--max_steps 320
--learning_rate 2e-4
--lr_scheduler_type "cosine"
--optim "adamw_torch_fused"
# LoRA Configuration
--tune_vision
--tune_language_layers
--tune_attention_modules
--tune_mlp_modules
--r 16
--alpha 16
--lora_dropout 0.05
# Batching & Memory
--per_device_train_batch_size 4
--per_device_eval_batch_size 4
--gradient_accumulation_steps 8
--gradient_checkpointing
📚 Phase 2: Direct Preference Optimization (DPO)
- Goal: To refine the SFT model by training it to prefer helpful, accurate responses over less desirable ones.
- Dataset:
Argobell/gek408-dpo
- Key Hyperparameters: Starting from the SFT-tuned model, DPO training was performed with the following settings.
# DPO Stage Configuration
--max_seq_length 2048
--max_prompt_length 1024
--max_steps 100
--learning_rate 5e-6
--optim "adamw_torch_fused"
--warmup_ration 0.1
--weight_decay 0.01
# LoRA Configuration
--tune_vision
--tune_language_layers
--tune_attention_modules
--tune_mlp_modules
--r 4
--alpha 4
--lora_dropout 0.1
# Batching & Memory
--per_device_train_batch_size 2
--per_device_eval_batch_size 2
--gradient_accumulation_steps 4
--gradient_checkpointing
💻 Infrastructure & Software
- Hardware: 1× NVIDIA RTX 5880 Ada Generation
- Key Software:
- Unsloth: Used for 2-3x faster training and ~60% less memory usage, enabling more extensive experimentation.
- Hugging Face TRL: For implementing the SFT and DPO training loops.
- Hugging Face Transformers & Datasets.
🧰 Technical Specifications
Architecture
Gemma-3n utilizes a Matryoshka Transformer (MatFormer) architecture, which nests smaller, self-contained models within a larger one.
🙏 Acknowledgements
This work would not have been possible without the foundational models and libraries developed by the open-source community. We would like to extend our gratitude to:
- Google: For developing and releasing the powerful gemma-3n-E2B-it base model.
- The Unsloth AI team: For creating the Unsloth library, which was instrumental in accelerating the training process and reducing computational costs.
- Hugging Face: For providing the transformers, datasets, and TRL libraries that formed the backbone of our training and experimentation pipeline.
📖 Citation
If you use this model in your work, please cite it as follows:
@misc{gemma3ngek408dpo,
author = {Argobell},
title = {gemma-3n-gek408-dpo},
howpublished = {\url{https://huggingface.co/Argobell/gemma-3n-gek408-dpo}},
year = {2025}
}
👥 Model Card Authors
- Argobell
📬 Contact
For questions, feedback, or collaboration, please reach out via email: [email protected]
- Downloads last month
- 80