File size: 7,226 Bytes
b16f2d1
 
 
 
2698a3f
b16f2d1
 
2698a3f
 
b16f2d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2698a3f
 
 
 
 
 
 
 
 
 
b16f2d1
2698a3f
b16f2d1
2698a3f
b16f2d1
2698a3f
b16f2d1
 
 
 
2698a3f
b16f2d1
 
 
 
 
 
 
 
2698a3f
 
 
b16f2d1
2698a3f
b16f2d1
 
 
 
 
 
 
 
 
 
 
2698a3f
b16f2d1
 
 
 
2698a3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b16f2d1
 
 
 
2698a3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b16f2d1
2698a3f
 
 
 
 
 
b16f2d1
 
2698a3f
b16f2d1
2698a3f
b16f2d1
2698a3f
b16f2d1
 
2698a3f
b16f2d1
 
2698a3f
b16f2d1
 
 
 
 
 
 
2698a3f
b16f2d1
 
 
 
 
 
 
 
 
 
 
2698a3f
 
 
b16f2d1
2698a3f
b16f2d1
 
 
 
 
 
 
 
 
 
 
2698a3f
 
 
b16f2d1
 
 
 
 
 
 
2698a3f
b16f2d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2698a3f
 
 
b16f2d1
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline
import requests
from io import BytesIO

# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load SAM model for segmentation
print("Loading SAM model...")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

# Load Stable Diffusion for outpainting
print("Loading Stable Diffusion model...")
inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)

def get_sam_mask(image, points=None):
    """Get segmentation mask using SAM model"""
    if points is None:
        # If no points provided, use center point
        height, width = image.shape[:2]
        points = [[[width // 2, height // 2]]]
        
    # Convert to PIL if needed
    if not isinstance(image, Image.Image):
        image_pil = Image.fromarray(image)
    else:
        image_pil = image
        
    # Process the image and point prompts
    inputs = sam_processor(
        images=image_pil,
        input_points=points,
        return_tensors="pt"
    ).to(device)
    
    # Generate mask
    with torch.no_grad():
        outputs = sam_model(**inputs)
        masks = sam_processor.image_processor.post_process_masks(
            outputs.pred_masks.cpu(),
            inputs["original_sizes"].cpu(),
            inputs["reshaped_input_sizes"].cpu()
        )
    
    # Get the mask
    mask = masks[0][0].numpy()
    return mask

def adjust_aspect_ratio(image, mask, target_ratio, prompt=""):
    """Adjust image to target aspect ratio while preserving important content"""
    # Convert PIL to numpy if needed
    if isinstance(image, Image.Image):
        image_np = np.array(image)
    else:
        image_np = image
        
    h, w = image_np.shape[:2]
    current_ratio = w / h
    target_ratio_value = eval(target_ratio.replace(':', '/'))
    
    # Determine if we need to add width or height
    if current_ratio < target_ratio_value:
        # Need to add width (outpaint left/right)
        new_width = int(h * target_ratio_value)
        new_height = h
        
        # Calculate padding
        pad_width = new_width - w
        pad_left = pad_width // 2
        pad_right = pad_width - pad_left
        
        # Create canvas with padding
        result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
        # Place original image in the center
        result[:, pad_left:pad_left+w, :] = image_np
        
        # Create mask for inpainting
        inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
        inpaint_mask[:, pad_left:pad_left+w] = 0
        
        # Perform outpainting using Stable Diffusion
        result = outpaint_regions(result, inpaint_mask, prompt)
        
    else:
        # Need to add height (outpaint top/bottom)
        new_width = w
        new_height = int(w / target_ratio_value)
        
        # Calculate padding
        pad_height = new_height - h
        pad_top = pad_height // 2
        pad_bottom = pad_height - pad_top
        
        # Create canvas with padding
        result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
        # Place original image in the center
        result[pad_top:pad_top+h, :, :] = image_np
        
        # Create mask for inpainting
        inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
        inpaint_mask[pad_top:pad_top+h, :] = 0
        
        # Perform outpainting using Stable Diffusion
        result = outpaint_regions(result, inpaint_mask, prompt)
    
    return result

def outpaint_regions(image, mask, prompt):
    """Use Stable Diffusion to outpaint masked regions"""
    # Convert to PIL images
    image_pil = Image.fromarray(image)
    mask_pil = Image.fromarray(mask)
    
    # If prompt is empty, use a generic one
    if not prompt or prompt.strip() == "":
        prompt = "seamless extension of the image, same style, same scene"
    
    # Generate the outpainting
    output = inpaint_model(
        prompt=prompt,
        image=image_pil,
        mask_image=mask_pil,
        guidance_scale=7.5,
        num_inference_steps=25
    ).images[0]
    
    return np.array(output)

def process_image(input_image, target_ratio="16:9", prompt=""):
    """Main processing function for the Gradio interface"""
    try:
        # Convert from Gradio format
        if isinstance(input_image, dict) and 'image' in input_image:
            image = input_image['image']
        else:
            image = input_image
            
        # Convert PIL to numpy if needed
        if isinstance(image, Image.Image):
            image_np = np.array(image)
        else:
            image_np = image
            
        # Get SAM mask to identify important regions
        mask = get_sam_mask(image_np)
        
        # Adjust aspect ratio while preserving content
        result = adjust_aspect_ratio(image_np, mask, target_ratio, prompt)
        
        # Convert result to PIL for visualization
        result_pil = Image.fromarray(result)
        
        return result_pil
    
    except Exception as e:
        print(f"Error processing image: {e}")
        return None

# Create the Gradio interface
with gr.Blocks(title="Automatic Aspect Ratio Adjuster") as demo:
    gr.Markdown("# Automatic Aspect Ratio Adjuster")
    gr.Markdown("Upload an image, choose your target aspect ratio, and let the AI adjust it while preserving important content.")
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Input Image")
            
            with gr.Row():
                aspect_ratio = gr.Dropdown(
                    choices=["16:9", "4:3", "1:1", "9:16", "3:4"], 
                    value="16:9",
                    label="Target Aspect Ratio"
                )
                
            prompt = gr.Textbox(
                label="Outpainting Prompt (optional)", 
                placeholder="Describe the scene for better outpainting"
            )
            
            submit_btn = gr.Button("Process Image")
        
        with gr.Column():
            output_image = gr.Image(label="Processed Image")
    
    submit_btn.click(
        process_image, 
        inputs=[input_image, aspect_ratio, prompt], 
        outputs=output_image
    )
    
    gr.Markdown("""
    ## How it works
    1. SAM (Segment Anything Model) identifies important content in your image
    2. The algorithm calculates how to adjust the aspect ratio while preserving this content
    3. Stable Diffusion fills in the new areas with AI-generated content that matches the original image
    
    ## Tips
    - For best results, provide a descriptive prompt that matches the scene
    - Try different aspect ratios to see what works best
    - The model works best with clear, well-lit images
    """)

# Launch the app
if __name__ == "__main__":
    demo.launch()