mostlycached's picture
Update app.py
2698a3f verified
raw
history blame contribute delete
7.23 kB
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
from transformers import SamModel, SamProcessor
from diffusers import StableDiffusionInpaintPipeline
import requests
from io import BytesIO
# Set up device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load SAM model for segmentation
print("Loading SAM model...")
sam_model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
# Load Stable Diffusion for outpainting
print("Loading Stable Diffusion model...")
inpaint_model = StableDiffusionInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-inpainting",
torch_dtype=torch.float16 if device == "cuda" else torch.float32
).to(device)
def get_sam_mask(image, points=None):
"""Get segmentation mask using SAM model"""
if points is None:
# If no points provided, use center point
height, width = image.shape[:2]
points = [[[width // 2, height // 2]]]
# Convert to PIL if needed
if not isinstance(image, Image.Image):
image_pil = Image.fromarray(image)
else:
image_pil = image
# Process the image and point prompts
inputs = sam_processor(
images=image_pil,
input_points=points,
return_tensors="pt"
).to(device)
# Generate mask
with torch.no_grad():
outputs = sam_model(**inputs)
masks = sam_processor.image_processor.post_process_masks(
outputs.pred_masks.cpu(),
inputs["original_sizes"].cpu(),
inputs["reshaped_input_sizes"].cpu()
)
# Get the mask
mask = masks[0][0].numpy()
return mask
def adjust_aspect_ratio(image, mask, target_ratio, prompt=""):
"""Adjust image to target aspect ratio while preserving important content"""
# Convert PIL to numpy if needed
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
h, w = image_np.shape[:2]
current_ratio = w / h
target_ratio_value = eval(target_ratio.replace(':', '/'))
# Determine if we need to add width or height
if current_ratio < target_ratio_value:
# Need to add width (outpaint left/right)
new_width = int(h * target_ratio_value)
new_height = h
# Calculate padding
pad_width = new_width - w
pad_left = pad_width // 2
pad_right = pad_width - pad_left
# Create canvas with padding
result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
# Place original image in the center
result[:, pad_left:pad_left+w, :] = image_np
# Create mask for inpainting
inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
inpaint_mask[:, pad_left:pad_left+w] = 0
# Perform outpainting using Stable Diffusion
result = outpaint_regions(result, inpaint_mask, prompt)
else:
# Need to add height (outpaint top/bottom)
new_width = w
new_height = int(w / target_ratio_value)
# Calculate padding
pad_height = new_height - h
pad_top = pad_height // 2
pad_bottom = pad_height - pad_top
# Create canvas with padding
result = np.zeros((new_height, new_width, 3), dtype=np.uint8)
# Place original image in the center
result[pad_top:pad_top+h, :, :] = image_np
# Create mask for inpainting
inpaint_mask = np.ones((new_height, new_width), dtype=np.uint8) * 255
inpaint_mask[pad_top:pad_top+h, :] = 0
# Perform outpainting using Stable Diffusion
result = outpaint_regions(result, inpaint_mask, prompt)
return result
def outpaint_regions(image, mask, prompt):
"""Use Stable Diffusion to outpaint masked regions"""
# Convert to PIL images
image_pil = Image.fromarray(image)
mask_pil = Image.fromarray(mask)
# If prompt is empty, use a generic one
if not prompt or prompt.strip() == "":
prompt = "seamless extension of the image, same style, same scene"
# Generate the outpainting
output = inpaint_model(
prompt=prompt,
image=image_pil,
mask_image=mask_pil,
guidance_scale=7.5,
num_inference_steps=25
).images[0]
return np.array(output)
def process_image(input_image, target_ratio="16:9", prompt=""):
"""Main processing function for the Gradio interface"""
try:
# Convert from Gradio format
if isinstance(input_image, dict) and 'image' in input_image:
image = input_image['image']
else:
image = input_image
# Convert PIL to numpy if needed
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
# Get SAM mask to identify important regions
mask = get_sam_mask(image_np)
# Adjust aspect ratio while preserving content
result = adjust_aspect_ratio(image_np, mask, target_ratio, prompt)
# Convert result to PIL for visualization
result_pil = Image.fromarray(result)
return result_pil
except Exception as e:
print(f"Error processing image: {e}")
return None
# Create the Gradio interface
with gr.Blocks(title="Automatic Aspect Ratio Adjuster") as demo:
gr.Markdown("# Automatic Aspect Ratio Adjuster")
gr.Markdown("Upload an image, choose your target aspect ratio, and let the AI adjust it while preserving important content.")
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image")
with gr.Row():
aspect_ratio = gr.Dropdown(
choices=["16:9", "4:3", "1:1", "9:16", "3:4"],
value="16:9",
label="Target Aspect Ratio"
)
prompt = gr.Textbox(
label="Outpainting Prompt (optional)",
placeholder="Describe the scene for better outpainting"
)
submit_btn = gr.Button("Process Image")
with gr.Column():
output_image = gr.Image(label="Processed Image")
submit_btn.click(
process_image,
inputs=[input_image, aspect_ratio, prompt],
outputs=output_image
)
gr.Markdown("""
## How it works
1. SAM (Segment Anything Model) identifies important content in your image
2. The algorithm calculates how to adjust the aspect ratio while preserving this content
3. Stable Diffusion fills in the new areas with AI-generated content that matches the original image
## Tips
- For best results, provide a descriptive prompt that matches the scene
- Try different aspect ratios to see what works best
- The model works best with clear, well-lit images
""")
# Launch the app
if __name__ == "__main__":
demo.launch()