ultralytics
Eval Results
YOLO11 / tools /pt_bundle.py
bayramsn
feat(tools): add pt_bundle to bundle multiple .pt files without modifying originals\nfix(apps): webcam_app device auto-resolution and session_state safety
8c738c4
raw
history blame
6.7 kB
#!/usr/bin/env python
"""
PT bundler: Bundle multiple .pt files into a single archive without modifying originals.
Supports two formats:
1) ZIP archive (recommended) – exact bytes of each .pt preserved.
2) PT container – a single .pt (pickle) file containing a dict {relative_path: bytes}.
CLI examples (PowerShell):
# Create ZIP bundle from current repo
python tools/pt_bundle.py zip --source . --out models_bundle.zip
# Create PT container bundle
python tools/pt_bundle.py pt --source . --out models_multi.pt
# List contents
python tools/pt_bundle.py list --bundle models_bundle.zip
python tools/pt_bundle.py list --bundle models_multi.pt
# Extract a single model from bundle to a path
python tools/pt_bundle.py extract --bundle models_multi.pt --member path/to/model.pt --out C:/tmp/model.pt
"""
from __future__ import annotations
import argparse
import io
import os
import sys
from pathlib import Path
from typing import Iterable, List
try:
import torch # Only needed for PT container
except Exception: # pragma: no cover - optional for ZIP-only usage
torch = None # type: ignore
import zipfile
def find_pt_files(source: Path, include: Iterable[str] | None = None, exclude: Iterable[str] | None = None) -> List[Path]:
include = list(include or ["*.pt"]) # default include all .pt
exclude = list(exclude or [])
files: List[Path] = []
for p in source.rglob("*.pt"):
rel = p.relative_to(source)
rel_str = str(rel).replace("\\", "/")
if include and not any(Path(rel_str).match(pat) for pat in include):
continue
if exclude and any(Path(rel_str).match(pat) for pat in exclude):
continue
files.append(p)
return files
def create_zip_bundle(source: Path, out_path: Path, includes: Iterable[str] | None = None, excludes: Iterable[str] | None = None) -> int:
files = find_pt_files(source, includes, excludes)
out_path.parent.mkdir(parents=True, exist_ok=True)
with zipfile.ZipFile(out_path, mode="w", compression=zipfile.ZIP_DEFLATED) as zf:
for f in files:
zf.write(f, f.relative_to(source))
return len(files)
def create_pt_container(source: Path, out_path: Path, includes: Iterable[str] | None = None, excludes: Iterable[str] | None = None) -> int:
if torch is None:
raise RuntimeError("torch is required for PT container mode. Install torch and retry.")
files = find_pt_files(source, includes, excludes)
payload = {}
for f in files:
rel = str(f.relative_to(source)).replace("\\", "/")
with open(f, "rb") as fh:
payload[rel] = fh.read() # store exact bytes (no mutation)
out_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(payload, out_path)
return len(files)
def list_bundle(bundle: Path) -> List[str]:
if bundle.suffix.lower() == ".zip":
with zipfile.ZipFile(bundle, "r") as zf:
return [i.filename for i in zf.infolist() if not i.is_dir()]
else:
if torch is None:
raise RuntimeError("torch is required to list PT container contents.")
data = torch.load(bundle, map_location="cpu")
if isinstance(data, dict):
return sorted(map(str, data.keys()))
raise ValueError("Unsupported PT container format: expected dict mapping.")
def extract_member(bundle: Path, member: str, out_path: Path) -> None:
if bundle.suffix.lower() == ".zip":
with zipfile.ZipFile(bundle, "r") as zf:
with zf.open(member, "r") as fh, open(out_path, "wb") as out:
out.write(fh.read())
else:
if torch is None:
raise RuntimeError("torch is required to extract from PT container.")
data = torch.load(bundle, map_location="cpu")
if not isinstance(data, dict):
raise ValueError("Unsupported PT container format: expected dict mapping.")
if member not in data:
raise FileNotFoundError(f"member not found in container: {member}")
out_path.parent.mkdir(parents=True, exist_ok=True)
with open(out_path, "wb") as fh:
fh.write(data[member])
def main(argv: List[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Bundle multiple .pt files without modifying originals.")
sub = parser.add_subparsers(dest="cmd", required=True)
p_zip = sub.add_parser("zip", help="Create a ZIP archive of .pt files.")
p_zip.add_argument("--source", default=".", help="Root directory to scan for .pt files.")
p_zip.add_argument("--out", required=True, help="Output ZIP path.")
p_zip.add_argument("--include", nargs="*", default=["*.pt"], help="Glob patterns to include.")
p_zip.add_argument("--exclude", nargs="*", default=[], help="Glob patterns to exclude.")
p_pt = sub.add_parser("pt", help="Create a single .pt container (dict of bytes).")
p_pt.add_argument("--source", default=".", help="Root directory to scan for .pt files.")
p_pt.add_argument("--out", required=True, help="Output PT path (e.g., models_multi.pt).")
p_pt.add_argument("--include", nargs="*", default=["*.pt"], help="Glob patterns to include.")
p_pt.add_argument("--exclude", nargs="*", default=[], help="Glob patterns to exclude.")
p_list = sub.add_parser("list", help="List contents of a bundle (ZIP or PT container).")
p_list.add_argument("--bundle", required=True, help="Path to models_bundle.zip or models_multi.pt.")
p_ext = sub.add_parser("extract", help="Extract a single member from the bundle.")
p_ext.add_argument("--bundle", required=True, help="Bundle path (ZIP or PT container).")
p_ext.add_argument("--member", required=True, help="Member path inside the bundle.")
p_ext.add_argument("--out", required=True, help="Destination file path to write.")
args = parser.parse_args(argv)
if args.cmd == "zip":
count = create_zip_bundle(Path(args.source), Path(args.out), args.include, args.exclude)
print(f"ZIP bundle written: {args.out} ({count} files)")
return 0
if args.cmd == "pt":
count = create_pt_container(Path(args.source), Path(args.out), args.include, args.exclude)
print(f"PT container written: {args.out} ({count} files)")
return 0
if args.cmd == "list":
items = list_bundle(Path(args.bundle))
for it in items:
print(it)
return 0
if args.cmd == "extract":
extract_member(Path(args.bundle), args.member, Path(args.out))
print(f"Extracted {args.member} -> {args.out}")
return 0
parser.print_help()
return 1
if __name__ == "__main__":
raise SystemExit(main())