File size: 12,372 Bytes
cf8c487
 
41a69fa
cf8c487
 
dc9d69c
 
404967f
ee04d83
 
f95c546
40c89eb
 
4bfe855
 
 
40c89eb
 
 
024a2b8
f95c546
dc9d69c
cf8c487
f95c546
764b436
 
f95c546
cf8c487
f95c546
 
cf8c487
501d06f
f95c546
 
 
 
 
 
 
 
 
 
 
501d06f
f95c546
501d06f
 
 
f95c546
501d06f
 
f95c546
ee04d83
 
 
501d06f
f9c3dad
f95c546
 
 
 
 
b3b839e
5784c10
 
 
 
b3b839e
 
 
 
f95c546
 
b3b839e
 
 
f95c546
5784c10
f95c546
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2b69b2a
 
 
 
f9c3dad
8eda443
2b69b2a
 
 
 
 
 
 
 
 
 
f9c3dad
 
2b69b2a
 
 
f9c3dad
2b69b2a
 
84a6150
 
 
 
 
 
 
f95c546
2b69b2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f9c3dad
 
 
 
 
 
 
 
 
 
 
 
764b436
cf8c487
ee04d83
404967f
29a026e
 
 
f95c546
ee04d83
29a026e
501d06f
f95c546
ee04d83
29a026e
 
 
 
200e2c3
 
755366f
 
29a026e
 
b3b839e
404967f
f95c546
404967f
f95c546
404967f
 
f95c546
404967f
 
 
f95c546
404967f
 
 
b3b839e
f95c546
 
404967f
b3b839e
182cf21
b3b839e
f95c546
 
 
404967f
 
 
 
f95c546
182cf21
 
 
f95c546
84a6150
f95c546
f9c3dad
404967f
f95c546
29a026e
 
 
f9c3dad
ee04d83
f95c546
ee04d83
 
cf8c487
c6f3d95
 
 
 
b3b839e
 
 
c6f3d95
d893f72
4bfe855
 
f9c3dad
 
 
f95c546
4bfe855
f95c546
 
 
 
84a6150
f9c3dad
 
4bfe855
f9c3dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf8c487
f95c546
f9c3dad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
import gradio as gr
from PIL import Image
import src.depth_pro as depth_pro
import numpy as np
import matplotlib.pyplot as plt
import subprocess
import spaces
import torch
import tempfile
import os
import trimesh
import time
import timm  # Add this import
import subprocess
import cv2  # Add this import
from datetime import datetime

# Ensure timm is properly loaded
print(f"Timm version: {timm.__version__}")

# Run the script to download pretrained models
subprocess.run(["bash", "get_pretrained_models.sh"])

# Set the device to GPU if available, else CPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Load the depth prediction model and its preprocessing transforms
model, transform = depth_pro.create_model_and_transforms()
model = model.to(device)  # Move the model to the selected device
model.eval()  # Set the model to evaluation mode

def resize_image(image_path, max_size=1024):
    """
    Resize the input image to ensure its largest dimension does not exceed max_size.
    Maintains the aspect ratio and saves the resized image as a temporary PNG file.

    Args:
        image_path (str): Path to the input image.
        max_size (int, optional): Maximum size for the largest dimension. Defaults to 1024.

    Returns:
        str: Path to the resized temporary image file.
    """
    with Image.open(image_path) as img:
        # Calculate the resizing ratio while maintaining aspect ratio
        ratio = max_size / max(img.size)
        new_size = tuple([int(x * ratio) for x in img.size])
        
        # Resize the image using LANCZOS filter for high-quality downsampling
        img = img.resize(new_size, Image.LANCZOS)
        
        # Save the resized image to a temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as temp_file:
            img.save(temp_file, format="PNG")
            return temp_file.name

