Juan Sebastian Giraldo
commited on
Commit
·
75cf81d
1
Parent(s):
57ff33f
Upload Lora app
Browse files- .gitignore +3 -0
- app.py +211 -0
- requirements.txt +0 -0
- safetensors_file.py +125 -0
- safetensors_util.py +98 -0
- safetensors_worker.py +243 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
**/__pycache__/
|
2 |
+
/.venv/
|
3 |
+
/scripts/
|
app.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
import sys
|
4 |
+
import io
|
5 |
+
import subprocess
|
6 |
+
import tempfile
|
7 |
+
from pathlib import Path
|
8 |
+
from safetensors_worker import PrintMetadata
|
9 |
+
|
10 |
+
class Context:
|
11 |
+
def __init__(self):
|
12 |
+
self.obj = {'quiet': True, 'parse_more': True}
|
13 |
+
|
14 |
+
ctx = Context()
|
15 |
+
|
16 |
+
def debug_log(message: str):
|
17 |
+
print(f"[DEBUG] {message}")
|
18 |
+
|
19 |
+
def load_metadata(file_path: str) -> tuple:
|
20 |
+
try:
|
21 |
+
debug_log(f"Loading file: {file_path}")
|
22 |
+
|
23 |
+
if not file_path:
|
24 |
+
return {"status": "Awaiting input"}, {}, "", "", ""
|
25 |
+
|
26 |
+
old_stdout = sys.stdout
|
27 |
+
sys.stdout = buffer = io.StringIO()
|
28 |
+
exit_code = PrintMetadata(ctx.obj, file_path.name)
|
29 |
+
sys.stdout = old_stdout
|
30 |
+
|
31 |
+
metadata_str = buffer.getvalue().strip()
|
32 |
+
|
33 |
+
if exit_code != 0:
|
34 |
+
error_msg = f"Error code {exit_code}"
|
35 |
+
return {"error": error_msg}, {}, "", error_msg, ""
|
36 |
+
|
37 |
+
try:
|
38 |
+
full_metadata = json.loads(metadata_str)
|
39 |
+
except json.JSONDecodeError:
|
40 |
+
error_msg = "Invalid metadata structure"
|
41 |
+
return {"error": error_msg}, {}, "", error_msg, ""
|
42 |
+
|
43 |
+
training_params = full_metadata.get("__metadata__", {})
|
44 |
+
key_metrics = {
|
45 |
+
key: training_params.get(key, "N/A")
|
46 |
+
for key in [
|
47 |
+
"ss_optimizer", "ss_num_epochs", "ss_unet_lr",
|
48 |
+
"ss_text_encoder_lr", "ss_steps"
|
49 |
+
]
|
50 |
+
}
|
51 |
+
|
52 |
+
return full_metadata, key_metrics, json.dumps(full_metadata, indent=2), "", file_path.name
|
53 |
+
|
54 |
+
except Exception as e:
|
55 |
+
return {"error": str(e)}, {}, "", str(e), ""
|
56 |
+
|
57 |
+
def validate_json(edited_json: str) -> tuple:
|
58 |
+
try:
|
59 |
+
return True, json.loads(edited_json), ""
|
60 |
+
except Exception as e:
|
61 |
+
return False, None, str(e)
|
62 |
+
|
63 |
+
def update_metadata(edited_json: str) -> tuple:
|
64 |
+
try:
|
65 |
+
modified_data = json.loads(edited_json)
|
66 |
+
metadata = modified_data.get("__metadata__", {})
|
67 |
+
|
68 |
+
key_fields = {
|
69 |
+
param: metadata.get(param, "N/A")
|
70 |
+
for param in [
|
71 |
+
"ss_optimizer", "ss_num_epochs", "ss_unet_lr",
|
72 |
+
"ss_text_encoder_lr", "ss_steps"
|
73 |
+
]
|
74 |
+
}
|
75 |
+
return key_fields, modified_data, ""
|
76 |
+
except:
|
77 |
+
return gr.update(), gr.update(), ""
|
78 |
+
|
79 |
+
def save_metadata(edited_json: str, source_file: str, output_name: str) -> tuple:
|
80 |
+
debug_log("Initiating save process")
|
81 |
+
try:
|
82 |
+
if not source_file:
|
83 |
+
return None, "No source file provided"
|
84 |
+
|
85 |
+
is_valid, parsed_data, error = validate_json(edited_json)
|
86 |
+
if not is_valid:
|
87 |
+
return None, f"Validation error: {error}"
|
88 |
+
|
89 |
+
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp:
|
90 |
+
json.dump(parsed_data, tmp, indent=2)
|
91 |
+
temp_path = tmp.name
|
92 |
+
|
93 |
+
source_path = Path(source_file)
|
94 |
+
|
95 |
+
if output_name.strip():
|
96 |
+
base_name = output_name.strip()
|
97 |
+
if not base_name.endswith(".safetensors"):
|
98 |
+
base_name += ".safetensors"
|
99 |
+
else:
|
100 |
+
base_name = f"{source_path.stem}_modified.safetensors"
|
101 |
+
|
102 |
+
output_path = Path(base_name)
|
103 |
+
version = 1
|
104 |
+
while output_path.exists():
|
105 |
+
output_path = Path(f"{source_path.stem}_modified_{version}.safetensors")
|
106 |
+
version += 1
|
107 |
+
|
108 |
+
cmd = [
|
109 |
+
sys.executable,
|
110 |
+
"safetensors_util.py",
|
111 |
+
"writemd",
|
112 |
+
source_file,
|
113 |
+
temp_path,
|
114 |
+
str(output_path),
|
115 |
+
"-f"
|
116 |
+
]
|
117 |
+
|
118 |
+
result = subprocess.run(
|
119 |
+
cmd,
|
120 |
+
capture_output=True,
|
121 |
+
text=True,
|
122 |
+
check=False
|
123 |
+
)
|
124 |
+
|
125 |
+
Path(temp_path).unlink(missing_ok=True)
|
126 |
+
|
127 |
+
if result.returncode != 0:
|
128 |
+
error_msg = f"Save failure: {result.stderr}"
|
129 |
+
return None, error_msg
|
130 |
+
|
131 |
+
return str(output_path), ""
|
132 |
+
|
133 |
+
except Exception as e:
|
134 |
+
return None, f"Critical error: {str(e)}"
|
135 |
+
|
136 |
+
def create_interface():
|
137 |
+
with gr.Blocks(title="LoRA Metadata Editor") as app:
|
138 |
+
gr.Markdown("# LoRA Metadata Editor")
|
139 |
+
|
140 |
+
with gr.Tabs():
|
141 |
+
with gr.Tab("Metdata Viewer"):
|
142 |
+
gr.Markdown("### LoRa Upload")
|
143 |
+
file_input = gr.File(
|
144 |
+
file_types=[".safetensors"],
|
145 |
+
show_label=False
|
146 |
+
)
|
147 |
+
|
148 |
+
with gr.Row():
|
149 |
+
with gr.Column():
|
150 |
+
gr.Markdown("### Full Metadata")
|
151 |
+
full_viewer = gr.JSON(show_label=False)
|
152 |
+
|
153 |
+
with gr.Column():
|
154 |
+
gr.Markdown("### Key Metrics")
|
155 |
+
key_viewer = gr.JSON(show_label=False)
|
156 |
+
|
157 |
+
with gr.Tab("Edit Metadata"):
|
158 |
+
with gr.Row():
|
159 |
+
with gr.Column():
|
160 |
+
gr.Markdown("### JSON Workspace")
|
161 |
+
metadata_editor = gr.Textbox(
|
162 |
+
lines=25,
|
163 |
+
show_label=False,
|
164 |
+
placeholder="Edit metadata JSON here"
|
165 |
+
)
|
166 |
+
gr.Markdown("### Output Name")
|
167 |
+
filename_input = gr.Textbox(
|
168 |
+
placeholder="Leave empty for auto-naming",
|
169 |
+
show_label=False
|
170 |
+
)
|
171 |
+
|
172 |
+
with gr.Column():
|
173 |
+
gr.Markdown("### Live Preview")
|
174 |
+
modified_viewer = gr.JSON(show_label=False)
|
175 |
+
save_btn = gr.Button("💾 Save Metadata", variant="primary")
|
176 |
+
gr.Markdown("### Download Modified LoRa")
|
177 |
+
output_file = gr.File(
|
178 |
+
visible=False,
|
179 |
+
show_label=False
|
180 |
+
)
|
181 |
+
|
182 |
+
status_display = gr.HTML(visible=False)
|
183 |
+
source_tracker = gr.State()
|
184 |
+
|
185 |
+
file_input.upload(
|
186 |
+
load_metadata,
|
187 |
+
inputs=file_input,
|
188 |
+
outputs=[full_viewer, key_viewer, metadata_editor, status_display, source_tracker]
|
189 |
+
)
|
190 |
+
|
191 |
+
metadata_editor.change(
|
192 |
+
update_metadata,
|
193 |
+
inputs=metadata_editor,
|
194 |
+
outputs=[key_viewer, modified_viewer, status_display]
|
195 |
+
)
|
196 |
+
|
197 |
+
save_btn.click(
|
198 |
+
save_metadata,
|
199 |
+
inputs=[metadata_editor, source_tracker, filename_input],
|
200 |
+
outputs=[output_file, status_display],
|
201 |
+
).then(
|
202 |
+
lambda x: gr.File(value=x, visible=True),
|
203 |
+
inputs=output_file,
|
204 |
+
outputs=output_file
|
205 |
+
)
|
206 |
+
|
207 |
+
return app
|
208 |
+
|
209 |
+
if __name__ == "__main__":
|
210 |
+
interface = create_interface()
|
211 |
+
interface.launch()
|
requirements.txt
ADDED
Binary file (4.84 kB). View file
|
|
safetensors_file.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, json
|
2 |
+
|
3 |
+
class SafeTensorsException(Exception):
|
4 |
+
def __init__(self, msg:str):
|
5 |
+
self.msg=msg
|
6 |
+
super().__init__(msg)
|
7 |
+
|
8 |
+
@staticmethod
|
9 |
+
def invalid_file(filename:str,whatiswrong:str):
|
10 |
+
s=f"{filename} is not a valid .safetensors file: {whatiswrong}"
|
11 |
+
return SafeTensorsException(msg=s)
|
12 |
+
|
13 |
+
def __str__(self):
|
14 |
+
return self.msg
|
15 |
+
|
16 |
+
class SafeTensorsChunk:
|
17 |
+
def __init__(self,name:str,dtype:str,shape:list[int],offset0:int,offset1:int):
|
18 |
+
self.name=name
|
19 |
+
self.dtype=dtype
|
20 |
+
self.shape=shape
|
21 |
+
self.offset0=offset0
|
22 |
+
self.offset1=offset1
|
23 |
+
|
24 |
+
class SafeTensorsFile:
|
25 |
+
def __init__(self):
|
26 |
+
self.f=None #file handle
|
27 |
+
self.hdrbuf=None #header byte buffer
|
28 |
+
self.header=None #parsed header as a dict
|
29 |
+
self.error=0
|
30 |
+
|
31 |
+
def __del__(self):
|
32 |
+
self.close_file()
|
33 |
+
|
34 |
+
def __enter__(self):
|
35 |
+
return self
|
36 |
+
|
37 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
38 |
+
self.close_file()
|
39 |
+
|
40 |
+
def close_file(self):
|
41 |
+
if self.f is not None:
|
42 |
+
self.f.close()
|
43 |
+
self.f=None
|
44 |
+
self.filename=""
|
45 |
+
|
46 |
+
#test file: duplicate_keys_in_header.safetensors
|
47 |
+
def _CheckDuplicateHeaderKeys(self):
|
48 |
+
def parse_object_pairs(pairs):
|
49 |
+
return [k for k,_ in pairs]
|
50 |
+
|
51 |
+
keys=json.loads(self.hdrbuf,object_pairs_hook=parse_object_pairs)
|
52 |
+
#print(keys)
|
53 |
+
d={}
|
54 |
+
for k in keys:
|
55 |
+
if k in d: d[k]=d[k]+1
|
56 |
+
else: d[k]=1
|
57 |
+
hasError=False
|
58 |
+
for k,v in d.items():
|
59 |
+
if v>1:
|
60 |
+
print(f"key {k} used {v} times in header",file=sys.stderr)
|
61 |
+
hasError=True
|
62 |
+
if hasError:
|
63 |
+
raise SafeTensorsException.invalid_file(self.filename,"duplicate keys in header")
|
64 |
+
|
65 |
+
@staticmethod
|
66 |
+
def open_file(filename:str,quiet=False,parseHeader=True):
|
67 |
+
s=SafeTensorsFile()
|
68 |
+
s.open(filename,quiet,parseHeader)
|
69 |
+
return s
|
70 |
+
|
71 |
+
def open(self,fn:str,quiet=False,parseHeader=True)->int:
|
72 |
+
st=os.stat(fn)
|
73 |
+
if st.st_size<8: #test file: zero_len_file.safetensors
|
74 |
+
raise SafeTensorsException.invalid_file(fn,"length less than 8 bytes")
|
75 |
+
|
76 |
+
f=open(fn,"rb")
|
77 |
+
b8=f.read(8) #read header size
|
78 |
+
if len(b8)!=8:
|
79 |
+
raise SafeTensorsException.invalid_file(fn,f"read only {len(b8)} bytes at start of file")
|
80 |
+
headerlen=int.from_bytes(b8,'little',signed=False)
|
81 |
+
|
82 |
+
if (8+headerlen>st.st_size): #test file: header_size_too_big.safetensors
|
83 |
+
raise SafeTensorsException.invalid_file(fn,"header extends past end of file")
|
84 |
+
|
85 |
+
if quiet==False:
|
86 |
+
print(f"{fn}: length={st.st_size}, header length={headerlen}")
|
87 |
+
hdrbuf=f.read(headerlen)
|
88 |
+
if len(hdrbuf)!=headerlen:
|
89 |
+
raise SafeTensorsException.invalid_file(fn,f"header size is {headerlen}, but read {len(hdrbuf)} bytes")
|
90 |
+
self.filename=fn
|
91 |
+
self.f=f
|
92 |
+
self.st=st
|
93 |
+
self.hdrbuf=hdrbuf
|
94 |
+
self.error=0
|
95 |
+
self.headerlen=headerlen
|
96 |
+
if parseHeader==True:
|
97 |
+
self._CheckDuplicateHeaderKeys()
|
98 |
+
self.header=json.loads(self.hdrbuf)
|
99 |
+
return 0
|
100 |
+
|
101 |
+
def get_header(self):
|
102 |
+
return self.header
|
103 |
+
|
104 |
+
def load_one_tensor(self,tensor_name:str):
|
105 |
+
self.get_header()
|
106 |
+
if tensor_name not in self.header: return None
|
107 |
+
|
108 |
+
t=self.header[tensor_name]
|
109 |
+
self.f.seek(8+self.headerlen+t['data_offsets'][0])
|
110 |
+
bytesToRead=t['data_offsets'][1]-t['data_offsets'][0]
|
111 |
+
bytes=self.f.read(bytesToRead)
|
112 |
+
if len(bytes)!=bytesToRead:
|
113 |
+
print(f"{tensor_name}: length={bytesToRead}, only read {len(bytes)} bytes",file=sys.stderr)
|
114 |
+
return bytes
|
115 |
+
|
116 |
+
def copy_data_to_file(self,file_handle) -> int:
|
117 |
+
|
118 |
+
self.f.seek(8+self.headerlen)
|
119 |
+
bytesLeft:int=self.st.st_size - 8 - self.headerlen
|
120 |
+
while bytesLeft>0:
|
121 |
+
chunklen:int=min(bytesLeft,int(16*1024*1024)) #copy in blocks of 16 MB
|
122 |
+
file_handle.write(self.f.read(chunklen))
|
123 |
+
bytesLeft-=chunklen
|
124 |
+
|
125 |
+
return 0
|
safetensors_util.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys, click
|
2 |
+
|
3 |
+
import safetensors_worker
|
4 |
+
# This file deals with command line only. If the command line is parsed successfully,
|
5 |
+
# we will call one of the functions in safetensors_worker.py.
|
6 |
+
|
7 |
+
readonly_input_file=click.argument("input_file", metavar='input_file',
|
8 |
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True))
|
9 |
+
output_file=click.argument("output_file", metavar='output_file',
|
10 |
+
type=click.Path(file_okay=True, dir_okay=False, writable=True))
|
11 |
+
|
12 |
+
force_overwrite_flag=click.option("-f","--force-overwrite",default=False,is_flag=True, show_default=True,
|
13 |
+
help="overwrite existing files")
|
14 |
+
fix_ued_flag=click.option("-pm","--parse-more",default=False,is_flag=True, show_default=True,
|
15 |
+
help="when printing metadata, unescaped doublequotes to make text more readable" )
|
16 |
+
quiet_flag=click.option("-q","--quiet",default=False,is_flag=True, show_default=True,
|
17 |
+
help="Quiet mode, don't print informational stuff" )
|
18 |
+
|
19 |
+
@click.group()
|
20 |
+
@click.version_option(version=7)
|
21 |
+
@quiet_flag
|
22 |
+
|
23 |
+
@click.pass_context
|
24 |
+
def cli(ctx,quiet:bool):
|
25 |
+
# ensure that ctx.obj exists and is a dict (in case `cli()` is called
|
26 |
+
# by means other than the `if` block below)
|
27 |
+
ctx.ensure_object(dict)
|
28 |
+
ctx.obj['quiet'] = quiet
|
29 |
+
|
30 |
+
|
31 |
+
@cli.command(name="header",short_help="print file header")
|
32 |
+
@readonly_input_file
|
33 |
+
@click.pass_context
|
34 |
+
def cmd_header(ctx,input_file:str) -> int:
|
35 |
+
sys.exit( safetensors_worker.PrintHeader(ctx.obj,input_file) )
|
36 |
+
|
37 |
+
|
38 |
+
@cli.command(name="metadata",short_help="print only __metadata__ in file header")
|
39 |
+
@readonly_input_file
|
40 |
+
@fix_ued_flag
|
41 |
+
@click.pass_context
|
42 |
+
def cmd_meta(ctx,input_file:str,parse_more:bool)->int:
|
43 |
+
ctx.obj['parse_more'] = parse_more
|
44 |
+
sys.exit( safetensors_worker.PrintMetadata(ctx.obj,input_file) )
|
45 |
+
|
46 |
+
|
47 |
+
@cli.command(name="listkeys",short_help="print header key names (except __metadata__) as a Python list")
|
48 |
+
@readonly_input_file
|
49 |
+
@click.pass_context
|
50 |
+
def cmd_keyspy(ctx,input_file:str) -> int:
|
51 |
+
sys.exit( safetensors_worker.HeaderKeysToLists(ctx.obj,input_file) )
|
52 |
+
|
53 |
+
|
54 |
+
@cli.command(name="writemd",short_help="read __metadata__ from json and write to safetensors file")
|
55 |
+
@click.argument("in_st_file", metavar='input_st_file',
|
56 |
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True))
|
57 |
+
@click.argument("in_json_file", metavar='input_json_file',
|
58 |
+
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True))
|
59 |
+
@output_file
|
60 |
+
@force_overwrite_flag
|
61 |
+
@click.pass_context
|
62 |
+
def cmd_writemd(ctx,in_st_file:str,in_json_file:str,output_file:str,force_overwrite:bool) -> int:
|
63 |
+
"""Read "__metadata__" from json file and write to safetensors header"""
|
64 |
+
ctx.obj['force_overwrite'] = force_overwrite
|
65 |
+
sys.exit( safetensors_worker.WriteMetadataToHeader(ctx.obj,in_st_file,in_json_file,output_file) )
|
66 |
+
|
67 |
+
|
68 |
+
@cli.command(name="extracthdr",short_help="extract file header and save to output file")
|
69 |
+
@readonly_input_file
|
70 |
+
@output_file
|
71 |
+
@force_overwrite_flag
|
72 |
+
@click.pass_context
|
73 |
+
def cmd_extractheader(ctx,input_file:str,output_file:str,force_overwrite:bool) -> int:
|
74 |
+
ctx.obj['force_overwrite'] = force_overwrite
|
75 |
+
sys.exit( safetensors_worker.ExtractHeader(ctx.obj,input_file,output_file) )
|
76 |
+
|
77 |
+
|
78 |
+
@cli.command(name="extractdata",short_help="extract one tensor and save to file")
|
79 |
+
@readonly_input_file
|
80 |
+
@click.argument("key_name", metavar='key_name',type=click.STRING)
|
81 |
+
@output_file
|
82 |
+
@force_overwrite_flag
|
83 |
+
@click.pass_context
|
84 |
+
def cmd_extractheader(ctx,input_file:str,key_name:str,output_file:str,force_overwrite:bool) -> int:
|
85 |
+
ctx.obj['force_overwrite'] = force_overwrite
|
86 |
+
sys.exit( safetensors_worker.ExtractData(ctx.obj,input_file,key_name,output_file) )
|
87 |
+
|
88 |
+
|
89 |
+
@cli.command(name="checklora",short_help="see if input file is a SD 1.x LoRA file")
|
90 |
+
@readonly_input_file
|
91 |
+
@click.pass_context
|
92 |
+
def cmd_checklora(ctx,input_file:str)->int:
|
93 |
+
sys.exit( safetensors_worker.CheckLoRA(ctx.obj,input_file) )
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == '__main__':
|
97 |
+
sys.stdout.reconfigure(encoding='utf-8')
|
98 |
+
cli(obj={},max_content_width=96)
|
safetensors_worker.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys, json
|
2 |
+
from safetensors_file import SafeTensorsFile
|
3 |
+
|
4 |
+
def _need_force_overwrite(output_file:str,cmdLine:dict) -> bool:
|
5 |
+
if cmdLine["force_overwrite"]==False:
|
6 |
+
if os.path.exists(output_file):
|
7 |
+
print(f'output file "{output_file}" already exists, use -f flag to force overwrite',file=sys.stderr)
|
8 |
+
return True
|
9 |
+
return False
|
10 |
+
|
11 |
+
def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_file:str) -> int:
|
12 |
+
if _need_force_overwrite(output_file,cmdLine): return -1
|
13 |
+
|
14 |
+
with open(in_json_file,"rt") as f:
|
15 |
+
inmeta=json.load(f)
|
16 |
+
if not "__metadata__" in inmeta:
|
17 |
+
print(f"file {in_json_file} does not contain a top-level __metadata__ item",file=sys.stderr)
|
18 |
+
#json.dump(inmeta,fp=sys.stdout,indent=2)
|
19 |
+
return -2
|
20 |
+
inmeta=inmeta["__metadata__"] #keep only metadata
|
21 |
+
#json.dump(inmeta,fp=sys.stdout,indent=2)
|
22 |
+
|
23 |
+
s=SafeTensorsFile.open_file(in_st_file)
|
24 |
+
js=s.get_header()
|
25 |
+
|
26 |
+
if inmeta==[]:
|
27 |
+
js.pop("__metadata__",0)
|
28 |
+
print("loaded __metadata__ is an empty list, output file will not contain __metadata__ in header")
|
29 |
+
else:
|
30 |
+
print("adding __metadata__ to header:")
|
31 |
+
json.dump(inmeta,fp=sys.stdout,indent=2)
|
32 |
+
if isinstance(inmeta,dict):
|
33 |
+
for k in inmeta:
|
34 |
+
inmeta[k]=str(inmeta[k])
|
35 |
+
else:
|
36 |
+
inmeta=str(inmeta)
|
37 |
+
#js["__metadata__"]=json.dumps(inmeta,ensure_ascii=False)
|
38 |
+
js["__metadata__"]=inmeta
|
39 |
+
print()
|
40 |
+
|
41 |
+
newhdrbuf=json.dumps(js,separators=(',',':'),ensure_ascii=False).encode('utf-8')
|
42 |
+
newhdrlen:int=int(len(newhdrbuf))
|
43 |
+
pad:int=((newhdrlen+7)&(~7))-newhdrlen #pad to multiple of 8
|
44 |
+
|
45 |
+
with open(output_file,"wb") as f:
|
46 |
+
f.write(int(newhdrlen+pad).to_bytes(8,'little'))
|
47 |
+
f.write(newhdrbuf)
|
48 |
+
if pad>0: f.write(bytearray([32]*pad))
|
49 |
+
i:int=s.copy_data_to_file(f)
|
50 |
+
if i==0:
|
51 |
+
print(f"file {output_file} saved successfully")
|
52 |
+
else:
|
53 |
+
print(f"error {i} occurred when writing to file {output_file}")
|
54 |
+
return i
|
55 |
+
|
56 |
+
def PrintHeader(cmdLine:dict,input_file:str) -> int:
|
57 |
+
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
|
58 |
+
js=s.get_header()
|
59 |
+
|
60 |
+
# All the .safetensors files I've seen have long key names, and as a result,
|
61 |
+
# neither json nor pprint package prints text in very readable format,
|
62 |
+
# so we print it ourselves, putting key name & value on one long line.
|
63 |
+
# Note the print out is in Python format, not valid JSON format.
|
64 |
+
firstKey=True
|
65 |
+
print("{")
|
66 |
+
for key in js:
|
67 |
+
if firstKey:
|
68 |
+
firstKey=False
|
69 |
+
else:
|
70 |
+
print(",")
|
71 |
+
json.dump(key,fp=sys.stdout,ensure_ascii=False,separators=(',',':'))
|
72 |
+
print(": ",end='')
|
73 |
+
json.dump(js[key],fp=sys.stdout,ensure_ascii=False,separators=(',',':'))
|
74 |
+
print("\n}")
|
75 |
+
return 0
|
76 |
+
|
77 |
+
def _ParseMore(d:dict):
|
78 |
+
'''Basically try to turn this:
|
79 |
+
|
80 |
+
"ss_dataset_dirs":"{\"abc\": {\"n_repeats\": 2, \"img_count\": 60}}",
|
81 |
+
|
82 |
+
into this:
|
83 |
+
|
84 |
+
"ss_dataset_dirs":{
|
85 |
+
"abc":{
|
86 |
+
"n_repeats":2,
|
87 |
+
"img_count":60
|
88 |
+
}
|
89 |
+
},
|
90 |
+
|
91 |
+
'''
|
92 |
+
for key in d:
|
93 |
+
value=d[key]
|
94 |
+
#print("+++",key,value,type(value),"+++",sep='|')
|
95 |
+
if isinstance(value,str):
|
96 |
+
try:
|
97 |
+
v2=json.loads(value)
|
98 |
+
d[key]=v2
|
99 |
+
value=v2
|
100 |
+
except json.JSONDecodeError as e:
|
101 |
+
pass
|
102 |
+
if isinstance(value,dict):
|
103 |
+
_ParseMore(value)
|
104 |
+
|
105 |
+
def PrintMetadata(cmdLine:dict,input_file:str) -> int:
|
106 |
+
with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s:
|
107 |
+
js=s.get_header()
|
108 |
+
|
109 |
+
if not "__metadata__" in js:
|
110 |
+
print("file header does not contain a __metadata__ item",file=sys.stderr)
|
111 |
+
return -2
|
112 |
+
|
113 |
+
md=js["__metadata__"]
|
114 |
+
if cmdLine['parse_more']:
|
115 |
+
_ParseMore(md)
|
116 |
+
json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1)
|
117 |
+
return 0
|
118 |
+
|
119 |
+
def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int:
|
120 |
+
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
|
121 |
+
js=s.get_header()
|
122 |
+
|
123 |
+
_lora_keys:list[tuple(str,bool)]=[] # use list to sort by name
|
124 |
+
for key in js:
|
125 |
+
if key=='__metadata__': continue
|
126 |
+
v=js[key]
|
127 |
+
isScalar=False
|
128 |
+
if isinstance(v,dict):
|
129 |
+
if 'shape' in v:
|
130 |
+
if 0==len(v['shape']):
|
131 |
+
isScalar=True
|
132 |
+
_lora_keys.append((key,isScalar))
|
133 |
+
_lora_keys.sort(key=lambda x:x[0])
|
134 |
+
|
135 |
+
def printkeylist(kl):
|
136 |
+
firstKey=True
|
137 |
+
for key in kl:
|
138 |
+
if firstKey: firstKey=False
|
139 |
+
else: print(",")
|
140 |
+
print(key,end='')
|
141 |
+
print()
|
142 |
+
|
143 |
+
print("# use list to keep insertion order")
|
144 |
+
print("_lora_keys:list[tuple[str,bool]]=[")
|
145 |
+
printkeylist(_lora_keys)
|
146 |
+
print("]")
|
147 |
+
|
148 |
+
return 0
|
149 |
+
|
150 |
+
|
151 |
+
def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int:
|
152 |
+
if _need_force_overwrite(output_file,cmdLine): return -1
|
153 |
+
|
154 |
+
s=SafeTensorsFile.open_file(input_file,parseHeader=False)
|
155 |
+
if s.error!=0: return s.error
|
156 |
+
|
157 |
+
hdrbuf=s.hdrbuf
|
158 |
+
s.close_file() #close it in case user wants to write back to input_file itself
|
159 |
+
with open(output_file,"wb") as fo:
|
160 |
+
wn=fo.write(hdrbuf)
|
161 |
+
if wn!=len(hdrbuf):
|
162 |
+
print(f"write output file failed, tried to write {len(hdrbuf)} bytes, only wrote {wn} bytes",file=sys.stderr)
|
163 |
+
return -1
|
164 |
+
print(f"raw header saved to file {output_file}")
|
165 |
+
return 0
|
166 |
+
|
167 |
+
|
168 |
+
def _CheckLoRA_internal(s:SafeTensorsFile)->int:
|
169 |
+
import lora_keys_sd15 as lora_keys
|
170 |
+
js=s.get_header()
|
171 |
+
set_scalar=set()
|
172 |
+
set_nonscalar=set()
|
173 |
+
for x in lora_keys._lora_keys:
|
174 |
+
if x[1]==True: set_scalar.add(x[0])
|
175 |
+
else: set_nonscalar.add(x[0])
|
176 |
+
|
177 |
+
bad_unknowns:list[str]=[] # unrecognized keys
|
178 |
+
bad_scalars:list[str]=[] #bad scalar
|
179 |
+
bad_nonscalars:list[str]=[] #bad nonscalar
|
180 |
+
for key in js:
|
181 |
+
if key in set_nonscalar:
|
182 |
+
if js[key]['shape']==[]: bad_nonscalars.append(key)
|
183 |
+
set_nonscalar.remove(key)
|
184 |
+
elif key in set_scalar:
|
185 |
+
if js[key]['shape']!=[]: bad_scalars.append(key)
|
186 |
+
set_scalar.remove(key)
|
187 |
+
else:
|
188 |
+
if "__metadata__"!=key:
|
189 |
+
bad_unknowns.append(key)
|
190 |
+
|
191 |
+
hasError=False
|
192 |
+
|
193 |
+
if len(bad_unknowns)!=0:
|
194 |
+
print("INFO: unrecognized items:")
|
195 |
+
for x in bad_unknowns: print(" ",x)
|
196 |
+
#hasError=True
|
197 |
+
|
198 |
+
if len(set_scalar)>0:
|
199 |
+
print("missing scalar keys:")
|
200 |
+
for x in set_scalar: print(" ",x)
|
201 |
+
hasError=True
|
202 |
+
if len(set_nonscalar)>0:
|
203 |
+
print("missing nonscalar keys:")
|
204 |
+
for x in set_nonscalar: print(" ",x)
|
205 |
+
hasError=True
|
206 |
+
|
207 |
+
if len(bad_scalars)!=0:
|
208 |
+
print("keys expected to be scalar but are nonscalar:")
|
209 |
+
for x in bad_scalars: print(" ",x)
|
210 |
+
hasError=True
|
211 |
+
|
212 |
+
if len(bad_nonscalars)!=0:
|
213 |
+
print("keys expected to be nonscalar but are scalar:")
|
214 |
+
for x in bad_nonscalars: print(" ",x)
|
215 |
+
hasError=True
|
216 |
+
|
217 |
+
return (1 if hasError else 0)
|
218 |
+
|
219 |
+
def CheckLoRA(cmdLine:dict,input_file:str)->int:
|
220 |
+
s=SafeTensorsFile.open_file(input_file)
|
221 |
+
i:int=_CheckLoRA_internal(s)
|
222 |
+
if i==0: print("looks like an OK SD 1.x LoRA file")
|
223 |
+
return 0
|
224 |
+
|
225 |
+
def ExtractData(cmdLine:dict,input_file:str,key_name:str,output_file:str)->int:
|
226 |
+
if _need_force_overwrite(output_file,cmdLine): return -1
|
227 |
+
|
228 |
+
s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
|
229 |
+
if s.error!=0: return s.error
|
230 |
+
|
231 |
+
bindata=s.load_one_tensor(key_name)
|
232 |
+
s.close_file() #close it just in case user wants to write back to input_file itself
|
233 |
+
if bindata is None:
|
234 |
+
print(f'key "{key_name}" not found in header (key names are case-sensitive)',file=sys.stderr)
|
235 |
+
return -1
|
236 |
+
|
237 |
+
with open(output_file,"wb") as fo:
|
238 |
+
wn=fo.write(bindata)
|
239 |
+
if wn!=len(bindata):
|
240 |
+
print(f"write output file failed, tried to write {len(bindata)} bytes, only wrote {wn} bytes",file=sys.stderr)
|
241 |
+
return -1
|
242 |
+
if cmdLine['quiet']==False: print(f"{key_name} saved to {output_file}, len={wn}")
|
243 |
+
return 0
|