Spaces:
Runtime error
Runtime error
| """ | |
| Pure python version of Safetensors safe_open | |
| From https://gist.github.com/Narsil/3edeec2669a5e94e4707aa0f901d2282 | |
| """ | |
| import json | |
| import mmap | |
| import os | |
| import torch | |
| class SafetensorsWrapper: | |
| def __init__(self, metadata, tensors): | |
| self._metadata = metadata | |
| self._tensors = tensors | |
| def metadata(self): | |
| return self._metadata | |
| def keys(self): | |
| return self._tensors.keys() | |
| def get_tensor(self, k): | |
| return self._tensors[k] | |
| DTYPES = { | |
| "F32": torch.float32, | |
| "F16": torch.float16, | |
| "BF16": torch.bfloat16, | |
| } | |
| def create_tensor(storage, info, offset): | |
| dtype = DTYPES[info["dtype"]] | |
| shape = info["shape"] | |
| start, stop = info["data_offsets"] | |
| return ( | |
| torch.asarray(storage[start + offset : stop + offset], dtype=torch.uint8) | |
| .view(dtype=dtype) | |
| .reshape(shape) | |
| ) | |
| def safe_open(filename, framework="pt", device="cpu"): | |
| if framework != "pt": | |
| raise ValueError("`framework` must be 'pt'") | |
| with open(filename, mode="r", encoding="utf8") as file_obj: | |
| with mmap.mmap(file_obj.fileno(), length=0, access=mmap.ACCESS_READ) as m: | |
| header = m.read(8) | |
| n = int.from_bytes(header, "little") | |
| metadata_bytes = m.read(n) | |
| metadata = json.loads(metadata_bytes) | |
| size = os.stat(filename).st_size | |
| storage = torch.ByteStorage.from_file(filename, shared=False, size=size).untyped() | |
| offset = n + 8 | |
| return SafetensorsWrapper( | |
| metadata=metadata.get("__metadata__", {}), | |
| tensors={ | |
| name: create_tensor(storage, info, offset).to(device) | |
| for name, info in metadata.items() | |
| if name != "__metadata__" | |
| }, | |
| ) | |