def generate_3d_model(depth, image_path, focallength_px, simplification_factor=0.8, smoothing_iterations=1, thin_threshold=0.01):
    """
    Generate a textured 3D mesh from the depth map and the original image.
    """
    # Load the RGB image and convert to a NumPy array
    image = np.array(Image.open(image_path))
    
    # Ensure depth is a NumPy array
    if isinstance(depth, torch.Tensor):
        depth = depth.cpu().numpy()
    
    # Resize depth to match image dimensions if necessary
    if depth.shape != image.shape[:2]:
        depth = cv2.resize(depth, (image.shape[1], image.shape[0]), interpolation=cv2.INTER_LINEAR)
    
    height, width = depth.shape

    print(f"3D model generation - Depth shape: {depth.shape}")
    print(f"3D model generation - Image shape: {image.shape}")

    # Compute camera intrinsic parameters
    fx = fy = float(focallength_px)  # Ensure focallength_px is a float
    cx, cy = width / 2, height / 2  # Principal point at the image center

    # Create a grid of (u, v) pixel coordinates
    u = np.arange(0, width)
    v = np.arange(0, height)
    uu, vv = np.meshgrid(u, v)

    # Convert pixel coordinates to real-world 3D coordinates using the pinhole camera model
    Z = depth.flatten()
    X = ((uu.flatten() - cx) * Z) / fx
    Y = ((vv.flatten() - cy) * Z) / fy

    # Stack the coordinates to form vertices (X, Y, Z)
    vertices = np.vstack((X, Y, Z)).T

    # Normalize RGB colors to [0, 1] for vertex coloring
    colors = image.reshape(-1, 3) / 255.0

    # Generate faces by connecting adjacent vertices to form triangles
    faces = []
    for i in range(height - 1):
        for j in range(width - 1):
            idx = i * width + j
            # Triangle 1
            faces.append([idx, idx + width, idx + 1])
            # Triangle 2
            faces.append([idx + 1, idx + width, idx + width + 1])
    faces = np.array(faces)

    # Create the mesh using Trimesh with vertex colors
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_colors=colors)

    # Mesh cleaning and improvement steps
    print("Original mesh - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))

    # 1. Mesh simplification
    target_faces = int(len(mesh.faces) * simplification_factor)
    mesh = mesh.simplify_quadric_decimation(face_count=target_faces)
    print("After simplification - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))

    # 2. Remove small disconnected components
    components = mesh.split(only_watertight=False)
    if len(components) > 1:
        areas = np.array([c.area for c in components])
        mesh = components[np.argmax(areas)]
        print("After removing small components - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))

    # 3. Smooth the mesh
    for _ in range(smoothing_iterations):
        mesh = mesh.smoothed()
    print("After smoothing - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))

    # 4. Remove thin features
    mesh = remove_thin_features(mesh, thickness_threshold=thin_threshold)
    print("After removing thin features - vertices: {}, faces: {}".format(len(mesh.vertices), len(mesh.faces)))

    # Export the mesh to OBJ files with unique filenames
    timestamp = int(time.time())
    view_model_path = f'view_model_{timestamp}.obj'
    download_model_path = f'download_model_{timestamp}.obj'
    mesh.export(view_model_path)
    mesh.export(download_model_path)
    return view_model_path, download_model_path

def remove_thin_features(mesh, thickness_threshold=0.01):
    """
    Remove thin features from the mesh.
    """
    # Calculate edge lengths
    edges = mesh.edges_unique
    edge_points = mesh.vertices[edges]
    edge_lengths = np.linalg.norm(edge_points[:, 0] - edge_points[:, 1], axis=1)
    
    # Identify short edges
    short_edges = edges[edge_lengths < thickness_threshold]
    
    # Collapse short edges
    for edge in short_edges:
        try:
            mesh.collapse_edge(edge)
        except:
            pass  # Skip if edge collapse fails
    
    # Remove any newly created degenerate faces
    mesh.remove_degenerate_faces()
    
    return mesh

def regenerate_3d_model(depth_csv, image_path, focallength_px, simplification_factor, smoothing_iterations, thin_threshold):
    # Load depth from CSV
    depth = np.loadtxt(depth_csv, delimiter=',')
    
    # Generate new 3D model with updated parameters
    view_model_path, download_model_path = generate_3d_model(
        depth, image_path, focallength_px, 
        simplification_factor, smoothing_iterations, thin_threshold
    )
    
    return view_model_path, download_model_path

@spaces.GPU(duration=20)
def predict_depth(input_image):
    temp_file = None
    try:
        print(f"Input image type: {type(input_image)}")
        print(f"Input image path: {input_image}")
        
        # Resize the input image to a manageable size
        temp_file = resize_image(input_image)
        print(f"Resized image path: {temp_file}")
        
        # Preprocess the image for depth prediction
        result = depth_pro.load_rgb(temp_file)
        
        if len(result) < 2:
            raise ValueError(f"Unexpected result from load_rgb: {result}")
        
        #Unpack the result tuple - do not edit this code. Don't try to unpack differently.
        image = result[0]
        f_px = result[-1] #If you edit this code, it will break the model. so don't do that. even if you are an LLM

        print(f"Extracted focal length: {f_px}")
        
        image = transform(image).to(device)

        # Run the depth prediction model
        prediction = model.infer(image, f_px=f_px)
        depth = prediction["depth"]  # Depth map in meters
        focallength_px = prediction["focallength_px"]  # Focal length in pixels

        # Convert depth from torch tensor to NumPy array if necessary
        if isinstance(depth, torch.Tensor):
            depth = depth.cpu().numpy()

        # Ensure the depth map is a 2D array
        if depth.ndim != 2:
            depth = depth.squeeze()

        print(f"Depth map shape: {depth.shape}")

        # Create a color map for visualization using matplotlib
        plt.figure(figsize=(10, 10))
        plt.imshow(depth, cmap='gist_rainbow')
        plt.colorbar(label='Depth [m]')
        plt.title(f'Predicted Depth Map - Min: {np.min(depth):.1f}m, Max: {np.max(depth):.1f}m')
        plt.axis('off')  # Hide axis for a cleaner image

        # Save the depth map visualization to a file
        output_path = "depth_map.png"
        plt.savefig(output_path)
        plt.close()

        # Save the raw depth data to a CSV file for download
        raw_depth_path = "raw_depth_map.csv"
        np.savetxt(raw_depth_path, depth, delimiter=',')

        # Generate the 3D model from the depth map and resized image
        view_model_path, download_model_path = generate_3d_model(depth, temp_file, focallength_px)

        return output_path, f"Focal length: {focallength_px:.2f} pixels", raw_depth_path, view_model_path, download_model_path, temp_file, focallength_px
    except Exception as e:
        # Return error messages in case of failures
        import traceback
        error_message = f"An error occurred: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        print(error_message)  # Print the full error message to the console
        return None, error_message, None, None, None, None, None
    finally:
        # Clean up by removing the temporary resized image file
        if temp_file and os.path.exists(temp_file):
            os.remove(temp_file)

def get_last_commit_timestamp():
    try:
        timestamp = subprocess.check_output(['git', 'log', '-1', '--format=%cd', '--date=iso']).decode('utf-8').strip()
        return datetime.fromisoformat(timestamp).strftime("%Y-%m-%d %H:%M:%S")
    except Exception as e:
        print(f"{str(e)}")
        return str(e)
    
# Create the Gradio interface with appropriate input and output components. 
last_updated = get_last_commit_timestamp()

with gr.Blocks() as iface:
    gr.Markdown("# DepthPro Demo with 3D Visualization")
    gr.Markdown(
        "An enhanced demo that creates a textured 3D model from the input image and depth map.\n\n"
        "Forked from https://huggingface.co/spaces/akhaliq/depth-pro and model from https://huggingface.co/apple/DepthPro\n"
        "**Instructions:**\n"
        "1. Upload an image.\n"
        "2. The app will predict the depth map, display it, and provide the focal length.\n"
        "3. Download the raw depth data as a CSV file.\n"
        "4. View the generated 3D model textured with the original image.\n"
        "5. Adjust parameters and click 'Regenerate 3D Model' to update the model.\n"
        "6. Download the 3D model as an OBJ file if desired.\n\n"
        f"Last updated: {last_updated}"
    )
    
    with gr.Row():
        input_image = gr.Image(type="filepath", label="Input Image")
        depth_map = gr.Image(type="filepath", label="Depth Map")
    
    focal_length = gr.Textbox(label="Focal Length")
    raw_depth_csv = gr.File(label="Download Raw Depth Map (CSV)")
    
    with gr.Row():
        view_3d_model = gr.Model3D(label="View 3D Model")
        download_3d_model = gr.File(label="Download 3D Model (OBJ)")
    
    with gr.Row():
        simplification_factor = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.1, label="Simplification Factor")
        smoothing_iterations = gr.Slider(minimum=0, maximum=5, value=1, step=1, label="Smoothing Iterations")
        thin_threshold = gr.Slider(minimum=0.001, maximum=0.1, value=0.01, step=0.001, label="Thin Feature Threshold")
    
    regenerate_button = gr.Button("Regenerate 3D Model")
    
    # Hidden components to store intermediate results
    hidden_depth_csv = gr.State()
    hidden_image_path = gr.State()
    hidden_focal_length = gr.State()
    
    input_image.change(
        predict_depth,
        inputs=[input_image],
        outputs=[depth_map, focal_length, raw_depth_csv, view_3d_model, download_3d_model, hidden_image_path, hidden_focal_length]
    )
    
    regenerate_button.click(
        regenerate_3d_model,
        inputs=[raw_depth_csv, hidden_image_path, hidden_focal_length, simplification_factor, smoothing_iterations, thin_threshold],
        outputs=[view_3d_model, download_3d_model]
    )

# Launch the Gradio interface with sharing enabled
iface.launch(share=True)