|
--- |
|
license: mit |
|
datasets: |
|
- pfb30/multi_woz_v22 |
|
language: |
|
- en |
|
pipeline_tag: text-generation |
|
library_name: transformers |
|
tags: |
|
- Sam-2 |
|
- text-generation |
|
--- |
|
|
|
# 🧠 Model Card: Sam‑2.0 |
|
|
|
## 📌 Model Overview |
|
**Sam‑2.0** is a minimal, modular, decoder‑only Transformer architecture designed for chat‑style reasoning tasks. |
|
It emphasizes reproducibility, ablation‑friendly design, and clean benchmarking across input modalities. |
|
|
|
- **Architecture**: Decoder‑only Transformer with RMSNorm, SwiGLU feed‑forward, and causal masking |
|
- **Training Objective**: Causal language modeling (CLM) with role‑based label masking |
|
- **Checkpoint**: `sam2-epoch35.safetensors` |
|
- **Final Train Loss**: 1.04 |
|
- **Validation Loss**: Not tracked in this run |
|
- **Training Duration**: ~6272 s over 35 epochs |
|
- **Framework**: PyTorch + Hugging Face Transformers (custom model class) |
|
|
|
## 🧱 Model Architecture |
|
| Component | Description | |
|
|-------------------|-----------------------------------------------------------------------------| |
|
| Backbone | Decoder‑only Transformer stack | |
|
| Normalization | RMSNorm | |
|
| Attention | Multi‑head self‑attention (causal) | |
|
| Feed‑Forward | SwiGLU activation with dropout | |
|
| Positional Bias | Learned absolute positions (no RoPE in this minimal variant) | |
|
| Head | Tied‑embedding LM head | |
|
| Checkpoint Format | `safetensors` with metadata for reproducibility | |
|
|
|
## 🧪 Training Details |
|
- **Dataset**: [pfb30/multi_woz_v22](https://huggingface.co/datasets/pfb30/multi_woz_v22) |
|
- **Batch Size**: 8 |
|
- **Optimizer**: AdamW |
|
- **Learning Rate**: 2 × 10⁻⁴ (constant in this run) |
|
- **Loss Function**: Cross‑entropy over assistant tokens only |
|
- **Hardware**: Kaggle GPU runtime |
|
- **Logging**: Step‑wise loss tracking, no validation during training |
|
|
|
## 📊 Evaluation |
|
| Metric | Value | Notes | |
|
|------------------|-------------|---------------------------------------| |
|
| Final Train Loss | 1.04 | Achieved at Epoch 35/35 | |
|
| Validation Loss | — | Not tracked in this run | |
|
| Inference Speed | Fast | Lightweight architecture | |
|
| Generalisation | TBD | To be compared against Sam‑2.5 | |
|
|
|
## 🔧 Intended Use |
|
- **Research**: Benchmarking modular architectures and ablation studies |
|
- **Education**: Reasoning scaffolds and logic quizzes |
|
- **Deployment**: Lightweight agents for chat and dialogue modeling |
|
|
|
## 🚫 Limitations |
|
- No validation tracking — generalisation must be inferred via external harnesses |
|
- Trained on MultiWOZ v2.2 only — may not generalize to other domains without fine‑tuning |
|
- Minimal architecture — no RoPE/MQA in this variant |
|
|
|
## 📁 Files |
|
- `sam2-epoch35.safetensors` — final checkpoint |
|
- `config.json` — architecture and training config |
|
- `tokenizer.json` — tokenizer with special tokens |
|
- `README.md` — training logs and setup instructions |
|
|
|
## 🧩 How to Load |
|
```python |
|
from transformers import AutoTokenizer |
|
import torch |
|
from sam2 import Sam2, Sam2Config # your custom model class |
|
|
|
tok = AutoTokenizer.from_pretrained("Smilyai-labs/Sam-2.0") |
|
cfg = Sam2Config(**json.load(open("config.json"))) |
|
model = Sam2(cfg) |
|
state = torch.load("sam2-epoch35.safetensors", map_location="cpu") |
|
model.load_state_dict(state) |
|
model.eval() |
|
|
|
prompt = "<|user|> Hello! <|eot|>\n<|assistant|>" |
|
ids = tok.encode(prompt, return_tensors="pt") |
|
with torch.no_grad(): |
|
for _ in range(50): |
|
logits = model(ids) |
|
next_id = torch.argmax(logits[:, -1, :], dim=-1, keepdim=True) |
|
ids = torch.cat([ids, next_id], dim=1) |
|
if next_id.item() == tok.eos_token_id: |
|
break |
|
|
|
print(tok.decode(ids[0])) |
|
|