| import gradio as gr | |
| import json | |
| import sys | |
| import io | |
| import subprocess | |
| import tempfile | |
| from pathlib import Path | |
| from safetensors_worker import PrintMetadata | |
| class Context: | |
| def __init__(self): | |
| self.obj = {'quiet': True, 'parse_more': True} | |
| ctx = Context() | |
| def debug_log(message: str): | |
| print(f"[DEBUG] {message}") | |
| def load_metadata(file_path: str) -> tuple: | |
| try: | |
| debug_log(f"Loading file: {file_path}") | |
| if not file_path: | |
| return {"status": "Awaiting input"}, {}, "", "", "" | |
| old_stdout = sys.stdout | |
| sys.stdout = buffer = io.StringIO() | |
| exit_code = PrintMetadata(ctx.obj, file_path.name) | |
| sys.stdout = old_stdout | |
| metadata_str = buffer.getvalue().strip() | |
| if exit_code != 0: | |
| error_msg = f"Error code {exit_code}" | |
| return {"error": error_msg}, {}, "", error_msg, "" | |
| try: | |
| full_metadata = json.loads(metadata_str) | |
| except json.JSONDecodeError: | |
| error_msg = "Invalid metadata structure" | |
| return {"error": error_msg}, {}, "", error_msg, "" | |
| training_params = full_metadata.get("__metadata__", {}) | |
| key_metrics = { | |
| key: training_params.get(key, "N/A") | |
| for key in [ | |
| "ss_optimizer", "ss_num_epochs", "ss_unet_lr", | |
| "ss_text_encoder_lr", "ss_steps" | |
| ] | |
| } | |
| return full_metadata, key_metrics, json.dumps(full_metadata, indent=2), "", file_path.name | |
| except Exception as e: | |
| return {"error": str(e)}, {}, "", str(e), "" | |
| def validate_json(edited_json: str) -> tuple: | |
| try: | |
| return True, json.loads(edited_json), "" | |
| except Exception as e: | |
| return False, None, str(e) | |
| def update_metadata(edited_json: str) -> tuple: | |
| try: | |
| modified_data = json.loads(edited_json) | |
| metadata = modified_data.get("__metadata__", {}) | |
| key_fields = { | |
| param: metadata.get(param, "N/A") | |
| for param in [ | |
| "ss_optimizer", "ss_num_epochs", "ss_unet_lr", | |
| "ss_text_encoder_lr", "ss_steps" | |
| ] | |
| } | |
| return key_fields, modified_data, "" | |
| except: | |
| return gr.update(), gr.update(), "" | |
| def save_metadata(edited_json: str, source_file: str, output_name: str) -> tuple: | |
| debug_log("Initiating save process") | |
| try: | |
| if not source_file: | |
| return None, "No source file provided" | |
| is_valid, parsed_data, error = validate_json(edited_json) | |
| if not is_valid: | |
| return None, f"Validation error: {error}" | |
| with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as tmp: | |
| json.dump(parsed_data, tmp, indent=2) | |
| temp_path = tmp.name | |
| source_path = Path(source_file) | |
| if output_name.strip(): | |
| base_name = output_name.strip() | |
| if not base_name.endswith(".safetensors"): | |
| base_name += ".safetensors" | |
| else: | |
| base_name = f"{source_path.stem}_modified.safetensors" | |
| output_path = Path(base_name) | |
| version = 1 | |
| while output_path.exists(): | |
| output_path = Path(f"{source_path.stem}_modified_{version}.safetensors") | |
| version += 1 | |
| cmd = [ | |
| sys.executable, | |
| "safetensors_util.py", | |
| "writemd", | |
| source_file, | |
| temp_path, | |
| str(output_path), | |
| "-f" | |
| ] | |
| result = subprocess.run( | |
| cmd, | |
| capture_output=True, | |
| text=True, | |
| check=False | |
| ) | |
| Path(temp_path).unlink(missing_ok=True) | |
| if result.returncode != 0: | |
| error_msg = f"Save failure: {result.stderr}" | |
| return None, error_msg | |
| return str(output_path), "" | |
| except Exception as e: | |
| return None, f"Critical error: {str(e)}" | |
| def create_interface(): | |
| with gr.Blocks(title="LoRA Metadata Editor") as app: | |
| gr.Markdown("# LoRA Metadata Editor") | |
| with gr.Tabs(): | |
| with gr.Tab("Metdata Viewer"): | |
| gr.Markdown("### LoRa Upload") | |
| file_input = gr.File( | |
| file_types=[".safetensors"], | |
| show_label=False | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### Full Metadata") | |
| full_viewer = gr.JSON(show_label=False) | |
| with gr.Column(): | |
| gr.Markdown("### Key Metrics") | |
| key_viewer = gr.JSON(show_label=False) | |
| with gr.Tab("Edit Metadata"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("### JSON Workspace") | |
| metadata_editor = gr.Textbox( | |
| lines=25, | |
| show_label=False, | |
| placeholder="Edit metadata JSON here" | |
| ) | |
| gr.Markdown("### Output Name") | |
| filename_input = gr.Textbox( | |
| placeholder="Leave empty for auto-naming", | |
| show_label=False | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### Live Preview") | |
| modified_viewer = gr.JSON(show_label=False) | |
| save_btn = gr.Button("💾 Save Metadata", variant="primary") | |
| gr.Markdown("### Download Modified LoRa") | |
| output_file = gr.File( | |
| visible=False, | |
| show_label=False | |
| ) | |
| status_display = gr.HTML(visible=False) | |
| source_tracker = gr.State() | |
| file_input.upload( | |
| load_metadata, | |
| inputs=file_input, | |
| outputs=[full_viewer, key_viewer, metadata_editor, status_display, source_tracker] | |
| ) | |
| metadata_editor.change( | |
| update_metadata, | |
| inputs=metadata_editor, | |
| outputs=[key_viewer, modified_viewer, status_display] | |
| ) | |
| save_btn.click( | |
| save_metadata, | |
| inputs=[metadata_editor, source_tracker, filename_input], | |
| outputs=[output_file, status_display], | |
| ).then( | |
| lambda x: gr.File(value=x, visible=True), | |
| inputs=output_file, | |
| outputs=output_file | |
| ) | |
| return app | |
| if __name__ == "__main__": | |
| interface = create_interface() | |
| interface.launch() |