import gradio as gr import subprocess from huggingface_hub import HfApi import spaces import shutil import logging @spaces.GPU def write_repo(base_model, model_to_merge): with open("repo.txt", "w") as repo: repo.write(base_model + "\n" + model_to_merge) def merge_and_upload(base_model, model_to_merge, scaling_factor, weight_drop_prob, repo_name): # Define a fixed output path output_path = "/tmp/output" # Write the base model and model to merge to the repo.txt file write_repo(base_model, model_to_merge) # Construct the command to run hf_merge.py command = [ "python3", "hf_merge.py", "-p", str(weight_drop_prob), "-lambda", str(scaling_factor), "repo.txt", output_path ] # Set up logging logging.basicConfig(level=logging.INFO) # Run the command and capture the output result = subprocess.run(command, capture_output=True, text=True) # Log the output logging.info(result.stdout) logging.error(result.stderr) # Check if the merge was successful if result.returncode!= 0: return f"Error in merging models: {result.stderr}" # Upload the result to Hugging Face Hub api = HfApi() try: # Create a new repo or update an existing one api.create_repo(repo_id=repo_name, token=None, exist_ok=True) # Upload the file api.upload_file( path_or_fileobj=output_path, path_in_repo="merged_model.safetensors", repo_id=repo_name, token=None ) return f"Model merged and uploaded successfully to {repo_name}!" except Exception as e: return f"Error uploading to Hugging Face Hub: {str(e)}" # Wipe /tmp repo after each use shutil.rmtree("/tmp") # Define the Gradio interface iface = gr.Interface( fn=merge_and_upload, inputs=[ gr.Textbox(label="Base Model"), gr.Textbox(label="Model to Merge"), gr.Slider(minimum=0, maximum=10, value=3.0, label="Scaling Factor"), gr.Slider(minimum=0, maximum=1, value=0.3, label="Weight Drop Probability"), gr.Textbox(label="Repo Name") ], outputs=gr.Textbox(label="Output"), title="Model Merger and Uploader", description="Merge two models using the Super Mario merge method and upload to Hugging Face Hub.", theme="dark", auth="huggingface" ) # Launch the interface iface.launch()