DolphinGR00T-N1.5-3B-Zero / init_DolphinGR00T_zero.py
ehartford's picture
Update init_DolphinGR00T_zero.py
58d30dd verified
#!/usr/bin/env python3
# ─────────────────── make local repo override any wheel ────────────────────
import sys, os; sys.path.insert(0, os.path.abspath("."))
# ─────────────────── Flash-Attention + CUDA stubs ──────────────────────────
import types, torch, torch.nn.functional as F, importlib.machinery as im
flash_pkg = types.ModuleType("flash_attn"); flash_pkg.__spec__ = im.ModuleSpec("flash_attn", loader=None, is_package=True); flash_pkg.__path__=[]
sys.modules["flash_attn"] = flash_pkg
fa = types.ModuleType("flash_attn.flash_attn_interface"); fa.__spec__ = im.ModuleSpec("flash_attn.flash_attn_interface", loader=None)
def _sdpa(qkv,*_,causal=False,**__): q,k,v = qkv.unbind(1); q,k,v = (t.unsqueeze(0) for t in(q,k,v)); return F.scaled_dot_product_attention(q,k,v,is_causal=causal).squeeze(0)
for s in ("flash_attn_unpadded_qkvpacked_func","flash_attn_unpadded_kvpacked_func","flash_attn_varlen_qkvpacked_func","flash_attn_varlen_kvpacked_func"): setattr(fa, s, _sdpa)
sys.modules["flash_attn.flash_attn_interface"] = fa; flash_pkg.flash_attn_interface = fa
pad = types.ModuleType("flash_attn.bert_padding"); pad.__spec__ = im.ModuleSpec("flash_attn.bert_padding", loader=None)
pad.pad_input = lambda x,*a,**k:(x,None); pad.unpad_input = lambda x,*a,**k:x
sys.modules["flash_attn.bert_padding"] = pad; flash_pkg.bert_padding = pad
if not torch.cuda.is_available():
torch.cuda.is_available=lambda:False
torch.cuda.get_device_capability=lambda dev=None:(0,0)
torch.cuda.current_device=lambda:0
torch.cuda.get_device_properties=lambda dev=None:types.SimpleNamespace(major=0,minor=0)
import importlib.metadata as _im
if "flash_attn" not in _im.packages_distributions():
rv, rd = _im.version, _im.distribution
_im.version = lambda p:"0.0.0" if p=="flash_attn" else rv(p)
_im.distribution = lambda p:types.SimpleNamespace(version="0.0.0") if p=="flash_attn" else rd(p)
# ─────────────────── std imports ───────────────────────────────────────────
from pathlib import Path
import argparse, json, shutil
from huggingface_hub import hf_hub_download
from transformers import AutoConfig
from gr00t.model.gr00t_n1 import GR00T_N1_5
# ─────────────────── helpers ───────────────────────────────────────────────
def patched_cfg():
p = hf_hub_download("nvidia/GR00T-N1.5-3B", "config.json")
d = json.load(open(p))
if d.get("model_type") != "gr00t_n1_5":
d["model_type"] = "gr00t_n1_5"
patched = Path(p).with_name("config_patched.json")
patched.write_text(json.dumps(d)); return str(patched)
return p
def build_blank():
cfg = AutoConfig.from_pretrained(patched_cfg(),
trust_remote_code=True,
local_files_only=True)
cfg.backbone_cfg.update(dict(tune_llm=True)) # enable L-tower
cfg.backbone_cfg.pop("checkpoint_path", None)
cfg.backbone_cfg.pop("use_pretrained", None)
cfg.action_head_cfg.pop("checkpoint_path", None)
torch.manual_seed(0)
return GR00T_N1_5(cfg, local_model_path="") # random weights
def maybe_add_lm_head(model):
"""Ensure lm_head is properly initialized with weights"""
# Navigate to the language model
lm = model.backbone.eagle_model.language_model
# Get dimensions from embed_tokens
embed_tokens = lm.model.embed_tokens
vocab_size = embed_tokens.num_embeddings
hidden_size = embed_tokens.embedding_dim
print(f"Embedding dimensions: vocab_size={vocab_size}, hidden_size={hidden_size}")
# Expected shape based on architecture: [151680, 2048]
if vocab_size != 151680 or hidden_size != 2048:
print(f"⚠️ Warning: Unexpected dimensions. Expected vocab=151680, hidden=2048")
# Check if lm_head exists
if hasattr(lm, "lm_head"):
print(f"lm_head attribute exists: {lm.lm_head is not None}")
# Even if lm_head exists, it might not have weights properly initialized
# Just replace it with a properly initialized one
print("Creating new lm_head with proper initialization...")
else:
print("lm_head attribute missing, creating...")
# Create a new lm_head with proper initialization
# Note: nn.Linear uses (in_features, out_features), so it's (hidden_size, vocab_size)
new_lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False)
# Initialize weights with normal distribution (std=0.02 is standard for LM heads)
torch.nn.init.normal_(new_lm_head.weight, mean=0.0, std=0.02)
# Convert to bfloat16 to match backbone
new_lm_head.weight.data = new_lm_head.weight.data.to(torch.bfloat16)
# Replace the lm_head
lm.lm_head = new_lm_head
print(f"βœ“ Created lm_head: Linear({hidden_size}, {vocab_size}, bias=False)")
print(f" Weight shape: {lm.lm_head.weight.shape}")
print(f" Weight dtype: {lm.lm_head.weight.dtype}")
print(f" Parameters: {lm.lm_head.weight.numel() / 1e6:.1f}M")
def set_mixed(model):
"""Set mixed precision: backbone in bf16, action head in fp32"""
for n,p in model.named_parameters():
if n.startswith("backbone.") or "lm_head" in n:
p.data = p.data.to(torch.bfloat16)
else:
p.data = p.data.to(torch.float32)
def copy_tokenizer(out):
for f in ("tokenizer.json","tokenizer_config.json","vocab.txt","special_tokens_map.json"):
try: shutil.copy(hf_hub_download("nvidia/GR00T-N1.5-3B", f), out/f)
except Exception: pass
def diagnose_model(model):
"""Print diagnostic info about the model"""
print("\nModel diagnostics:")
total_params = sum(p.numel() for p in model.parameters())
print(f" Total params: {total_params/1e6:,.0f}M")
# Check for key components
has_lm_head = False
lm_head_params = 0
lm_head_location = None
for name, param in model.named_parameters():
if "lm_head" in name:
has_lm_head = True
lm_head_params += param.numel()
lm_head_location = name
print(f" Has lm_head: {'βœ“' if has_lm_head else 'βœ—'}")
if has_lm_head:
print(f" lm_head params: {lm_head_params/1e6:,.0f}M")
print(f" lm_head location: {lm_head_location}")
# Check if the params are actually counted in the total
lm = model.backbone.eagle_model.language_model
if hasattr(lm, 'lm_head') and lm.lm_head is not None:
actual_params = lm.lm_head.weight.numel()
print(f" lm_head actual params: {actual_params/1e6:,.0f}M")
print(f" lm_head weight shape: {lm.lm_head.weight.shape}")
print(f" lm_head weight dtype: {lm.lm_head.weight.dtype}")
def validate_model_architecture(model):
"""Validate model against the architecture specification"""
print("\n" + "="*60)
print("ARCHITECTURE VALIDATION")
print("="*60)
# Expected architecture based on the spec
expected_shapes = {
# Key layers to check - using actual parameter names with .weight suffix
"backbone.eagle_model.language_model.lm_head.weight": (151680, 2048),
"backbone.eagle_model.language_model.model.embed_tokens.weight": (151680, 2048),
"backbone.eagle_model.language_model.model.norm.weight": (2048,),
"backbone.eagle_model.mlp1.0.weight": (2048, 1152),
"backbone.eagle_model.mlp1.0.bias": (2048,),
"action_head.position_embedding.weight": (1024, 1536), # Fixed: added .weight
"action_head.vlln.weight": (2048,),
"action_head.vlln.bias": (2048,),
}
errors = []
warnings = []
# Get all parameters
param_dict = dict(model.named_parameters())
# Debug: print actual action_head parameter names to see the pattern
action_head_params = [name for name in param_dict.keys() if name.startswith("action_head.position")]
if action_head_params:
print("\nFound position embedding parameters:")
for name in action_head_params[:5]:
print(f" {name}: {param_dict[name].shape}")
# Check key shapes
for name, expected_shape in expected_shapes.items():
if name in param_dict:
actual_shape = tuple(param_dict[name].shape)
if actual_shape != expected_shape:
errors.append(f"Shape mismatch for {name}: expected {expected_shape}, got {actual_shape}")
else:
print(f"βœ“ {name}: {actual_shape}")
else:
errors.append(f"Missing parameter: {name}")
# Check dtypes
dtype_issues = []
for name, param in param_dict.items():
if name.startswith("backbone."):
if param.dtype != torch.bfloat16:
dtype_issues.append(f"{name}: expected bfloat16, got {param.dtype}")
elif name.startswith("action_head."):
if param.dtype != torch.float32:
dtype_issues.append(f"{name}: expected float32, got {param.dtype}")
if dtype_issues:
warnings.extend(dtype_issues[:5]) # Only show first 5
# Count parameters by component
component_params = {
"backbone": 0,
"action_head": 0,
"other": 0
}
for name, param in param_dict.items():
count = param.numel()
if name.startswith("backbone."):
component_params["backbone"] += count
elif name.startswith("action_head."):
component_params["action_head"] += count
else:
component_params["other"] += count
# Special check for lm_head
lm_head_found = False
lm_head_params = 0
for name, param in param_dict.items():
if "lm_head" in name:
lm_head_found = True
lm_head_params += param.numel()
# Report results
print("\nValidation Results:")
print(f" Errors: {len(errors)}")
print(f" Warnings: {len(warnings)}")
if errors:
print("\n❌ ERRORS:")
for error in errors:
print(f" - {error}")
if warnings:
print("\n⚠️ WARNINGS (showing first 5):")
for warning in warnings[:5]:
print(f" - {warning}")
if len(warnings) > 5:
print(f" ... and {len(warnings) - 5} more")
print("\nπŸ“Š Parameter Summary:")
total = sum(component_params.values())
print(f" Total: {total/1e6:,.1f}M")
print(f" Backbone: {component_params['backbone']/1e6:,.1f}M")
print(f" Action Head: {component_params['action_head']/1e6:,.1f}M")
if component_params['other'] > 0:
print(f" Other: {component_params['other']/1e6:,.1f}M")
print(f"\n lm_head found: {'βœ“' if lm_head_found else 'βœ—'}")
if lm_head_found:
print(f" lm_head params: {lm_head_params/1e6:.1f}M (expected: 311.1M)")
# Expected totals based on NVIDIA model
expected_total = 2724 # Million params
actual_total = total / 1e6
diff = actual_total - expected_total
print(f"\n Expected total: {expected_total}M")
print(f" Actual total: {actual_total:.1f}M")
print(f" Difference: {diff:+.1f}M")
if abs(diff) < 1: # Within 1M params
print("\nβœ… Model architecture matches expected specification!")
else:
print("\n❌ Model architecture does NOT match specification!")
return len(errors) == 0
# ─────────────────── main ──────────────────────────────────────────────────
def main(device: str, out_dir: str):
print("="*60)
print("Creating blank GR00T-N1.5-3B model")
print("="*60)
model = build_blank()
# Add diagnostics before adding lm_head
print("\nBefore adding lm_head:")
diagnose_model(model)
maybe_add_lm_head(model)
# Add diagnostics after adding lm_head
print("\nAfter adding lm_head:")
diagnose_model(model)
set_mixed(model)
model = model.to(device)
# Validate against architecture spec
validate_model_architecture(model)
out = Path(out_dir).expanduser(); out.mkdir(parents=True, exist_ok=True)
print(f"\nSaving model to {out}...")
model.save_pretrained(out, max_shard_size="2GB")
copy_tokenizer(out)
(out/"README.md").write_text("Random GR00T-N1.5-3B | backbone bf16 | action_head fp32 | Apache-2.0\n")
# Final summary
print("\n" + "="*60)
print("FINAL SUMMARY")
print("="*60)
print(f"βœ… Saved blank model ({sum(p.numel() for p in model.parameters())/1e6:,.0f}M params) β†’ {out}")
print(f"βœ… Model has lm_head with {model.backbone.eagle_model.language_model.lm_head.weight.numel()/1e6:.1f}M params")
print(f"βœ… Ready for training with Apache-2.0 license")
# ─────────────────── CLI ───────────────────────────────────────────────────
if __name__ == "__main__":
ap = argparse.ArgumentParser()
ap.add_argument("--device", default="cpu")
ap.add_argument("--out_dir", default="DolphinGR00T-N1.5-3B-Zero")
args = ap.parse_args(); main(args.device, args.out_dir)