mostlycached commited on
Commit
2698a3f
·
verified ·
1 Parent(s): b344378

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -121
app.py CHANGED
@@ -2,10 +2,11 @@ import gradio as gr
2
  import torch
3
  import numpy as np
4
  import cv2
5
- from PIL import Image, ImageOps
6
  from transformers import SamModel, SamProcessor
7
  from diffusers import StableDiffusionInpaintPipeline
8
- import os
 
9
 
10
  # Set up device
11
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -23,35 +24,27 @@ inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
23
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
24
  ).to(device)
25
 
26
- def get_importance_map(image, points=None):
27
- """Get importance map using SAM model to identify key content regions"""
28
- # Convert to numpy if needed
29
- if isinstance(image, Image.Image):
30
- image_np = np.array(image)
 
 
 
 
 
31
  else:
32
- image_np = image
33
 
34
- h, w = image_np.shape[:2]
35
-
36
- # If no points provided, use grid sampling to identify important areas
37
- if points is None:
38
- # Create a grid of points to sample the image
39
- x_points = np.linspace(w//4, 3*w//4, 5, dtype=int)
40
- y_points = np.linspace(h//4, 3*h//4, 5, dtype=int)
41
- grid_points = []
42
- for y in y_points:
43
- for x in x_points:
44
- grid_points.append([x, y])
45
- points = [grid_points]
46
-
47
- # Process image through SAM
48
  inputs = sam_processor(
49
- images=image_np,
50
  input_points=points,
51
  return_tensors="pt"
52
  ).to(device)
53
 
54
- # Generate masks
55
  with torch.no_grad():
56
  outputs = sam_model(**inputs)
57
  masks = sam_processor.image_processor.post_process_masks(
@@ -60,123 +53,86 @@ def get_importance_map(image, points=None):
60
  inputs["reshaped_input_sizes"].cpu()
61
  )
62
 
63
- # Combine all masks to create importance map
64
- importance_map = np.zeros((h, w), dtype=np.float32)
65
- for i in range(len(masks[0])):
66
- importance_map += masks[0][i].numpy().astype(np.float32)
67
-
68
- # Normalize to 0-1
69
- if importance_map.max() > 0:
70
- importance_map = importance_map / importance_map.max()
71
-
72
- return importance_map
73
 
74
- def find_optimal_placement(importance_map, original_size, new_size):
75
- """Find the optimal placement for the original image within the new canvas based on importance"""
76
- oh, ow = original_size
77
- nh, nw = new_size
78
-
79
- # If the new size is smaller in any dimension, then just center it
80
- if nh <= oh or nw <= ow:
81
- x_offset = max(0, (nw - ow) // 2)
82
- y_offset = max(0, (nh - oh) // 2)
83
- return x_offset, y_offset
84
-
85
- # Calculate all possible positions
86
- possible_x = nw - ow + 1
87
- possible_y = nh - oh + 1
88
-
89
- best_score = -np.inf
90
- best_x = 0
91
- best_y = 0
92
-
93
- # Create a border-weighted importance map (gives extra weight to content near borders)
94
- y_coords, x_coords = np.ogrid[:oh, :ow]
95
- border_weight = np.minimum(np.minimum(x_coords, ow-1-x_coords), np.minimum(y_coords, oh-1-y_coords))
96
- border_weight = 1.0 - border_weight / border_weight.max()
97
- weighted_importance = importance_map * (1.0 + 0.5 * border_weight)
98
-
99
- # Optimize for 9 positions (corners, center of edges, and center)
100
- positions = [
101
- (0, 0), # Top-left
102
- (0, (possible_y-1)//2), # Middle-left
103
- (0, possible_y-1), # Bottom-left
104
- ((possible_x-1)//2, 0), # Top-center
105
- ((possible_x-1)//2, (possible_y-1)//2), # Center
106
- ((possible_x-1)//2, possible_y-1), # Bottom-center
107
- (possible_x-1, 0), # Top-right
108
- (possible_x-1, (possible_y-1)//2), # Middle-right
109
- (possible_x-1, possible_y-1) # Bottom-right
110
- ]
111
-
112
- # Find position with highest importance score
113
- for x, y in positions:
114
- # Calculate importance score for this position
115
- score = weighted_importance.sum()
116
- if score > best_score:
117
- best_score = score
118
- best_x = x
119
- best_y = y
120
-
121
- return best_x, best_y
122
-
123
- def adjust_aspect_ratio(image, target_ratio, prompt=""):
124
  """Adjust image to target aspect ratio while preserving important content"""
125
  # Convert PIL to numpy if needed
126
  if isinstance(image, Image.Image):
127
- image_pil = image
128
  image_np = np.array(image)
129
  else:
130
  image_np = image
131
- image_pil = Image.fromarray(image_np)
132
 
133
- # Get dimensions
134
  h, w = image_np.shape[:2]
135
  current_ratio = w / h
136
  target_ratio_value = eval(target_ratio.replace(':', '/'))
137
 
138
- # Generate importance map to identify key regions
139
- importance_map = get_importance_map(image_np)
140
-
141
- # Calculate new dimensions
142
  if current_ratio < target_ratio_value:
143
  # Need to add width (outpaint left/right)
144
  new_width = int(h * target_ratio_value)
145
  new_height = h
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
  else:
147
  # Need to add height (outpaint top/bottom)
148
  new_width = w
149
  new_height = int(w / target_ratio_value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- # Find optimal placement based on importance map
152
- x_offset, y_offset = find_optimal_placement(importance_map, (h, w), (new_height, new_width))
153
-
154
- # Create new canvas
155
- result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
156
- mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
157
-
158
- # Place original image at calculated position
159
- result[y_offset:y_offset+h, x_offset:x_offset+w] = image_np
160
- mask[y_offset:y_offset+h, x_offset:x_offset+w] = 0
161
-
162
- # Convert to PIL for inpainting
163
- result_pil = Image.fromarray(result)
164
  mask_pil = Image.fromarray(mask)
165
 
166
- # Use default prompt if none provided
167
  if not prompt or prompt.strip() == "":
168
- if len(image_np.shape) == 3 and image_np.shape[2] == 4: # Check if image has alpha channel
169
- prompt = "seamless extension of the image, same style and content"
170
- else:
171
- prompt = "seamless extension of the image, same style, same scene, consistent lighting"
172
 
173
- # Perform outpainting using Stable Diffusion
174
  output = inpaint_model(
175
  prompt=prompt,
176
- image=result_pil,
177
  mask_image=mask_pil,
178
  guidance_scale=7.5,
179
- num_inference_steps=30
180
  ).images[0]
181
 
182
  return np.array(output)
@@ -184,7 +140,7 @@ def adjust_aspect_ratio(image, target_ratio, prompt=""):
184
  def process_image(input_image, target_ratio="16:9", prompt=""):
185
  """Main processing function for the Gradio interface"""
186
  try:
187
- # Convert from Gradio format if needed
188
  if isinstance(input_image, dict) and 'image' in input_image:
189
  image = input_image['image']
190
  else:
@@ -196,8 +152,11 @@ def process_image(input_image, target_ratio="16:9", prompt=""):
196
  else:
197
  image_np = image
198
 
 
 
 
199
  # Adjust aspect ratio while preserving content
200
- result = adjust_aspect_ratio(image_np, target_ratio, prompt)
201
 
202
  # Convert result to PIL for visualization
203
  result_pil = Image.fromarray(result)
@@ -209,9 +168,9 @@ def process_image(input_image, target_ratio="16:9", prompt=""):
209
  return None
210
 
211
  # Create the Gradio interface
212
- with gr.Blocks(title="Smart Aspect Ratio Adjuster") as demo:
213
- gr.Markdown("# Smart Aspect Ratio Adjuster")
214
- gr.Markdown("Upload an image, choose your target aspect ratio, and the AI will adjust it while intelligently preserving important content.")
215
 
216
  with gr.Row():
217
  with gr.Column():
@@ -219,7 +178,7 @@ with gr.Blocks(title="Smart Aspect Ratio Adjuster") as demo:
219
 
220
  with gr.Row():
221
  aspect_ratio = gr.Dropdown(
222
- choices=["16:9", "4:3", "1:1", "9:16", "3:4", "2:1", "1:2"],
223
  value="16:9",
224
  label="Target Aspect Ratio"
225
  )
@@ -242,9 +201,9 @@ with gr.Blocks(title="Smart Aspect Ratio Adjuster") as demo:
242
 
243
  gr.Markdown("""
244
  ## How it works
245
- 1. **Content Analysis**: SAM (Segment Anything Model) identifies important regions in your image
246
- 2. **Smart Placement**: The algorithm calculates optimal positioning to preserve key content
247
- 3. **AI Outpainting**: Stable Diffusion fills in new areas with matching content
248
 
249
  ## Tips
250
  - For best results, provide a descriptive prompt that matches the scene
 
2
  import torch
3
  import numpy as np
4
  import cv2
5
+ from PIL import Image
6
  from transformers import SamModel, SamProcessor
7
  from diffusers import StableDiffusionInpaintPipeline
8
+ import requests
9
+ from io import BytesIO
10
 
11
  # Set up device
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
24
  torch_dtype=torch.float16 if device == "cuda" else torch.float32
25
  ).to(device)
26
 
27
+ def get_sam_mask(image, points=None):
28
+ """Get segmentation mask using SAM model"""
29
+ if points is None:
30
+ # If no points provided, use center point
31
+ height, width = image.shape[:2]
32
+ points = [[[width // 2, height // 2]]]
33
+
34
+ # Convert to PIL if needed
35
+ if not isinstance(image, Image.Image):
36
+ image_pil = Image.fromarray(image)
37
  else:
38
+ image_pil = image
39
 
40
+ # Process the image and point prompts
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  inputs = sam_processor(
42
+ images=image_pil,
43
  input_points=points,
44
  return_tensors="pt"
45
  ).to(device)
46
 
47
+ # Generate mask
48
  with torch.no_grad():
49
  outputs = sam_model(**inputs)
50
  masks = sam_processor.image_processor.post_process_masks(
 
53
  inputs["reshaped_input_sizes"].cpu()
54
  )
55
 
56
+ # Get the mask
57
+ mask = masks[0][0].numpy()
58
+ return mask
 
 
 
 
 
 
 
59
 
60
+ def adjust_aspect_ratio(image, mask, target_ratio, prompt=""):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  """Adjust image to target aspect ratio while preserving important content"""
62
  # Convert PIL to numpy if needed
63
  if isinstance(image, Image.Image):
 
64
  image_np = np.array(image)
65
  else:
66
  image_np = image
 
67
 
 
68
  h, w = image_np.shape[:2]
69
  current_ratio = w / h
70
  target_ratio_value = eval(target_ratio.replace(':', '/'))
71
 
72
+ # Determine if we need to add width or height
 
 
 
73
  if current_ratio < target_ratio_value:
74
  # Need to add width (outpaint left/right)
75
  new_width = int(h * target_ratio_value)
76
  new_height = h
77
+
78
+ # Calculate padding
79
+ pad_width = new_width - w
80
+ pad_left = pad_width // 2
81
+ pad_right = pad_width - pad_left
82
+
83
+ # Create canvas with padding
84
+ result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
85
+ # Place original image in the center
86
+ result[:, pad_left:pad_left+w, :] = image_np
87
+
88
+ # Create mask for inpainting
89
+ inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
90
+ inpaint_mask[:, pad_left:pad_left+w] = 0
91
+
92
+ # Perform outpainting using Stable Diffusion
93
+ result = outpaint_regions(result, inpaint_mask, prompt)
94
+
95
  else:
96
  # Need to add height (outpaint top/bottom)
97
  new_width = w
98
  new_height = int(w / target_ratio_value)
99
+
100
+ # Calculate padding
101
+ pad_height = new_height - h
102
+ pad_top = pad_height // 2
103
+ pad_bottom = pad_height - pad_top
104
+
105
+ # Create canvas with padding
106
+ result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
107
+ # Place original image in the center
108
+ result[pad_top:pad_top+h, :, :] = image_np
109
+
110
+ # Create mask for inpainting
111
+ inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
112
+ inpaint_mask[pad_top:pad_top+h, :] = 0
113
+
114
+ # Perform outpainting using Stable Diffusion
115
+ result = outpaint_regions(result, inpaint_mask, prompt)
116
 
117
+ return result
118
+
119
+ def outpaint_regions(image, mask, prompt):
120
+ """Use Stable Diffusion to outpaint masked regions"""
121
+ # Convert to PIL images
122
+ image_pil = Image.fromarray(image)
 
 
 
 
 
 
 
123
  mask_pil = Image.fromarray(mask)
124
 
125
+ # If prompt is empty, use a generic one
126
  if not prompt or prompt.strip() == "":
127
+ prompt = "seamless extension of the image, same style, same scene"
 
 
 
128
 
129
+ # Generate the outpainting
130
  output = inpaint_model(
131
  prompt=prompt,
132
+ image=image_pil,
133
  mask_image=mask_pil,
134
  guidance_scale=7.5,
135
+ num_inference_steps=25
136
  ).images[0]
137
 
138
  return np.array(output)
 
140
  def process_image(input_image, target_ratio="16:9", prompt=""):
141
  """Main processing function for the Gradio interface"""
142
  try:
143
+ # Convert from Gradio format
144
  if isinstance(input_image, dict) and 'image' in input_image:
145
  image = input_image['image']
146
  else:
 
152
  else:
153
  image_np = image
154
 
155
+ # Get SAM mask to identify important regions
156
+ mask = get_sam_mask(image_np)
157
+
158
  # Adjust aspect ratio while preserving content
159
+ result = adjust_aspect_ratio(image_np, mask, target_ratio, prompt)
160
 
161
  # Convert result to PIL for visualization
162
  result_pil = Image.fromarray(result)
 
168
  return None
169
 
170
  # Create the Gradio interface
171
+ with gr.Blocks(title="Automatic Aspect Ratio Adjuster") as demo:
172
+ gr.Markdown("# Automatic Aspect Ratio Adjuster")
173
+ gr.Markdown("Upload an image, choose your target aspect ratio, and let the AI adjust it while preserving important content.")
174
 
175
  with gr.Row():
176
  with gr.Column():
 
178
 
179
  with gr.Row():
180
  aspect_ratio = gr.Dropdown(
181
+ choices=["16:9", "4:3", "1:1", "9:16", "3:4"],
182
  value="16:9",
183
  label="Target Aspect Ratio"
184
  )
 
201
 
202
  gr.Markdown("""
203
  ## How it works
204
+ 1. SAM (Segment Anything Model) identifies important content in your image
205
+ 2. The algorithm calculates how to adjust the aspect ratio while preserving this content
206
+ 3. Stable Diffusion fills in the new areas with AI-generated content that matches the original image
207
 
208
  ## Tips
209
  - For best results, provide a descriptive prompt that matches the scene