Spaces:
Running
on
Zero
Running
on
Zero
import cv2 | |
import flow_vis | |
import gradio as gr | |
import numpy as np | |
import torch | |
from PIL import Image | |
import sys | |
import subprocess | |
import sys | |
import subprocess | |
import importlib | |
from functools import lru_cache | |
import spaces | |
try: | |
import uniception | |
import uniflowmatch | |
except: | |
subprocess.check_call(["/bin/bash", "./install_package.sh"]) | |
# Optional: explicitly reload sys.path (only needed if install path changed) | |
import site | |
importlib.invalidate_caches() | |
site.main() # reload site-packages | |
# Now try importing again | |
uniception = importlib.import_module("uniception") | |
uniflowmatch = importlib.import_module("uniflowmatch") | |
from uniflowmatch.models.ufm import ( | |
UniFlowMatchClassificationRefinement, | |
UniFlowMatchConfidence, | |
) | |
from uniflowmatch.utils.viz import warp_image_with_flow | |
# Global model variable | |
USE_REFINEMENT_MODEL = False | |
def initialize_model(use_refinement: bool = False): | |
"""Initialize the model - call this once at startup""" | |
try: | |
if use_refinement: | |
print("Loading UFM Refinement model from infinity1096/UFM-Refine...") | |
model_obj = UniFlowMatchClassificationRefinement.from_pretrained("infinity1096/UFM-Refine") | |
else: | |
print("Loading UFM Base model from infinity1096/UFM-Base...") | |
model_obj = UniFlowMatchConfidence.from_pretrained("infinity1096/UFM-Base") | |
# Set model to evaluation mode | |
if hasattr(model_obj, "eval"): | |
model_obj.eval() | |
print("Model loaded successfully!") | |
return True, model_obj | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
return False, None | |
def process_images(source_image, target_image, model_type_choice): | |
""" | |
Process two uploaded images and return visualizations | |
""" | |
if source_image is None or target_image is None: | |
return None, None, None, "Please upload both images." | |
# Reinitialize model if type has changed | |
current_refinement = model_type_choice == "Refinement Model" | |
print(f"Switching to {model_type_choice}...") | |
ret, model = initialize_model(current_refinement) | |
if model is None: | |
return None, None, None, "Model not loaded. Please restart the application." | |
model = model.to("cuda" if torch.cuda.is_available() else "cpu") | |
use_gpu = torch.cuda.is_available() | |
try: | |
# Convert PIL images to numpy arrays | |
source_np = np.array(source_image) | |
target_np = np.array(target_image) | |
# Ensure images are RGB | |
if len(source_np.shape) == 3 and source_np.shape[2] == 3: | |
source_rgb = source_np | |
else: | |
source_rgb = cv2.cvtColor(source_np, cv2.COLOR_BGR2RGB) | |
if len(target_np.shape) == 3 and target_np.shape[2] == 3: | |
target_rgb = target_np | |
else: | |
target_rgb = cv2.cvtColor(target_np, cv2.COLOR_BGR2RGB) | |
print(f"Processing images with shapes: Source {source_rgb.shape}, Target {target_rgb.shape}") | |
# === Predict Correspondences === | |
with torch.no_grad(): | |
result = model.predict_correspondences_batched( | |
source_image=torch.from_numpy(source_rgb).to("cuda" if use_gpu else "cpu"), | |
target_image=torch.from_numpy(target_rgb).to("cuda" if use_gpu else "cpu"), | |
) | |
# Extract results based on your model's output structure | |
flow_output = result.flow.flow_output[0].cpu().numpy() | |
covisibility = result.covisibility.mask[0].cpu().numpy() | |
print(f"Flow output shape: {flow_output.shape}") | |
print(f"Covisibility shape: {covisibility.shape}") | |
# === Create Visualizations === | |
# 1. Flow visualization | |
flow_vis_image = flow_vis.flow_to_color(flow_output.transpose(1, 2, 0)) | |
flow_pil = Image.fromarray(flow_vis_image.astype(np.uint8)) | |
# 2. Covisibility visualization - direct gray image | |
covisibility_gray = (covisibility * 255).astype(np.uint8) | |
covisibility_pil = Image.fromarray(covisibility_gray, mode="L") | |
# 3. Warped image using actual warp function | |
warped_image = warp_image_with_flow(source_rgb, None, target_rgb, flow_output.transpose(1, 2, 0)) | |
warped_image = covisibility[..., None] * warped_image + (1 - covisibility[..., None]) * 255 * np.ones_like( | |
warped_image | |
) | |
warped_image = (warped_image / 255.0).clip(0, 1) | |
warped_pil = Image.fromarray((warped_image * 255).astype(np.uint8)) | |
status_msg = f"Processing completed with {model_type_choice}" | |
return flow_pil, covisibility_pil, warped_pil, status_msg | |
except Exception as e: | |
error_msg = f"Error processing images: {str(e)}" | |
print(error_msg) | |
return None, None, None, error_msg | |
def create_demo(): | |
"""Create the Gradio interface""" | |
with gr.Blocks(title="UniFlowMatch Demo") as demo: | |
gr.Markdown("# UniFlowMatch Demo") | |
gr.Markdown("Upload two images to see optical flow visualization") | |
# Input section | |
with gr.Row(): | |
source_input = gr.Image(label="Source Image", type="pil") | |
target_input = gr.Image(label="Target Image", type="pil") | |
# Model selection | |
model_type = gr.Radio(choices=["Base Model", "Refinement Model"], value="Base Model", label="Model Type") | |
# Process button | |
process_btn = gr.Button("Process Images") | |
# Status | |
status_output = gr.Textbox(label="Status", interactive=False) | |
# Output section | |
with gr.Row(): | |
flow_output = gr.Image(label="Flow Visualization") | |
covisibility_output = gr.Image(label="Covisibility Mask") | |
warped_output = gr.Image(label="Warped Target Image") | |
# Example images | |
gr.Examples( | |
examples=[ | |
["examples/image_pairs/fire_academy_0.png", "examples/image_pairs/fire_academy_1.png"], | |
["examples/image_pairs/scene_0.png", "examples/image_pairs/scene_1.png"], | |
["examples/image_pairs/bike_0.png", "examples/image_pairs/bike_1.png"], | |
["examples/image_pairs/cook_0.png", "examples/image_pairs/cook_1.png"], | |
["examples/image_pairs/building_0.png", "examples/image_pairs/building_1.png"], | |
], | |
inputs=[source_input, target_input], | |
label="Example Image Pairs", | |
) | |
# Event handlers | |
process_btn.click( | |
fn=process_images, | |
inputs=[source_input, target_input, model_type], | |
outputs=[flow_output, covisibility_output, warped_output, status_output], | |
) | |
# Auto-process when both images are uploaded | |
def auto_process(source, target, model_choice): | |
if source is not None and target is not None: | |
return process_images(source, target, model_choice) | |
return None, None, None, "Upload both images to start processing." | |
for input_component in [source_input, target_input, model_type]: | |
input_component.change( | |
fn=auto_process, | |
inputs=[source_input, target_input, model_type], | |
outputs=[flow_output, covisibility_output, warped_output, status_output], | |
) | |
return demo | |
if __name__ == "__main__": | |
# Initialize model | |
print("Initializing UniFlowMatch model...") | |
model_loaded = initialize_model(use_refinement=False) # Start with base model | |
if not model_loaded: | |
print("Error: Model failed to load. Please check your model installation and HuggingFace access.") | |
print("Make sure you have:") | |
print("1. Installed uniflowmatch package") | |
print("2. Have internet access for downloading pretrained models") | |
print("3. All required dependencies installed") | |
exit(1) | |
# Create and launch demo | |
demo = create_demo() | |
demo.launch( | |
share=True, # Set to True to create a public link | |
server_name="0.0.0.0", # Allow external connections | |
server_port=7860, # Default Gradio port | |
show_error=True, | |
) | |