import gradio as gr import json import sys import io import subprocess import tempfile from pathlib import Path from safetensors_worker import PrintMetadata class Context: def __init__(self): self.obj = {'quiet': True, 'parse_more': True} ctx = Context() def debug_log(message: str): print(f"[DEBUG] {message}") def load_metadata(file_path: str) -> tuple: try: debug_log(f"Loading file: {file_path}") if not file_path: return {"status": "Awaiting input"}, {}, "", "", "" old_stdout = sys.stdout sys.stdout = buffer = io.StringIO() exit_code = PrintMetadata(ctx.obj, file_path.name) sys.stdout = old_stdout metadata_str = buffer.getvalue().strip() if exit_code != 0: error_msg = f"Error code {exit_code}" return {"error": error_msg}, {}, "", error_msg, "" try: full_metadata = json.loads(metadata_str) except json.JSONDecodeError: error_msg = "Invalid metadata structure" return {"error": error_msg}, {}, "", error_msg, "" training_params = full_metadata.get("__metadata__", {}) key_metrics = { key: training_params.get(key, "N/A") for key in [ "ss_optimizer", "ss_num_epochs", "ss_unet_lr", "ss_text_encoder_lr", "ss_steps" ] } return full_metadata, key_metrics, json.dumps(full_metadata, indent=2), "", file_path.name except Exception as e: return {"error": str(e)}, {}, "", str(e), "" def validate_json(edited_json: str) -> tuple: try: return True, json.loads(edited_json), "" except Exception as e: return False, None, str(e) def update_metadata(edited_json: str) -> tuple: try: modified_data = json.loads(edited_json) metadata = modified_data.get("__metadata__", {}) key_fields = { param: metadata.get(param, "N/A") for param in [ "ss_optimizer", "ss_num_epochs", "ss_unet_lr", "ss_text_encoder_lr", "ss_steps" ] } return key_fields, modified_data, "" except: return gr.update(), gr.update(), "" def save_metadata(edited_json: str, source_file: str, output_name: str) -> tuple: debug_log("Initiating save process") try: if not source_file: return None, "No source file provided" is_valid, parsed_data, error = validate_json(edited_json) if not is_valid: return None, f"Validation error: {error}" with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: json.dump(parsed_data, tmp, indent=2) temp_path = tmp.name source_path = Path(source_file) if output_name.strip(): base_name = output_name.strip() if not base_name.endswith(".safetensors"): base_name += ".safetensors" else: base_name = f"{source_path.stem}_modified.safetensors" output_path = Path(base_name) version = 1 while output_path.exists(): output_path = Path(f"{source_path.stem}_modified_{version}.safetensors") version += 1 cmd = [ sys.executable, "safetensors_util.py", "writemd", source_file, temp_path, str(output_path), "-f" ] result = subprocess.run( cmd, capture_output=True, text=True, check=False ) Path(temp_path).unlink(missing_ok=True) if result.returncode != 0: error_msg = f"Save failure: {result.stderr}" return None, error_msg return str(output_path), "" except Exception as e: return None, f"Critical error: {str(e)}" def create_interface(): with gr.Blocks(title="LoRA Metadata Editor") as app: gr.Markdown("# LoRA Metadata Editor") with gr.Tabs(): with gr.Tab("Metdata Viewer"): gr.Markdown("### LoRa Upload") file_input = gr.File( file_types=[".safetensors"], show_label=False ) with gr.Row(): with gr.Column(): gr.Markdown("### Full Metadata") full_viewer = gr.JSON(show_label=False) with gr.Column(): gr.Markdown("### Key Metrics") key_viewer = gr.JSON(show_label=False) with gr.Tab("Edit Metadata"): with gr.Row(): with gr.Column(): gr.Markdown("### JSON Workspace") metadata_editor = gr.Textbox( lines=25, show_label=False, placeholder="Edit metadata JSON here" ) gr.Markdown("### Output Name") filename_input = gr.Textbox( placeholder="Leave empty for auto-naming", show_label=False ) with gr.Column(): gr.Markdown("### Live Preview") modified_viewer = gr.JSON(show_label=False) save_btn = gr.Button("💾 Save Metadata", variant="primary") gr.Markdown("### Download Modified LoRa") output_file = gr.File( visible=False, show_label=False ) status_display = gr.HTML(visible=False) source_tracker = gr.State() file_input.upload( load_metadata, inputs=file_input, outputs=[full_viewer, key_viewer, metadata_editor, status_display, source_tracker] ) metadata_editor.change( update_metadata, inputs=metadata_editor, outputs=[key_viewer, modified_viewer, status_display] ) save_btn.click( save_metadata, inputs=[metadata_editor, source_tracker, filename_input], outputs=[output_file, status_display], ).then( lambda x: gr.File(value=x, visible=True), inputs=output_file, outputs=output_file ) return app if __name__ == "__main__": interface = create_interface() interface.launch()