import cv2 import flow_vis import gradio as gr import numpy as np import torch from PIL import Image import sys import subprocess import sys import subprocess import importlib from functools import lru_cache import spaces try: import uniception import uniflowmatch except: subprocess.check_call(["/bin/bash", "./install_package.sh"]) # Optional: explicitly reload sys.path (only needed if install path changed) import site importlib.invalidate_caches() site.main() # reload site-packages # Now try importing again uniception = importlib.import_module("uniception") uniflowmatch = importlib.import_module("uniflowmatch") from uniflowmatch.models.ufm import ( UniFlowMatchClassificationRefinement, UniFlowMatchConfidence, ) from uniflowmatch.utils.viz import warp_image_with_flow # Global model variable USE_REFINEMENT_MODEL = False @lru_cache(maxsize=2) def initialize_model(use_refinement: bool = False): """Initialize the model - call this once at startup""" try: if use_refinement: print("Loading UFM Refinement model from infinity1096/UFM-Refine...") model_obj = UniFlowMatchClassificationRefinement.from_pretrained("infinity1096/UFM-Refine") else: print("Loading UFM Base model from infinity1096/UFM-Base...") model_obj = UniFlowMatchConfidence.from_pretrained("infinity1096/UFM-Base") # Set model to evaluation mode if hasattr(model_obj, "eval"): model_obj.eval() print("Model loaded successfully!") return True, model_obj except Exception as e: print(f"Error loading model: {e}") return False, None @spaces.GPU def process_images(source_image, target_image, model_type_choice): """ Process two uploaded images and return visualizations """ if source_image is None or target_image is None: return None, None, None, "Please upload both images." # Reinitialize model if type has changed current_refinement = model_type_choice == "Refinement Model" print(f"Switching to {model_type_choice}...") ret, model = initialize_model(current_refinement) if model is None: return None, None, None, "Model not loaded. Please restart the application." model = model.to("cuda" if torch.cuda.is_available() else "cpu") use_gpu = torch.cuda.is_available() try: # Convert PIL images to numpy arrays source_np = np.array(source_image) target_np = np.array(target_image) # Ensure images are RGB if len(source_np.shape) == 3 and source_np.shape[2] == 3: source_rgb = source_np else: source_rgb = cv2.cvtColor(source_np, cv2.COLOR_BGR2RGB) if len(target_np.shape) == 3 and target_np.shape[2] == 3: target_rgb = target_np else: target_rgb = cv2.cvtColor(target_np, cv2.COLOR_BGR2RGB) print(f"Processing images with shapes: Source {source_rgb.shape}, Target {target_rgb.shape}") # === Predict Correspondences === with torch.no_grad(): result = model.predict_correspondences_batched( source_image=torch.from_numpy(source_rgb).to("cuda" if use_gpu else "cpu"), target_image=torch.from_numpy(target_rgb).to("cuda" if use_gpu else "cpu"), ) # Extract results based on your model's output structure flow_output = result.flow.flow_output[0].cpu().numpy() covisibility = result.covisibility.mask[0].cpu().numpy() print(f"Flow output shape: {flow_output.shape}") print(f"Covisibility shape: {covisibility.shape}") # === Create Visualizations === # 1. Flow visualization flow_vis_image = flow_vis.flow_to_color(flow_output.transpose(1, 2, 0)) flow_pil = Image.fromarray(flow_vis_image.astype(np.uint8)) # 2. Covisibility visualization - direct gray image covisibility_gray = (covisibility * 255).astype(np.uint8) covisibility_pil = Image.fromarray(covisibility_gray, mode="L") # 3. Warped image using actual warp function warped_image = warp_image_with_flow(source_rgb, None, target_rgb, flow_output.transpose(1, 2, 0)) warped_image = covisibility[..., None] * warped_image + (1 - covisibility[..., None]) * 255 * np.ones_like( warped_image ) warped_image = (warped_image / 255.0).clip(0, 1) warped_pil = Image.fromarray((warped_image * 255).astype(np.uint8)) status_msg = f"Processing completed with {model_type_choice}" return flow_pil, covisibility_pil, warped_pil, status_msg except Exception as e: error_msg = f"Error processing images: {str(e)}" print(error_msg) return None, None, None, error_msg def create_demo(): """Create the Gradio interface""" with gr.Blocks(title="UniFlowMatch Demo") as demo: gr.Markdown("# UniFlowMatch Demo") gr.Markdown("Upload two images to see optical flow visualization") # Input section with gr.Row(): source_input = gr.Image(label="Source Image", type="pil") target_input = gr.Image(label="Target Image", type="pil") # Model selection model_type = gr.Radio(choices=["Base Model", "Refinement Model"], value="Base Model", label="Model Type") # Process button process_btn = gr.Button("Process Images") # Status status_output = gr.Textbox(label="Status", interactive=False) # Output section with gr.Row(): flow_output = gr.Image(label="Flow Visualization") covisibility_output = gr.Image(label="Covisibility Mask") warped_output = gr.Image(label="Warped Target Image") # Example images gr.Examples( examples=[ ["examples/image_pairs/fire_academy_0.png", "examples/image_pairs/fire_academy_1.png"], ["examples/image_pairs/scene_0.png", "examples/image_pairs/scene_1.png"], ["examples/image_pairs/bike_0.png", "examples/image_pairs/bike_1.png"], ["examples/image_pairs/cook_0.png", "examples/image_pairs/cook_1.png"], ["examples/image_pairs/building_0.png", "examples/image_pairs/building_1.png"], ], inputs=[source_input, target_input], label="Example Image Pairs", ) # Event handlers process_btn.click( fn=process_images, inputs=[source_input, target_input, model_type], outputs=[flow_output, covisibility_output, warped_output, status_output], ) # Auto-process when both images are uploaded def auto_process(source, target, model_choice): if source is not None and target is not None: return process_images(source, target, model_choice) return None, None, None, "Upload both images to start processing." for input_component in [source_input, target_input, model_type]: input_component.change( fn=auto_process, inputs=[source_input, target_input, model_type], outputs=[flow_output, covisibility_output, warped_output, status_output], ) return demo if __name__ == "__main__": # Initialize model print("Initializing UniFlowMatch model...") model_loaded = initialize_model(use_refinement=False) # Start with base model if not model_loaded: print("Error: Model failed to load. Please check your model installation and HuggingFace access.") print("Make sure you have:") print("1. Installed uniflowmatch package") print("2. Have internet access for downloading pretrained models") print("3. All required dependencies installed") exit(1) # Create and launch demo demo = create_demo() demo.launch( share=True, # Set to True to create a public link server_name="0.0.0.0", # Allow external connections server_port=7860, # Default Gradio port show_error=True, )