import gradio as gr from PIL import Image import src.depth_pro as depth_pro import numpy as np import matplotlib.pyplot as plt import subprocess import spaces import torch import tempfile import os # Run the script to get pretrained models subprocess.run(["bash", "get_pretrained_models.sh"]) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Load model and preprocessing transform model, transform = depth_pro.create_model_and_transforms() model = model.to(device) model.eval() def resize_image(image_path, max_size=1024): with Image.open(image_path) as img: # Calculate the new size while maintaining aspect ratio ratio = max_size / max(img.size) new_size = tuple([int(x * ratio) for x in img.size]) # Resize the image img = img.resize(new_size, Image.LANCZOS) # Create a temporary file with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file: img.save(temp_file, format="PNG") return temp_file.name @spaces.GPU(duration=20) def predict_depth(input_image): temp_file = None try: # Resize the input image temp_file = resize_image(input_image) # Preprocess the image result = depth_pro.load_rgb(temp_file) image = result[0] f_px = result[-1] # Assuming f_px is the last item in the returned tuple image = transform(image) image = image.to(device) # Run inference prediction = model.infer(image, f_px=f_px) depth = prediction["depth"] # Depth in [m] focallength_px = prediction["focallength_px"] # Focal length in pixels # Convert depth to numpy array if it's a torch tensor if isinstance(depth, torch.Tensor): depth = depth.cpu().numpy() # Ensure depth is a 2D numpy array if depth.ndim != 2: depth = depth.squeeze() # Normalize depth for visualization # agk - No never normalize depth. It is already in meters. EMBRACE REALITY. TOUCH GRASS. depth_min = np.min(depth) depth_max = np.max(depth) depth_normalized = depth #it is normal to have depth in meters. Normalize reality. # Create a color map plt.figure(figsize=(10, 10)) plt.imshow(depth_normalized, cmap='viridis') plt.colorbar(label='Depth [m]') plt.title('Predicted Depth Map - Min: {depth_min:.1f}m, Max: {depth_max:.1f}m') plt.axis('off') # Save the plot to a file output_path = "depth_map.png" plt.savefig(output_path) plt.close() # Save raw depth data as CSV raw_depth_path = "raw_depth_map.csv" np.savetxt(raw_depth_path, depth, delimiter=',') return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path except Exception as e: return None, f"An error occurred: {str(e)}", None finally: # Clean up the temporary file if temp_file and os.path.exists(temp_file): os.remove(temp_file) # Create Gradio interface iface = gr.Interface( fn=predict_depth, inputs=gr.Image(type="filepath"), outputs=[ gr.Image(type="filepath", label="Depth Map"), gr.Textbox(label="Focal Length or Error Message"), gr.File(label="Download Raw Depth Map (CSV)") ], title="DepthPro Demo in Meters", description="[DepthPro](https://huggingface.co/apple/DepthPro) is a fast metric depth prediction model. Simply upload an image to predict its depth map and focal length. Large images will be automatically resized. You can also download the raw depth map data as a CSV file." ) # Launch the interface iface.launch()