Spaces:
Running
Running
import torch | |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler | |
from PIL import Image, ImageFilter | |
import numpy as np | |
import gradio as gr | |
import cv2 | |
# Load pre-trained Stable Diffusion model (frozen part) | |
model_id = "runwayml/stable-diffusion-v1-5" | |
controlnet_id = "lllyasviel/control_v11p_sd15_canny" # ControlNet for edge detection-based control | |
# Load ControlNet model (trainable part) | |
controlnet = ControlNetModel.from_pretrained(controlnet_id, torch_dtype=torch.float16) | |
# Load Stable Diffusion pipeline with ControlNet | |
pipe = StableDiffusionControlNetPipeline.from_pretrained( | |
model_id, controlnet=controlnet, torch_dtype=torch.float16 | |
) | |
# Use an efficient scheduler | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
# Move pipeline to GPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe.to(device) | |
# Function to generate control image (edge detection using Canny filter) | |
def generate_control_image(input_image_path): | |
image = cv2.imread(input_image_path, cv2.IMREAD_GRAYSCALE) | |
edges = cv2.Canny(image, 100, 200) # Apply Canny edge detection | |
control_image = Image.fromarray(edges).convert("L") | |
control_image = control_image.resize((512, 512)) # Resize to match model requirements | |
control_image.save("control_image.jpg") | |
return "control_image.jpg" | |
# Function to apply color change | |
def apply_color_change(input_image, prompt): | |
# Save input image temporarily | |
input_image_path = "input_image.jpg" | |
input_image.save(input_image_path) | |
# Generate control image (edges) | |
control_image_path = generate_control_image(input_image_path) | |
# Load processed input and control images | |
input_image = Image.open(input_image_path).convert("RGB").resize((512, 512)) | |
control_image = Image.open(control_image_path).convert("L") | |
# Generate the new image using the pipeline | |
generator = torch.manual_seed(42) # For reproducibility | |
output_image = pipe( | |
prompt=prompt, | |
image=input_image, | |
control_image=control_image, | |
generator=generator, | |
num_inference_steps=30 | |
).images[0] | |
output_image.save("output_color_changed.png") | |
return "output_color_changed.png" | |
# Gradio interface | |
def gradio_interface(input_image, prompt): | |
output_image_path = apply_color_change(input_image, prompt) | |
return output_image_path | |
# Launch the Gradio interface with drag and drop | |
interface = gr.Interface( | |
fn=gradio_interface, | |
inputs=[ | |
gr.Image(type="pil", label="Upload your image"), # Drag and drop feature | |
gr.Textbox(label="Enter prompt", placeholder="e.g. A hoodie with blue and white design"), | |
], | |
outputs=gr.Image(label="Color Changed Output"), | |
title="AI-Powered Clothing Color Changer", | |
description="Upload an image of clothing, enter a prompt, and get a redesigned color version.", | |
) | |
interface.launch() |