Juan Sebastian Giraldo commited on
Commit
75cf81d
·
1 Parent(s): 57ff33f

Upload Lora app

Browse files
Files changed (6) hide show
  1. .gitignore +3 -0
  2. app.py +211 -0
  3. requirements.txt +0 -0
  4. safetensors_file.py +125 -0
  5. safetensors_util.py +98 -0
  6. safetensors_worker.py +243 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ **/__pycache__/
2
+ /.venv/
3
+ /scripts/
app.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import sys
4
+ import io
5
+ import subprocess
6
+ import tempfile
7
+ from pathlib import Path
8
+ from safetensors_worker import PrintMetadata
9
+
10
+ class Context:
11
+ def __init__(self):
12
+ self.obj = {'quiet': True, 'parse_more': True}
13
+
14
+ ctx = Context()
15
+
16
+ def debug_log(message: str):
17
+ print(f"[DEBUG] {message}")
18
+
19
+ def load_metadata(file_path: str) -> tuple:
20
+ try:
21
+ debug_log(f"Loading file: {file_path}")
22
+
23
+ if not file_path:
24
+ return {"status": "Awaiting input"}, {}, "", "", ""
25
+
26
+ old_stdout = sys.stdout
27
+ sys.stdout = buffer = io.StringIO()
28
+ exit_code = PrintMetadata(ctx.obj, file_path.name)
29
+ sys.stdout = old_stdout
30
+
31
+ metadata_str = buffer.getvalue().strip()
32
+
33
+ if exit_code != 0:
34
+ error_msg = f"Error code {exit_code}"
35
+ return {"error": error_msg}, {}, "", error_msg, ""
36
+
37
+ try:
38
+ full_metadata = json.loads(metadata_str)
39
+ except json.JSONDecodeError:
40
+ error_msg = "Invalid metadata structure"
41
+ return {"error": error_msg}, {}, "", error_msg, ""
42
+
43
+ training_params = full_metadata.get("__metadata__", {})
44
+ key_metrics = {
45
+ key: training_params.get(key, "N/A")
46
+ for key in [
47
+ "ss_optimizer", "ss_num_epochs", "ss_unet_lr",
48
+ "ss_text_encoder_lr", "ss_steps"
49
+ ]
50
+ }
51
+
52
+ return full_metadata, key_metrics, json.dumps(full_metadata, indent=2), "", file_path.name
53
+
54
+ except Exception as e:
55
+ return {"error": str(e)}, {}, "", str(e), ""
56
+
57
+ def validate_json(edited_json: str) -> tuple:
58
+ try:
59
+ return True, json.loads(edited_json), ""
60
+ except Exception as e:
61
+ return False, None, str(e)
62
+
63
+ def update_metadata(edited_json: str) -> tuple:
64
+ try:
65
+ modified_data = json.loads(edited_json)
66
+ metadata = modified_data.get("__metadata__", {})
67
+
68
+ key_fields = {
69
+ param: metadata.get(param, "N/A")
70
+ for param in [
71
+ "ss_optimizer", "ss_num_epochs", "ss_unet_lr",
72
+ "ss_text_encoder_lr", "ss_steps"
73
+ ]
74
+ }
75
+ return key_fields, modified_data, ""
76
+ except:
77
+ return gr.update(), gr.update(), ""
78
+
79
+ def save_metadata(edited_json: str, source_file: str, output_name: str) -> tuple:
80
+ debug_log("Initiating save process")
81
+ try:
82
+ if not source_file:
83
+ return None, "No source file provided"
84
+
85
+ is_valid, parsed_data, error = validate_json(edited_json)
86
+ if not is_valid:
87
+ return None, f"Validation error: {error}"
88
+
89
+ with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp:
90
+ json.dump(parsed_data, tmp, indent=2)
91
+ temp_path = tmp.name
92
+
93
+ source_path = Path(source_file)
94
+
95
+ if output_name.strip():
96
+ base_name = output_name.strip()
97
+ if not base_name.endswith(".safetensors"):
98
+ base_name += ".safetensors"
99
+ else:
100
+ base_name = f"{source_path.stem}_modified.safetensors"
101
+
102
+ output_path = Path(base_name)
103
+ version = 1
104
+ while output_path.exists():
105
+ output_path = Path(f"{source_path.stem}_modified_{version}.safetensors")
106
+ version += 1
107
+
108
+ cmd = [
109
+ sys.executable,
110
+ "safetensors_util.py",
111
+ "writemd",
112
+ source_file,
113
+ temp_path,
114
+ str(output_path),
115
+ "-f"
116
+ ]
117
+
118
+ result = subprocess.run(
119
+ cmd,
120
+ capture_output=True,
121
+ text=True,
122
+ check=False
123
+ )
124
+
125
+ Path(temp_path).unlink(missing_ok=True)
126
+
127
+ if result.returncode != 0:
128
+ error_msg = f"Save failure: {result.stderr}"
129
+ return None, error_msg
130
+
131
+ return str(output_path), ""
132
+
133
+ except Exception as e:
134
+ return None, f"Critical error: {str(e)}"
135
+
136
+ def create_interface():
137
+ with gr.Blocks(title="LoRA Metadata Editor") as app:
138
+ gr.Markdown("# LoRA Metadata Editor")
139
+
140
+ with gr.Tabs():
141
+ with gr.Tab("Metdata Viewer"):
142
+ gr.Markdown("### LoRa Upload")
143
+ file_input = gr.File(
144
+ file_types=[".safetensors"],
145
+ show_label=False
146
+ )
147
+
148
+ with gr.Row():
149
+ with gr.Column():
150
+ gr.Markdown("### Full Metadata")
151
+ full_viewer = gr.JSON(show_label=False)
152
+
153
+ with gr.Column():
154
+ gr.Markdown("### Key Metrics")
155
+ key_viewer = gr.JSON(show_label=False)
156
+
157
+ with gr.Tab("Edit Metadata"):
158
+ with gr.Row():
159
+ with gr.Column():
160
+ gr.Markdown("### JSON Workspace")
161
+ metadata_editor = gr.Textbox(
162
+ lines=25,
163
+ show_label=False,
164
+ placeholder="Edit metadata JSON here"
165
+ )
166
+ gr.Markdown("### Output Name")
167
+ filename_input = gr.Textbox(
168
+ placeholder="Leave empty for auto-naming",
169
+ show_label=False
170
+ )
171
+
172
+ with gr.Column():
173
+ gr.Markdown("### Live Preview")
174
+ modified_viewer = gr.JSON(show_label=False)
175
+ save_btn = gr.Button("💾 Save Metadata", variant="primary")
176
+ gr.Markdown("### Download Modified LoRa")
177
+ output_file = gr.File(
178
+ visible=False,
179
+ show_label=False
180
+ )
181
+
182
+ status_display = gr.HTML(visible=False)
183
+ source_tracker = gr.State()
184
+
185
+ file_input.upload(
186
+ load_metadata,
187
+ inputs=file_input,
188
+ outputs=[full_viewer, key_viewer, metadata_editor, status_display, source_tracker]
189
+ )
190
+
191
+ metadata_editor.change(
192
+ update_metadata,
193
+ inputs=metadata_editor,
194
+ outputs=[key_viewer, modified_viewer, status_display]
195
+ )
196
+
197
+ save_btn.click(
198
+ save_metadata,
199
+ inputs=[metadata_editor, source_tracker, filename_input],
200
+ outputs=[output_file, status_display],
201
+ ).then(
202
+ lambda x: gr.File(value=x, visible=True),
203
+ inputs=output_file,
204
+ outputs=output_file
205
+ )
206
+
207
+ return app
208
+
209
+ if __name__ == "__main__":
210
+ interface = create_interface()
211
+ interface.launch()
requirements.txt ADDED
Binary file (4.84 kB). View file
 
