|
import gradio as gr |
|
import json |
|
import sys |
|
import io |
|
import subprocess |
|
import tempfile |
|
from pathlib import Path |
|
from safetensors_worker import PrintMetadata |
|
|
|
class Context: |
|
def __init__(self): |
|
self.obj = {'quiet': True, 'parse_more': True} |
|
|
|
ctx = Context() |
|
|
|
def debug_log(message: str): |
|
print(f"[DEBUG] {message}") |
|
|
|
def load_metadata(file_path: str) -> tuple: |
|
try: |
|
debug_log(f"Loading file: {file_path}") |
|
|
|
if not file_path: |
|
return {"status": "Awaiting input"}, {}, "", "", "" |
|
|
|
old_stdout = sys.stdout |
|
sys.stdout = buffer = io.StringIO() |
|
exit_code = PrintMetadata(ctx.obj, file_path.name) |
|
sys.stdout = old_stdout |
|
|
|
metadata_str = buffer.getvalue().strip() |
|
|
|
if exit_code != 0: |
|
error_msg = f"Error code {exit_code}" |
|
return {"error": error_msg}, {}, "", error_msg, "" |
|
|
|
try: |
|
full_metadata = json.loads(metadata_str) |
|
except json.JSONDecodeError: |
|
error_msg = "Invalid metadata structure" |
|
return {"error": error_msg}, {}, "", error_msg, "" |
|
|
|
training_params = full_metadata.get("__metadata__", {}) |
|
key_metrics = { |
|
key: training_params.get(key, "N/A") |
|
for key in [ |
|
"ss_optimizer", "ss_num_epochs", "ss_unet_lr", |
|
"ss_text_encoder_lr", "ss_steps" |
|
] |
|
} |
|
|
|
return full_metadata, key_metrics, json.dumps(full_metadata, indent=2), "", file_path.name |
|
|
|
except Exception as e: |
|
return {"error": str(e)}, {}, "", str(e), "" |
|
|
|
def validate_json(edited_json: str) -> tuple: |
|
try: |
|
return True, json.loads(edited_json), "" |
|
except Exception as e: |
|
return False, None, str(e) |
|
|
|
def update_metadata(edited_json: str) -> tuple: |
|
try: |
|
modified_data = json.loads(edited_json) |
|
metadata = modified_data.get("__metadata__", {}) |
|
|
|
key_fields = { |
|
param: metadata.get(param, "N/A") |
|
for param in [ |
|
"ss_optimizer", "ss_num_epochs", "ss_unet_lr", |
|
"ss_text_encoder_lr", "ss_steps" |
|
] |
|
} |
|
return key_fields, modified_data, "" |
|
except: |
|
return gr.update(), gr.update(), "" |
|
|
|
def save_metadata(edited_json: str, source_file: str, output_name: str) -> tuple: |
|
debug_log("Initiating save process") |
|
try: |
|
if not source_file: |
|
return None, "No source file provided" |
|
|
|
is_valid, parsed_data, error = validate_json(edited_json) |
|
if not is_valid: |
|
return None, f"Validation error: {error}" |
|
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: |
|
json.dump(parsed_data, tmp, indent=2) |
|
temp_path = tmp.name |
|
|
|
source_path = Path(source_file) |
|
|
|
if output_name.strip(): |
|
base_name = output_name.strip() |
|
if not base_name.endswith(".safetensors"): |
|
base_name += ".safetensors" |
|
else: |
|
base_name = f"{source_path.stem}_modified.safetensors" |
|
|
|
output_path = Path(base_name) |
|
version = 1 |
|
while output_path.exists(): |
|
output_path = Path(f"{source_path.stem}_modified_{version}.safetensors") |
|
version += 1 |
|
|
|
cmd = [ |
|
sys.executable, |
|
"safetensors_util.py", |
|
"writemd", |
|
source_file, |
|
temp_path, |
|
str(output_path), |
|
"-f" |
|
] |
|
|
|
result = subprocess.run( |
|
cmd, |
|
capture_output=True, |
|
text=True, |
|
check=False |
|
) |
|
|
|
Path(temp_path).unlink(missing_ok=True) |
|
|
|
if result.returncode != 0: |
|
error_msg = f"Save failure: {result.stderr}" |
|
return None, error_msg |
|
|
|
return str(output_path), "" |
|
|
|
except Exception as e: |
|
return None, f"Critical error: {str(e)}" |
|
|
|
def create_interface(): |
|
with gr.Blocks(title="LoRA Metadata Editor") as app: |
|
gr.Markdown("# LoRA Metadata Editor") |
|
|
|
with gr.Tabs(): |
|
with gr.Tab("Metdata Viewer"): |
|
gr.Markdown("### LoRa Upload") |
|
file_input = gr.File( |
|
file_types=[".safetensors"], |
|
show_label=False |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Full Metadata") |
|
full_viewer = gr.JSON(show_label=False) |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Key Metrics") |
|
key_viewer = gr.JSON(show_label=False) |
|
|
|
with gr.Tab("Edit Metadata"): |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### JSON Workspace") |
|
metadata_editor = gr.Textbox( |
|
lines=25, |
|
show_label=False, |
|
placeholder="Edit metadata JSON here" |
|
) |
|
gr.Markdown("### Output Name") |
|
filename_input = gr.Textbox( |
|
placeholder="Leave empty for auto-naming", |
|
show_label=False |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Live Preview") |
|
modified_viewer = gr.JSON(show_label=False) |
|
save_btn = gr.Button("💾 Save Metadata", variant="primary") |
|
gr.Markdown("### Download Modified LoRa") |
|
output_file = gr.File( |
|
visible=False, |
|
show_label=False |
|
) |
|
|
|
status_display = gr.HTML(visible=False) |
|
source_tracker = gr.State() |
|
|
|
file_input.upload( |
|
load_metadata, |
|
inputs=file_input, |
|
outputs=[full_viewer, key_viewer, metadata_editor, status_display, source_tracker] |
|
) |
|
|
|
metadata_editor.change( |
|
update_metadata, |
|
inputs=metadata_editor, |
|
outputs=[key_viewer, modified_viewer, status_display] |
|
) |
|
|
|
save_btn.click( |
|
save_metadata, |
|
inputs=[metadata_editor, source_tracker, filename_input], |
|
outputs=[output_file, status_display], |
|
).then( |
|
lambda x: gr.File(value=x, visible=True), |
|
inputs=output_file, |
|
outputs=output_file |
|
) |
|
|
|
return app |
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.launch() |