import os, sys, json from safetensors_file import SafeTensorsFile def _need_force_overwrite(output_file:str,cmdLine:dict) -> bool: if cmdLine["force_overwrite"]==False: if os.path.exists(output_file): print(f'output file "{output_file}" already exists, use -f flag to force overwrite',file=sys.stderr) return True return False def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_file:str) -> int: if _need_force_overwrite(output_file,cmdLine): return -1 with open(in_json_file,"rt") as f: inmeta=json.load(f) if not "__metadata__" in inmeta: print(f"file {in_json_file} does not contain a top-level __metadata__ item",file=sys.stderr) #json.dump(inmeta,fp=sys.stdout,indent=2) return -2 inmeta=inmeta["__metadata__"] #keep only metadata #json.dump(inmeta,fp=sys.stdout,indent=2) s=SafeTensorsFile.open_file(in_st_file) js=s.get_header() if inmeta==[]: js.pop("__metadata__",0) print("loaded __metadata__ is an empty list, output file will not contain __metadata__ in header") else: print("adding __metadata__ to header:") json.dump(inmeta,fp=sys.stdout,indent=2) if isinstance(inmeta,dict): for k in inmeta: inmeta[k]=str(inmeta[k]) else: inmeta=str(inmeta) #js["__metadata__"]=json.dumps(inmeta,ensure_ascii=False) js["__metadata__"]=inmeta print() newhdrbuf=json.dumps(js,separators=(',',':'),ensure_ascii=False).encode('utf-8') newhdrlen:int=int(len(newhdrbuf)) pad:int=((newhdrlen+7)&(~7))-newhdrlen #pad to multiple of 8 with open(output_file,"wb") as f: f.write(int(newhdrlen+pad).to_bytes(8,'little')) f.write(newhdrbuf) if pad>0: f.write(bytearray([32]*pad)) i:int=s.copy_data_to_file(f) if i==0: print(f"file {output_file} saved successfully") else: print(f"error {i} occurred when writing to file {output_file}") return i def PrintHeader(cmdLine:dict,input_file:str) -> int: s=SafeTensorsFile.open_file(input_file,cmdLine['quiet']) js=s.get_header() # All the .safetensors files I've seen have long key names, and as a result, # neither json nor pprint package prints text in very readable format, # so we print it ourselves, putting key name & value on one long line. # Note the print out is in Python format, not valid JSON format. firstKey=True print("{") for key in js: if firstKey: firstKey=False else: print(",") json.dump(key,fp=sys.stdout,ensure_ascii=False,separators=(',',':')) print(": ",end='') json.dump(js[key],fp=sys.stdout,ensure_ascii=False,separators=(',',':')) print("\n}") return 0 def _ParseMore(d:dict): '''Basically try to turn this: "ss_dataset_dirs":"{\"abc\": {\"n_repeats\": 2, \"img_count\": 60}}", into this: "ss_dataset_dirs":{ "abc":{ "n_repeats":2, "img_count":60 } }, ''' for key in d: value=d[key] #print("+++",key,value,type(value),"+++",sep='|') if isinstance(value,str): try: v2=json.loads(value) d[key]=v2 value=v2 except json.JSONDecodeError as e: pass if isinstance(value,dict): _ParseMore(value) def PrintMetadata(cmdLine:dict,input_file:str) -> int: with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s: js=s.get_header() if not "__metadata__" in js: print("file header does not contain a __metadata__ item",file=sys.stderr) return -2 md=js["__metadata__"] if cmdLine['parse_more']: _ParseMore(md) json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1) return 0 def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int: s=SafeTensorsFile.open_file(input_file,cmdLine['quiet']) js=s.get_header() _lora_keys:list[tuple(str,bool)]=[] # use list to sort by name for key in js: if key=='__metadata__': continue v=js[key] isScalar=False if isinstance(v,dict): if 'shape' in v: if 0==len(v['shape']): isScalar=True _lora_keys.append((key,isScalar)) _lora_keys.sort(key=lambda x:x[0]) def printkeylist(kl): firstKey=True for key in kl: if firstKey: firstKey=False else: print(",") print(key,end='') print() print("# use list to keep insertion order") print("_lora_keys:list[tuple[str,bool]]=[") printkeylist(_lora_keys) print("]") return 0 def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int: if _need_force_overwrite(output_file,cmdLine): return -1 s=SafeTensorsFile.open_file(input_file,parseHeader=False) if s.error!=0: return s.error hdrbuf=s.hdrbuf s.close_file() #close it in case user wants to write back to input_file itself with open(output_file,"wb") as fo: wn=fo.write(hdrbuf) if wn!=len(hdrbuf): print(f"write output file failed, tried to write {len(hdrbuf)} bytes, only wrote {wn} bytes",file=sys.stderr) return -1 print(f"raw header saved to file {output_file}") return 0 def _CheckLoRA_internal(s:SafeTensorsFile)->int: import lora_keys_sd15 as lora_keys js=s.get_header() set_scalar=set() set_nonscalar=set() for x in lora_keys._lora_keys: if x[1]==True: set_scalar.add(x[0]) else: set_nonscalar.add(x[0]) bad_unknowns:list[str]=[] # unrecognized keys bad_scalars:list[str]=[] #bad scalar bad_nonscalars:list[str]=[] #bad nonscalar for key in js: if key in set_nonscalar: if js[key]['shape']==[]: bad_nonscalars.append(key) set_nonscalar.remove(key) elif key in set_scalar: if js[key]['shape']!=[]: bad_scalars.append(key) set_scalar.remove(key) else: if "__metadata__"!=key: bad_unknowns.append(key) hasError=False if len(bad_unknowns)!=0: print("INFO: unrecognized items:") for x in bad_unknowns: print(" ",x) #hasError=True if len(set_scalar)>0: print("missing scalar keys:") for x in set_scalar: print(" ",x) hasError=True if len(set_nonscalar)>0: print("missing nonscalar keys:") for x in set_nonscalar: print(" ",x) hasError=True if len(bad_scalars)!=0: print("keys expected to be scalar but are nonscalar:") for x in bad_scalars: print(" ",x) hasError=True if len(bad_nonscalars)!=0: print("keys expected to be nonscalar but are scalar:") for x in bad_nonscalars: print(" ",x) hasError=True return (1 if hasError else 0) def CheckLoRA(cmdLine:dict,input_file:str)->int: s=SafeTensorsFile.open_file(input_file) i:int=_CheckLoRA_internal(s) if i==0: print("looks like an OK SD 1.x LoRA file") return 0 def ExtractData(cmdLine:dict,input_file:str,key_name:str,output_file:str)->int: if _need_force_overwrite(output_file,cmdLine): return -1 s=SafeTensorsFile.open_file(input_file,cmdLine['quiet']) if s.error!=0: return s.error bindata=s.load_one_tensor(key_name) s.close_file() #close it just in case user wants to write back to input_file itself if bindata is None: print(f'key "{key_name}" not found in header (key names are case-sensitive)',file=sys.stderr) return -1 with open(output_file,"wb") as fo: wn=fo.write(bindata) if wn!=len(bindata): print(f"write output file failed, tried to write {len(bindata)} bytes, only wrote {wn} bytes",file=sys.stderr) return -1 if cmdLine['quiet']==False: print(f"{key_name} saved to {output_file}, len={wn}") return 0