safetensors_file.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, json
2
+
3
+ class SafeTensorsException(Exception):
4
+ def __init__(self, msg:str):
5
+ self.msg=msg
6
+ super().__init__(msg)
7
+
8
+ @staticmethod
9
+ def invalid_file(filename:str,whatiswrong:str):
10
+ s=f"{filename} is not a valid .safetensors file: {whatiswrong}"
11
+ return SafeTensorsException(msg=s)
12
+
13
+ def __str__(self):
14
+ return self.msg
15
+
16
+ class SafeTensorsChunk:
17
+ def __init__(self,name:str,dtype:str,shape:list[int],offset0:int,offset1:int):
18
+ self.name=name
19
+ self.dtype=dtype
20
+ self.shape=shape
21
+ self.offset0=offset0
22
+ self.offset1=offset1
23
+
24
+ class SafeTensorsFile:
25
+ def __init__(self):
26
+ self.f=None #file handle
27
+ self.hdrbuf=None #header byte buffer
28
+ self.header=None #parsed header as a dict
29
+ self.error=0
30
+
31
+ def __del__(self):
32
+ self.close_file()
33
+
34
+ def __enter__(self):
35
+ return self
36
+
37
+ def __exit__(self, exc_type, exc_value, traceback):
38
+ self.close_file()
39
+
40
+ def close_file(self):
41
+ if self.f is not None:
42
+ self.f.close()
43
+ self.f=None
44
+ self.filename=""
45
+
46
+ #test file: duplicate_keys_in_header.safetensors
47
+ def _CheckDuplicateHeaderKeys(self):
48
+ def parse_object_pairs(pairs):
49
+ return [k for k,_ in pairs]
50
+
51
+ keys=json.loads(self.hdrbuf,object_pairs_hook=parse_object_pairs)
52
+ #print(keys)
53
+ d={}
54
+ for k in keys:
55
+ if k in d: d[k]=d[k]+1
56
+ else: d[k]=1
57
+ hasError=False
58
+ for k,v in d.items():
59
+ if v>1:
60
+ print(f"key {k} used {v} times in header",file=sys.stderr)
61
+ hasError=True
62
+ if hasError:
63
+ raise SafeTensorsException.invalid_file(self.filename,"duplicate keys in header")
64
+
65
+ @staticmethod
66
+ def open_file(filename:str,quiet=False,parseHeader=True):
67
+ s=SafeTensorsFile()
68
+ s.open(filename,quiet,parseHeader)
69
+ return s
70
+
71
+ def open(self,fn:str,quiet=False,parseHeader=True)->int:
72
+ st=os.stat(fn)
73
+ if st.st_size<8: #test file: zero_len_file.safetensors
74
+ raise SafeTensorsException.invalid_file(fn,"length less than 8 bytes")
75
+
76
+ f=open(fn,"rb")
77
+ b8=f.read(8) #read header size
78
+ if len(b8)!=8:
79
+ raise SafeTensorsException.invalid_file(fn,f"read only {len(b8)} bytes at start of file")
80
+ headerlen=int.from_bytes(b8,'little',signed=False)
81
+
82
+ if (8+headerlen>st.st_size): #test file: header_size_too_big.safetensors
83
+ raise SafeTensorsException.invalid_file(fn,"header extends past end of file")
84
+
85
+ if quiet==False:
86
+ print(f"{fn}: length={st.st_size}, header length={headerlen}")
87
+ hdrbuf=f.read(headerlen)
88
+ if len(hdrbuf)!=headerlen:
89
+ raise SafeTensorsException.invalid_file(fn,f"header size is {headerlen}, but read {len(hdrbuf)} bytes")
90
+ self.filename=fn
91
+ self.f=f
92
+ self.st=st
93
+ self.hdrbuf=hdrbuf
94
+ self.error=0
95
+ self.headerlen=headerlen
96
+ if parseHeader==True:
97
+ self._CheckDuplicateHeaderKeys()
98
+ self.header=json.loads(self.hdrbuf)
99
+ return 0
100
+
101
+ def get_header(self):
102
+ return self.header
103
+
104
+ def load_one_tensor(self,tensor_name:str):
105
+ self.get_header()
106
+ if tensor_name not in self.header: return None
107
+
108
+ t=self.header[tensor_name]
109
+ self.f.seek(8+self.headerlen+t['data_offsets'][0])
110
+ bytesToRead=t['data_offsets'][1]-t['data_offsets'][0]
111
+ bytes=self.f.read(bytesToRead)
112
+ if len(bytes)!=bytesToRead:
113
+ print(f"{tensor_name}: length={bytesToRead}, only read {len(bytes)} bytes",file=sys.stderr)
114
+ return bytes
115
+
116
+ def copy_data_to_file(self,file_handle) -> int:
117
+
118
+ self.f.seek(8+self.headerlen)
119
+ bytesLeft:int=self.st.st_size - 8 - self.headerlen
120
+ while bytesLeft>0:
121
+ chunklen:int=min(bytesLeft,int(16*1024*1024)) #copy in blocks of 16 MB
122
+ file_handle.write(self.f.read(chunklen))
123
+ bytesLeft-=chunklen
124
+
125
+ return 0
safetensors_util.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, click
2
+
3
+ import safetensors_worker
4
+ # This file deals with command line only. If the command line is parsed successfully,
5
+ # we will call one of the functions in safetensors_worker.py.
6
+
7
+ readonly_input_file=click.argument("input_file", metavar='input_file',
8
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True))
9
+ output_file=click.argument("output_file", metavar='output_file',
10
+ type=click.Path(file_okay=True, dir_okay=False, writable=True))
11
+
12
+ force_overwrite_flag=click.option("-f","--force-overwrite",default=False,is_flag=True, show_default=True,
13
+ help="overwrite existing files")
14
+ fix_ued_flag=click.option("-pm","--parse-more",default=False,is_flag=True, show_default=True,
15
+ help="when printing metadata, unescaped doublequotes to make text more readable" )
16
+ quiet_flag=click.option("-q","--quiet",default=False,is_flag=True, show_default=True,
17
+ help="Quiet mode, don't print informational stuff" )
18
+
19
+ @click.group()
20
+ @click.version_option(version=7)
21
+ @quiet_flag
22
+
23
+ @click.pass_context
24
+ def cli(ctx,quiet:bool):
25
+ # ensure that ctx.obj exists and is a dict (in case `cli()` is called
26
+ # by means other than the `if` block below)
27
+ ctx.ensure_object(dict)
28
+ ctx.obj['quiet'] = quiet
29
+
30
+
31
+ @cli.command(name="header",short_help="print file header")
32
+ @readonly_input_file
33
+ @click.pass_context
34
+ def cmd_header(ctx,input_file:str) -> int:
35
+ sys.exit( safetensors_worker.PrintHeader(ctx.obj,input_file) )
36
+
37
+
38
+ @cli.command(name="metadata",short_help="print only __metadata__ in file header")
39
+ @readonly_input_file
40
+ @fix_ued_flag
41
+ @click.pass_context
42
+ def cmd_meta(ctx,input_file:str,parse_more:bool)->int:
43
+ ctx.obj['parse_more'] = parse_more
44
+ sys.exit( safetensors_worker.PrintMetadata(ctx.obj,input_file) )
45
+
46
+
47
+ @cli.command(name="listkeys",short_help="print header key names (except __metadata__) as a Python list")
48
+ @readonly_input_file
49
+ @click.pass_context
50
+ def cmd_keyspy(ctx,input_file:str) -> int:
51
+ sys.exit( safetensors_worker.HeaderKeysToLists(ctx.obj,input_file) )
52
+
53
+
54
+ @cli.command(name="writemd",short_help="read __metadata__ from json and write to safetensors file")
55
+ @click.argument("in_st_file", metavar='input_st_file',
56
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True))
57
+ @click.argument("in_json_file", metavar='input_json_file',
58
+ type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True))
59
+ @output_file
60
+ @force_overwrite_flag
61
+ @click.pass_context
62
+ def cmd_writemd(ctx,in_st_file:str,in_json_file:str,output_file:str,force_overwrite:bool) -> int:
63
+ """Read "__metadata__" from json file and write to safetensors header"""
64
+ ctx.obj['force_overwrite'] = force_overwrite
65
+ sys.exit( safetensors_worker.WriteMetadataToHeader(ctx.obj,in_st_file,in_json_file,output_file) )
66
+
67
+
68
+ @cli.command(name="extracthdr",short_help="extract file header and save to output file")
69
+ @readonly_input_file
70
+ @output_file
71
+ @force_overwrite_flag
72
+ @click.pass_context
73
+ def cmd_extractheader(ctx,input_file:str,output_file:str,force_overwrite:bool) -> int:
74
+ ctx.obj['force_overwrite'] = force_overwrite
75
+ sys.exit( safetensors_worker.ExtractHeader(ctx.obj,input_file,output_file) )
76
+
77
+
78
+ @cli.command(name="extractdata",short_help="extract one tensor and save to file")
79
+ @readonly_input_file
80
+ @click.argument("key_name", metavar='key_name',type=click.STRING)
81
+ @output_file
82
+ @force_overwrite_flag
83
+ @click.pass_context
84
+ def cmd_extractheader(ctx,input_file:str,key_name:str,output_file:str,force_overwrite:bool) -> int:
85
+ ctx.obj['force_overwrite'] = force_overwrite
86
+ sys.exit( safetensors_worker.ExtractData(ctx.obj,input_file,key_name,output_file) )
87
+
88
+
89
+ @cli.command(name="checklora",short_help="see if input file is a SD 1.x LoRA file")
90
+ @readonly_input_file
91
+ @click.pass_context
92
+ def cmd_checklora(ctx,input_file:str)->int:
93
+ sys.exit( safetensors_worker.CheckLoRA(ctx.obj,input_file) )
94
+
95
+
96
+ if __name__ == '__main__':
97
+ sys.stdout.reconfigure(encoding='utf-8')
98
+ cli(obj={},max_content_width=96)
safetensors_worker.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, json
2
+ from safetensors_file import SafeTensorsFile
3
+
4
+ def _need_force_overwrite(output_file:str,cmdLine:dict) -> bool:
5
+ if cmdLine["force_overwrite"]==False:
6
+ if os.path.exists(output_file):
7
+ print(f'output file "{output_file}" already exists, use -f flag to force overwrite',file=sys.stderr)
8
+ return True
9
+ return False
10
+
11
+ def WriteMetadataToHeader(cmdLine:dict,in_st_file:str,in_json_file:str,output_file:str) -> int:
12
+ if _need_force_overwrite(output_file,cmdLine): return -1
13
+
14
+ with open(in_json_file,"rt") as f:
15
+ inmeta=json.load(f)
16
+ if not "__metadata__" in inmeta:
17
+ print(f"file {in_json_file} does not contain a top-level __metadata__ item",file=sys.stderr)
18
+ #json.dump(inmeta,fp=sys.stdout,indent=2)
19
+ return -2
20
+ inmeta=inmeta["__metadata__"] #keep only metadata
21
+ #json.dump(inmeta,fp=sys.stdout,indent=2)
22
+
23
+ s=SafeTensorsFile.open_file(in_st_file)
24
+ js=s.get_header()
25
+
26
+ if inmeta==[]:
27
+ js.pop("__metadata__",0)
28
+ print("loaded __metadata__ is an empty list, output file will not contain __metadata__ in header")
29
+ else:
30
+ print("adding __metadata__ to header:")
31
+ json.dump(inmeta,fp=sys.stdout,indent=2)
32
+ if isinstance(inmeta,dict):
33
+ for k in inmeta:
34
+ inmeta[k]=str(inmeta[k])
35
+ else:
36
+ inmeta=str(inmeta)
37
+ #js["__metadata__"]=json.dumps(inmeta,ensure_ascii=False)
38
+ js["__metadata__"]=inmeta
39
+ print()
40
+
41
+ newhdrbuf=json.dumps(js,separators=(',',':'),ensure_ascii=False).encode('utf-8')
42
+ newhdrlen:int=int(len(newhdrbuf))
43
+ pad:int=((newhdrlen+7)&(~7))-newhdrlen #pad to multiple of 8
44
+
45
+ with open(output_file,"wb") as f:
46
+ f.write(int(newhdrlen+pad).to_bytes(8,'little'))
47
+ f.write(newhdrbuf)
48
+ if pad>0: f.write(bytearray([32]*pad))
49
+ i:int=s.copy_data_to_file(f)
50
+ if i==0:
51
+ print(f"file {output_file} saved successfully")
52
+ else:
53
+ print(f"error {i} occurred when writing to file {output_file}")
54
+ return i
55
+
56
+ def PrintHeader(cmdLine:dict,input_file:str) -> int:
57
+ s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
58
+ js=s.get_header()
59
+
60
+ # All the .safetensors files I've seen have long key names, and as a result,
61
+ # neither json nor pprint package prints text in very readable format,
62
+ # so we print it ourselves, putting key name & value on one long line.
63
+ # Note the print out is in Python format, not valid JSON format.
64
+ firstKey=True
65
+ print("{")
66
+ for key in js:
67
+ if firstKey:
68
+ firstKey=False
69
+ else:
70
+ print(",")
71
+ json.dump(key,fp=sys.stdout,ensure_ascii=False,separators=(',',':'))
72
+ print(": ",end='')
73
+ json.dump(js[key],fp=sys.stdout,ensure_ascii=False,separators=(',',':'))
74
+ print("\n}")
75
+ return 0
76
+
77
+ def _ParseMore(d:dict):
78
+ '''Basically try to turn this:
79
+
80
+ "ss_dataset_dirs":"{\"abc\": {\"n_repeats\": 2, \"img_count\": 60}}",
81
+
82
+ into this:
83
+
84
+ "ss_dataset_dirs":{
85
+ "abc":{
86
+ "n_repeats":2,
87
+ "img_count":60
88
+ }
89
+ },
90
+
91
+ '''
92
+ for key in d:
93
+ value=d[key]
94
+ #print("+++",key,value,type(value),"+++",sep='|')
95
+ if isinstance(value,str):
96
+ try:
97
+ v2=json.loads(value)
98
+ d[key]=v2
99
+ value=v2
100
+ except json.JSONDecodeError as e:
101
+ pass
102
+ if isinstance(value,dict):
103
+ _ParseMore(value)
104
+
105
+ def PrintMetadata(cmdLine:dict,input_file:str) -> int:
106
+ with SafeTensorsFile.open_file(input_file,cmdLine['quiet']) as s:
107
+ js=s.get_header()
108
+
109
+ if not "__metadata__" in js:
110
+ print("file header does not contain a __metadata__ item",file=sys.stderr)
111
+ return -2
112
+
113
+ md=js["__metadata__"]
114
+ if cmdLine['parse_more']:
115
+ _ParseMore(md)
116
+ json.dump({"__metadata__":md},fp=sys.stdout,ensure_ascii=False,separators=(',',':'),indent=1)
117
+ return 0
118
+
119
+ def HeaderKeysToLists(cmdLine:dict,input_file:str) -> int:
120
+ s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
121
+ js=s.get_header()
122
+
123
+ _lora_keys:list[tuple(str,bool)]=[] # use list to sort by name
124
+ for key in js:
125
+ if key=='__metadata__': continue
126
+ v=js[key]
127
+ isScalar=False
128
+ if isinstance(v,dict):
129
+ if 'shape' in v:
130
+ if 0==len(v['shape']):
131
+ isScalar=True
132
+ _lora_keys.append((key,isScalar))
133
+ _lora_keys.sort(key=lambda x:x[0])
134
+
135
+ def printkeylist(kl):
136
+ firstKey=True
137
+ for key in kl:
138
+ if firstKey: firstKey=False
139
+ else: print(",")
140
+ print(key,end='')
141
+ print()
142
+
143
+ print("# use list to keep insertion order")
144
+ print("_lora_keys:list[tuple[str,bool]]=[")
145
+ printkeylist(_lora_keys)
146
+ print("]")
147
+
148
+ return 0
149
+
150
+
151
+ def ExtractHeader(cmdLine:dict,input_file:str,output_file:str)->int:
152
+ if _need_force_overwrite(output_file,cmdLine): return -1
153
+
154
+ s=SafeTensorsFile.open_file(input_file,parseHeader=False)
155
+ if s.error!=0: return s.error
156
+
157
+ hdrbuf=s.hdrbuf
158
+ s.close_file() #close it in case user wants to write back to input_file itself
159
+ with open(output_file,"wb") as fo:
160
+ wn=fo.write(hdrbuf)
161
+ if wn!=len(hdrbuf):
162
+ print(f"write output file failed, tried to write {len(hdrbuf)} bytes, only wrote {wn} bytes",file=sys.stderr)
163
+ return -1
164
+ print(f"raw header saved to file {output_file}")
165
+ return 0
166
+
167
+
168
+ def _CheckLoRA_internal(s:SafeTensorsFile)->int:
169
+ import lora_keys_sd15 as lora_keys
170
+ js=s.get_header()
171
+ set_scalar=set()
172
+ set_nonscalar=set()
173
+ for x in lora_keys._lora_keys:
174
+ if x[1]==True: set_scalar.add(x[0])
175
+ else: set_nonscalar.add(x[0])
176
+
177
+ bad_unknowns:list[str]=[] # unrecognized keys
178
+ bad_scalars:list[str]=[] #bad scalar
179
+ bad_nonscalars:list[str]=[] #bad nonscalar
180
+ for key in js:
181
+ if key in set_nonscalar:
182
+ if js[key]['shape']==[]: bad_nonscalars.append(key)
183
+ set_nonscalar.remove(key)
184
+ elif key in set_scalar:
185
+ if js[key]['shape']!=[]: bad_scalars.append(key)
186
+ set_scalar.remove(key)
187
+ else:
188
+ if "__metadata__"!=key:
189
+ bad_unknowns.append(key)
190
+
191
+ hasError=False
192
+
193
+ if len(bad_unknowns)!=0:
194
+ print("INFO: unrecognized items:")
195
+ for x in bad_unknowns: print(" ",x)
196
+ #hasError=True
197
+
198
+ if len(set_scalar)>0:
199
+ print("missing scalar keys:")
200
+ for x in set_scalar: print(" ",x)
201
+ hasError=True
202
+ if len(set_nonscalar)>0:
203
+ print("missing nonscalar keys:")
204
+ for x in set_nonscalar: print(" ",x)
205
+ hasError=True
206
+
207
+ if len(bad_scalars)!=0:
208
+ print("keys expected to be scalar but are nonscalar:")
209
+ for x in bad_scalars: print(" ",x)
210
+ hasError=True
211
+
212
+ if len(bad_nonscalars)!=0:
213
+ print("keys expected to be nonscalar but are scalar:")
214
+ for x in bad_nonscalars: print(" ",x)
215
+ hasError=True
216
+
217
+ return (1 if hasError else 0)
218
+
219
+ def CheckLoRA(cmdLine:dict,input_file:str)->int:
220
+ s=SafeTensorsFile.open_file(input_file)
221
+ i:int=_CheckLoRA_internal(s)
222
+ if i==0: print("looks like an OK SD 1.x LoRA file")
223
+ return 0
224
+
225
+ def ExtractData(cmdLine:dict,input_file:str,key_name:str,output_file:str)->int:
226
+ if _need_force_overwrite(output_file,cmdLine): return -1
227
+
228
+ s=SafeTensorsFile.open_file(input_file,cmdLine['quiet'])
229
+ if s.error!=0: return s.error
230
+
231
+ bindata=s.load_one_tensor(key_name)
232
+ s.close_file() #close it just in case user wants to write back to input_file itself
233
+ if bindata is None:
234
+ print(f'key "{key_name}" not found in header (key names are case-sensitive)',file=sys.stderr)
235
+ return -1
236
+
237
+ with open(output_file,"wb") as fo:
238
+ wn=fo.write(bindata)
239
+ if wn!=len(bindata):
240
+ print(f"write output file failed, tried to write {len(bindata)} bytes, only wrote {wn} bytes",file=sys.stderr)
241
+ return -1
242
+ if cmdLine['quiet']==False: print(f"{key_name} saved to {output_file}, len={wn}")
243
+ return 0