Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
from transformers import SamModel, SamProcessor | |
from diffusers import StableDiffusionInpaintPipeline | |
import requests | |
from io import BytesIO | |
# Set up device | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load SAM model for segmentation | |
print("Loading SAM model...") | |
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device) | |
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base") | |
# Load Stable Diffusion for outpainting | |
print("Loading Stable Diffusion model...") | |
inpaint_model = StableDiffusionInpaintPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-2-inpainting", | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32 | |
).to(device) | |
def get_sam_mask(image, points=None): | |
"""Get segmentation mask using SAM model""" | |
if points is None: | |
# If no points provided, use center point | |
height, width = image.shape[:2] | |
points = [[[width // 2, height // 2]]] | |
# Convert to PIL if needed | |
if not isinstance(image, Image.Image): | |
image_pil = Image.fromarray(image) | |
else: | |
image_pil = image | |
# Process the image and point prompts | |
inputs = sam_processor( | |
images=image_pil, | |
input_points=points, | |
return_tensors="pt" | |
).to(device) | |
# Generate mask | |
with torch.no_grad(): | |
outputs = sam_model(**inputs) | |
masks = sam_processor.image_processor.post_process_masks( | |
outputs.pred_masks.cpu(), | |
inputs["original_sizes"].cpu(), | |
inputs["reshaped_input_sizes"].cpu() | |
) | |
# Get the mask | |
mask = masks[0][0].numpy() | |
return mask | |
def adjust_aspect_ratio(image, mask, target_ratio, prompt=""): | |
"""Adjust image to target aspect ratio while preserving important content""" | |
# Convert PIL to numpy if needed | |
if isinstance(image, Image.Image): | |
image_np = np.array(image) | |
else: | |
image_np = image | |
h, w = image_np.shape[:2] | |
current_ratio = w / h | |
target_ratio_value = eval(target_ratio.replace(':', '/')) | |
# Determine if we need to add width or height | |
if current_ratio < target_ratio_value: | |
# Need to add width (outpaint left/right) | |
new_width = int(h * target_ratio_value) | |
new_height = h | |
# Calculate padding | |
pad_width = new_width - w | |
pad_left = pad_width // 2 | |
pad_right = pad_width - pad_left | |
# Create canvas with padding | |
result = np.zeros((new_height, new_width, 3), dtype=np.uint8) | |
# Place original image in the center | |
result[:, pad_left:pad_left+w, :] = image_np | |
# Create mask for inpainting | |
inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255 | |
inpaint_mask[:, pad_left:pad_left+w] = 0 | |
# Perform outpainting using Stable Diffusion | |
result = outpaint_regions(result, inpaint_mask, prompt) | |
else: | |
# Need to add height (outpaint top/bottom) | |
new_width = w | |
new_height = int(w / target_ratio_value) | |
# Calculate padding | |
pad_height = new_height - h | |
pad_top = pad_height // 2 | |
pad_bottom = pad_height - pad_top | |
# Create canvas with padding | |
result = np.zeros((new_height, new_width, 3), dtype=np.uint8) | |
# Place original image in the center | |
result[pad_top:pad_top+h, :, :] = image_np | |
# Create mask for inpainting | |
inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255 | |
inpaint_mask[pad_top:pad_top+h, :] = 0 | |
# Perform outpainting using Stable Diffusion | |
result = outpaint_regions(result, inpaint_mask, prompt) | |
return result | |
def outpaint_regions(image, mask, prompt): | |
"""Use Stable Diffusion to outpaint masked regions""" | |
# Convert to PIL images | |
image_pil = Image.fromarray(image) | |
mask_pil = Image.fromarray(mask) | |
# If prompt is empty, use a generic one | |
if not prompt or prompt.strip() == "": | |
prompt = "seamless extension of the image, same style, same scene" | |
# Generate the outpainting | |
output = inpaint_model( | |
prompt=prompt, | |
image=image_pil, | |
mask_image=mask_pil, | |
guidance_scale=7.5, | |
num_inference_steps=25 | |
).images[0] | |
return np.array(output) | |
def process_image(input_image, target_ratio="16:9", prompt=""): | |
"""Main processing function for the Gradio interface""" | |
try: | |
# Convert from Gradio format | |
if isinstance(input_image, dict) and 'image' in input_image: | |
image = input_image['image'] | |
else: | |
image = input_image | |
# Convert PIL to numpy if needed | |
if isinstance(image, Image.Image): | |
image_np = np.array(image) | |
else: | |
image_np = image | |
# Get SAM mask to identify important regions | |
mask = get_sam_mask(image_np) | |
# Adjust aspect ratio while preserving content | |
result = adjust_aspect_ratio(image_np, mask, target_ratio, prompt) | |
# Convert result to PIL for visualization | |
result_pil = Image.fromarray(result) | |
return result_pil | |
except Exception as e: | |
print(f"Error processing image: {e}") | |
return None | |
# Create the Gradio interface | |
with gr.Blocks(title="Automatic Aspect Ratio Adjuster") as demo: | |
gr.Markdown("# Automatic Aspect Ratio Adjuster") | |
gr.Markdown("Upload an image, choose your target aspect ratio, and let the AI adjust it while preserving important content.") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Input Image") | |
with gr.Row(): | |
aspect_ratio = gr.Dropdown( | |
choices=["16:9", "4:3", "1:1", "9:16", "3:4"], | |
value="16:9", | |
label="Target Aspect Ratio" | |
) | |
prompt = gr.Textbox( | |
label="Outpainting Prompt (optional)", | |
placeholder="Describe the scene for better outpainting" | |
) | |
submit_btn = gr.Button("Process Image") | |
with gr.Column(): | |
output_image = gr.Image(label="Processed Image") | |
submit_btn.click( | |
process_image, | |
inputs=[input_image, aspect_ratio, prompt], | |
outputs=output_image | |
) | |
gr.Markdown(""" | |
## How it works | |
1. SAM (Segment Anything Model) identifies important content in your image | |
2. The algorithm calculates how to adjust the aspect ratio while preserving this content | |
3. Stable Diffusion fills in the new areas with AI-generated content that matches the original image | |
## Tips | |
- For best results, provide a descriptive prompt that matches the scene | |
- Try different aspect ratios to see what works best | |
- The model works best with clear, well-lit images | |
""") | |
# Launch the app | |
if __name__ == "__main__": | |
demo.launch() |