File size: 4,096 Bytes
75cf81d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os, sys, json
class SafeTensorsException(Exception):
def __init__(self, msg:str):
self.msg=msg
super().__init__(msg)
@staticmethod
def invalid_file(filename:str,whatiswrong:str):
s=f"{filename} is not a valid .safetensors file: {whatiswrong}"
return SafeTensorsException(msg=s)
def __str__(self):
return self.msg
class SafeTensorsChunk:
def __init__(self,name:str,dtype:str,shape:list[int],offset0:int,offset1:int):
self.name=name
self.dtype=dtype
self.shape=shape
self.offset0=offset0
self.offset1=offset1
class SafeTensorsFile:
def __init__(self):
self.f=None #file handle
self.hdrbuf=None #header byte buffer
self.header=None #parsed header as a dict
self.error=0
def __del__(self):
self.close_file()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close_file()
def close_file(self):
if self.f is not None:
self.f.close()
self.f=None
self.filename=""
#test file: duplicate_keys_in_header.safetensors
def _CheckDuplicateHeaderKeys(self):
def parse_object_pairs(pairs):
return [k for k,_ in pairs]
keys=json.loads(self.hdrbuf,object_pairs_hook=parse_object_pairs)
#print(keys)
d={}
for k in keys:
if k in d: d[k]=d[k]+1
else: d[k]=1
hasError=False
for k,v in d.items():
if v>1:
print(f"key {k} used {v} times in header",file=sys.stderr)
hasError=True
if hasError:
raise SafeTensorsException.invalid_file(self.filename,"duplicate keys in header")
@staticmethod
def open_file(filename:str,quiet=False,parseHeader=True):
s=SafeTensorsFile()
s.open(filename,quiet,parseHeader)
return s
def open(self,fn:str,quiet=False,parseHeader=True)->int:
st=os.stat(fn)
if st.st_size<8: #test file: zero_len_file.safetensors
raise SafeTensorsException.invalid_file(fn,"length less than 8 bytes")
f=open(fn,"rb")
b8=f.read(8) #read header size
if len(b8)!=8:
raise SafeTensorsException.invalid_file(fn,f"read only {len(b8)} bytes at start of file")
headerlen=int.from_bytes(b8,'little',signed=False)
if (8+headerlen>st.st_size): #test file: header_size_too_big.safetensors
raise SafeTensorsException.invalid_file(fn,"header extends past end of file")
if quiet==False:
print(f"{fn}: length={st.st_size}, header length={headerlen}")
hdrbuf=f.read(headerlen)
if len(hdrbuf)!=headerlen:
raise SafeTensorsException.invalid_file(fn,f"header size is {headerlen}, but read {len(hdrbuf)} bytes")
self.filename=fn
self.f=f
self.st=st
self.hdrbuf=hdrbuf
self.error=0
self.headerlen=headerlen
if parseHeader==True:
self._CheckDuplicateHeaderKeys()
self.header=json.loads(self.hdrbuf)
return 0
def get_header(self):
return self.header
def load_one_tensor(self,tensor_name:str):
self.get_header()
if tensor_name not in self.header: return None
t=self.header[tensor_name]
self.f.seek(8+self.headerlen+t['data_offsets'][0])
bytesToRead=t['data_offsets'][1]-t['data_offsets'][0]
bytes=self.f.read(bytesToRead)
if len(bytes)!=bytesToRead:
print(f"{tensor_name}: length={bytesToRead}, only read {len(bytes)} bytes",file=sys.stderr)
return bytes
def copy_data_to_file(self,file_handle) -> int:
self.f.seek(8+self.headerlen)
bytesLeft:int=self.st.st_size - 8 - self.headerlen
while bytesLeft>0:
chunklen:int=min(bytesLeft,int(16*1024*1024)) #copy in blocks of 16 MB
file_handle.write(self.f.read(chunklen))
bytesLeft-=chunklen
return 0
|