|
|
|
|
|
import sys, os; sys.path.insert(0, os.path.abspath(".")) |
|
|
|
|
|
import types, torch, torch.nn.functional as F, importlib.machinery as im |
|
flash_pkg = types.ModuleType("flash_attn"); flash_pkg.__spec__ = im.ModuleSpec("flash_attn", loader=None, is_package=True); flash_pkg.__path__=[] |
|
sys.modules["flash_attn"] = flash_pkg |
|
fa = types.ModuleType("flash_attn.flash_attn_interface"); fa.__spec__ = im.ModuleSpec("flash_attn.flash_attn_interface", loader=None) |
|
def _sdpa(qkv,*_,causal=False,**__): q,k,v = qkv.unbind(1); q,k,v = (t.unsqueeze(0) for t in(q,k,v)); return F.scaled_dot_product_attention(q,k,v,is_causal=causal).squeeze(0) |
|
for s in ("flash_attn_unpadded_qkvpacked_func","flash_attn_unpadded_kvpacked_func","flash_attn_varlen_qkvpacked_func","flash_attn_varlen_kvpacked_func"): setattr(fa, s, _sdpa) |
|
sys.modules["flash_attn.flash_attn_interface"] = fa; flash_pkg.flash_attn_interface = fa |
|
pad = types.ModuleType("flash_attn.bert_padding"); pad.__spec__ = im.ModuleSpec("flash_attn.bert_padding", loader=None) |
|
pad.pad_input = lambda x,*a,**k:(x,None); pad.unpad_input = lambda x,*a,**k:x |
|
sys.modules["flash_attn.bert_padding"] = pad; flash_pkg.bert_padding = pad |
|
|
|
if not torch.cuda.is_available(): |
|
torch.cuda.is_available=lambda:False |
|
torch.cuda.get_device_capability=lambda dev=None:(0,0) |
|
torch.cuda.current_device=lambda:0 |
|
torch.cuda.get_device_properties=lambda dev=None:types.SimpleNamespace(major=0,minor=0) |
|
|
|
import importlib.metadata as _im |
|
if "flash_attn" not in _im.packages_distributions(): |
|
rv, rd = _im.version, _im.distribution |
|
_im.version = lambda p:"0.0.0" if p=="flash_attn" else rv(p) |
|
_im.distribution = lambda p:types.SimpleNamespace(version="0.0.0") if p=="flash_attn" else rd(p) |
|
|
|
|
|
from pathlib import Path |
|
import argparse, json, shutil |
|
from huggingface_hub import hf_hub_download |
|
from transformers import AutoConfig |
|
from gr00t.model.gr00t_n1 import GR00T_N1_5 |
|
|
|
|
|
def patched_cfg(): |
|
p = hf_hub_download("nvidia/GR00T-N1.5-3B", "config.json") |
|
d = json.load(open(p)) |
|
if d.get("model_type") != "gr00t_n1_5": |
|
d["model_type"] = "gr00t_n1_5" |
|
patched = Path(p).with_name("config_patched.json") |
|
patched.write_text(json.dumps(d)); return str(patched) |
|
return p |
|
|
|
def build_blank(): |
|
cfg = AutoConfig.from_pretrained(patched_cfg(), |
|
trust_remote_code=True, |
|
local_files_only=True) |
|
cfg.backbone_cfg.update(dict(tune_llm=True)) |
|
cfg.backbone_cfg.pop("checkpoint_path", None) |
|
cfg.backbone_cfg.pop("use_pretrained", None) |
|
cfg.action_head_cfg.pop("checkpoint_path", None) |
|
torch.manual_seed(0) |
|
return GR00T_N1_5(cfg, local_model_path="") |
|
|
|
def maybe_add_lm_head(model): |
|
"""Ensure lm_head is properly initialized with weights""" |
|
|
|
lm = model.backbone.eagle_model.language_model |
|
|
|
|
|
embed_tokens = lm.model.embed_tokens |
|
vocab_size = embed_tokens.num_embeddings |
|
hidden_size = embed_tokens.embedding_dim |
|
|
|
print(f"Embedding dimensions: vocab_size={vocab_size}, hidden_size={hidden_size}") |
|
|
|
|
|
if vocab_size != 151680 or hidden_size != 2048: |
|
print(f"β οΈ Warning: Unexpected dimensions. Expected vocab=151680, hidden=2048") |
|
|
|
|
|
if hasattr(lm, "lm_head"): |
|
print(f"lm_head attribute exists: {lm.lm_head is not None}") |
|
|
|
|
|
|
|
print("Creating new lm_head with proper initialization...") |
|
else: |
|
print("lm_head attribute missing, creating...") |
|
|
|
|
|
|
|
new_lm_head = torch.nn.Linear(hidden_size, vocab_size, bias=False) |
|
|
|
|
|
torch.nn.init.normal_(new_lm_head.weight, mean=0.0, std=0.02) |
|
|
|
|
|
new_lm_head.weight.data = new_lm_head.weight.data.to(torch.bfloat16) |
|
|
|
|
|
lm.lm_head = new_lm_head |
|
|
|
print(f"β Created lm_head: Linear({hidden_size}, {vocab_size}, bias=False)") |
|
print(f" Weight shape: {lm.lm_head.weight.shape}") |
|
print(f" Weight dtype: {lm.lm_head.weight.dtype}") |
|
print(f" Parameters: {lm.lm_head.weight.numel() / 1e6:.1f}M") |
|
|
|
def set_mixed(model): |
|
"""Set mixed precision: backbone in bf16, action head in fp32""" |
|
for n,p in model.named_parameters(): |
|
if n.startswith("backbone.") or "lm_head" in n: |
|
p.data = p.data.to(torch.bfloat16) |
|
else: |
|
p.data = p.data.to(torch.float32) |
|
|
|
def copy_tokenizer(out): |
|
for f in ("tokenizer.json","tokenizer_config.json","vocab.txt","special_tokens_map.json"): |
|
try: shutil.copy(hf_hub_download("nvidia/GR00T-N1.5-3B", f), out/f) |
|
except Exception: pass |
|
|
|
def diagnose_model(model): |
|
"""Print diagnostic info about the model""" |
|
print("\nModel diagnostics:") |
|
total_params = sum(p.numel() for p in model.parameters()) |
|
print(f" Total params: {total_params/1e6:,.0f}M") |
|
|
|
|
|
has_lm_head = False |
|
lm_head_params = 0 |
|
lm_head_location = None |
|
|
|
for name, param in model.named_parameters(): |
|
if "lm_head" in name: |
|
has_lm_head = True |
|
lm_head_params += param.numel() |
|
lm_head_location = name |
|
|
|
print(f" Has lm_head: {'β' if has_lm_head else 'β'}") |
|
if has_lm_head: |
|
print(f" lm_head params: {lm_head_params/1e6:,.0f}M") |
|
print(f" lm_head location: {lm_head_location}") |
|
|
|
|
|
lm = model.backbone.eagle_model.language_model |
|
if hasattr(lm, 'lm_head') and lm.lm_head is not None: |
|
actual_params = lm.lm_head.weight.numel() |
|
print(f" lm_head actual params: {actual_params/1e6:,.0f}M") |
|
print(f" lm_head weight shape: {lm.lm_head.weight.shape}") |
|
print(f" lm_head weight dtype: {lm.lm_head.weight.dtype}") |
|
|
|
def validate_model_architecture(model): |
|
"""Validate model against the architecture specification""" |
|
print("\n" + "="*60) |
|
print("ARCHITECTURE VALIDATION") |
|
print("="*60) |
|
|
|
|
|
expected_shapes = { |
|
|
|
"backbone.eagle_model.language_model.lm_head.weight": (151680, 2048), |
|
"backbone.eagle_model.language_model.model.embed_tokens.weight": (151680, 2048), |
|
"backbone.eagle_model.language_model.model.norm.weight": (2048,), |
|
"backbone.eagle_model.mlp1.0.weight": (2048, 1152), |
|
"backbone.eagle_model.mlp1.0.bias": (2048,), |
|
"action_head.position_embedding.weight": (1024, 1536), |
|
"action_head.vlln.weight": (2048,), |
|
"action_head.vlln.bias": (2048,), |
|
} |
|
|
|
errors = [] |
|
warnings = [] |
|
|
|
|
|
param_dict = dict(model.named_parameters()) |
|
|
|
|
|
action_head_params = [name for name in param_dict.keys() if name.startswith("action_head.position")] |
|
if action_head_params: |
|
print("\nFound position embedding parameters:") |
|
for name in action_head_params[:5]: |
|
print(f" {name}: {param_dict[name].shape}") |
|
|
|
|
|
for name, expected_shape in expected_shapes.items(): |
|
if name in param_dict: |
|
actual_shape = tuple(param_dict[name].shape) |
|
if actual_shape != expected_shape: |
|
errors.append(f"Shape mismatch for {name}: expected {expected_shape}, got {actual_shape}") |
|
else: |
|
print(f"β {name}: {actual_shape}") |
|
else: |
|
errors.append(f"Missing parameter: {name}") |
|
|
|
|
|
dtype_issues = [] |
|
for name, param in param_dict.items(): |
|
if name.startswith("backbone."): |
|
if param.dtype != torch.bfloat16: |
|
dtype_issues.append(f"{name}: expected bfloat16, got {param.dtype}") |
|
elif name.startswith("action_head."): |
|
if param.dtype != torch.float32: |
|
dtype_issues.append(f"{name}: expected float32, got {param.dtype}") |
|
|
|
if dtype_issues: |
|
warnings.extend(dtype_issues[:5]) |
|
|
|
|
|
component_params = { |
|
"backbone": 0, |
|
"action_head": 0, |
|
"other": 0 |
|
} |
|
|
|
for name, param in param_dict.items(): |
|
count = param.numel() |
|
if name.startswith("backbone."): |
|
component_params["backbone"] += count |
|
elif name.startswith("action_head."): |
|
component_params["action_head"] += count |
|
else: |
|
component_params["other"] += count |
|
|
|
|
|
lm_head_found = False |
|
lm_head_params = 0 |
|
for name, param in param_dict.items(): |
|
if "lm_head" in name: |
|
lm_head_found = True |
|
lm_head_params += param.numel() |
|
|
|
|
|
print("\nValidation Results:") |
|
print(f" Errors: {len(errors)}") |
|
print(f" Warnings: {len(warnings)}") |
|
|
|
if errors: |
|
print("\nβ ERRORS:") |
|
for error in errors: |
|
print(f" - {error}") |
|
|
|
if warnings: |
|
print("\nβ οΈ WARNINGS (showing first 5):") |
|
for warning in warnings[:5]: |
|
print(f" - {warning}") |
|
if len(warnings) > 5: |
|
print(f" ... and {len(warnings) - 5} more") |
|
|
|
print("\nπ Parameter Summary:") |
|
total = sum(component_params.values()) |
|
print(f" Total: {total/1e6:,.1f}M") |
|
print(f" Backbone: {component_params['backbone']/1e6:,.1f}M") |
|
print(f" Action Head: {component_params['action_head']/1e6:,.1f}M") |
|
if component_params['other'] > 0: |
|
print(f" Other: {component_params['other']/1e6:,.1f}M") |
|
|
|
print(f"\n lm_head found: {'β' if lm_head_found else 'β'}") |
|
if lm_head_found: |
|
print(f" lm_head params: {lm_head_params/1e6:.1f}M (expected: 311.1M)") |
|
|
|
|
|
expected_total = 2724 |
|
actual_total = total / 1e6 |
|
diff = actual_total - expected_total |
|
|
|
print(f"\n Expected total: {expected_total}M") |
|
print(f" Actual total: {actual_total:.1f}M") |
|
print(f" Difference: {diff:+.1f}M") |
|
|
|
if abs(diff) < 1: |
|
print("\nβ
Model architecture matches expected specification!") |
|
else: |
|
print("\nβ Model architecture does NOT match specification!") |
|
|
|
return len(errors) == 0 |
|
|
|
|
|
def main(device: str, out_dir: str): |
|
print("="*60) |
|
print("Creating blank GR00T-N1.5-3B model") |
|
print("="*60) |
|
|
|
model = build_blank() |
|
|
|
|
|
print("\nBefore adding lm_head:") |
|
diagnose_model(model) |
|
|
|
maybe_add_lm_head(model) |
|
|
|
|
|
print("\nAfter adding lm_head:") |
|
diagnose_model(model) |
|
|
|
set_mixed(model) |
|
model = model.to(device) |
|
|
|
|
|
validate_model_architecture(model) |
|
|
|
out = Path(out_dir).expanduser(); out.mkdir(parents=True, exist_ok=True) |
|
|
|
print(f"\nSaving model to {out}...") |
|
model.save_pretrained(out, max_shard_size="2GB") |
|
copy_tokenizer(out) |
|
(out/"README.md").write_text("Random GR00T-N1.5-3B | backbone bf16 | action_head fp32 | Apache-2.0\n") |
|
|
|
|
|
print("\n" + "="*60) |
|
print("FINAL SUMMARY") |
|
print("="*60) |
|
print(f"β
Saved blank model ({sum(p.numel() for p in model.parameters())/1e6:,.0f}M params) β {out}") |
|
print(f"β
Model has lm_head with {model.backbone.eagle_model.language_model.lm_head.weight.numel()/1e6:.1f}M params") |
|
print(f"β
Ready for training with Apache-2.0 license") |
|
|
|
|
|
if __name__ == "__main__": |
|
ap = argparse.ArgumentParser() |
|
ap.add_argument("--device", default="cpu") |
|
ap.add_argument("--out_dir", default="DolphinGR00T-N1.5-3B-Zero") |
|
args = ap.parse_args(); main(args.device, args.out_dir) |