img2img / app.py
Gemini899's picture
Update app.py
a279f70 verified
raw
history blame
5.66 kB
import spaces
import gradio as gr
import re
from PIL import Image
import os
import numpy as np
import torch
from diffusers import FluxImg2ImgPipeline
# Set the torch data type and device
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# If needed, add use_auth_token="YOUR_TOKEN" in from_pretrained below.
pipe = FluxImg2ImgPipeline.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
).to(device)
def sanitize_prompt(prompt):
# Allow only alphanumeric characters, spaces, and basic punctuation
allowed_chars = re.compile(r"[^a-zA-Z0-9\s.,!?-]")
sanitized_prompt = allowed_chars.sub("", prompt)
return sanitized_prompt
def convert_to_fit_size(original_width_and_height, maximum_size=2048):
width, height = original_width_and_height
if width <= maximum_size and height <= maximum_size:
return width, height
if width > height:
scaling_factor = maximum_size / width
else:
scaling_factor = maximum_size / height
new_width = int(width * scaling_factor)
new_height = int(height * scaling_factor)
return new_width, new_height
def adjust_to_multiple_of_32(width: int, height: int):
width = width - (width % 32)
height = height - (height % 32)
return width, height
@spaces.GPU(duration=120)
def process_images(image, prompt="a girl", strength=0.75, seed=0, inference_step=4, progress=gr.Progress(track_tqdm=True)):
progress(0, desc="Starting")
def process_img2img(image, prompt="a person", strength=0.75, seed=0, num_inference_steps=4):
if image is None:
print("Empty input image returned.")
return None
# Ensure the image is in RGB mode (handles formats like WebP and JFIF)
if image.mode != "RGB":
image = image.convert("RGB")
generator = torch.Generator(device).manual_seed(seed)
fit_width, fit_height = convert_to_fit_size(image.size)
width, height = adjust_to_multiple_of_32(fit_width, fit_height)
image = image.resize((width, height), Image.LANCZOS)
output = pipe(
prompt=prompt,
image=image,
generator=generator,
strength=strength,
width=width,
height=height,
guidance_scale=0,
num_inference_steps=num_inference_steps,
max_sequence_length=256,
)
pil_image = output.images[0]
new_width, new_height = pil_image.size
if (new_width != fit_width) or (new_height != fit_height):
resized_image = pil_image.resize((fit_width, fit_height), Image.LANCZOS)
return resized_image
return pil_image
output = process_img2img(image, prompt, strength, seed, inference_step)
return output
def read_file(path: str) -> str:
with open(path, 'r', encoding='utf-8') as f:
content = f.read()
return content
css = """
#col-left {
margin: 0 auto;
max-width: 640px;
}
#col-right {
margin: 0 auto;
max-width: 640px;
}
.grid-container {
display: flex;
align-items: center;
justify-content: center;
gap: 10px;
}
.image {
width: 128px;
height: 128px;
object-fit: cover;
}
.text {
font-size: 16px;
}
"""
with gr.Blocks(css=css, elem_id="demo-container") as demo:
with gr.Column():
gr.HTML(read_file("demo_header.html"))
gr.HTML(read_file("demo_tools.html"))
with gr.Row():
with gr.Column():
image = gr.Image(
height=800,
sources=['upload', 'clipboard'],
image_mode='RGB',
elem_id="image_upload",
type="pil",
label="Upload"
)
with gr.Row(elem_id="prompt-container", equal_height=False):
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
value="a women",
placeholder="Your prompt (what you want in place of what is erased)",
elem_id="prompt"
)
btn = gr.Button("Img2Img", elem_id="run_button", variant="primary")
with gr.Accordion(label="Advanced Settings", open=False):
with gr.Row(equal_height=True):
strength = gr.Number(value=0.75, minimum=0, maximum=0.75, step=0.01, label="strength")
seed = gr.Number(value=100, minimum=0, step=1, label="seed")
inference_step = gr.Number(value=4, minimum=1, step=4, label="inference_step")
id_input = gr.Text(label="Name", visible=False)
with gr.Column():
image_out = gr.Image(height=800, sources=[], label="Output", elem_id="output-img", format="jpg")
gr.Examples(
examples=[
["examples/draw_input.jpg", "examples/draw_output.jpg", "a women ,eyes closed,mouth opened"],
["examples/draw-gimp_input.jpg", "examples/draw-gimp_output.jpg", "a women ,eyes closed,mouth opened"],
["examples/gimp_input.jpg", "examples/gimp_output.jpg", "a women ,hand on neck"],
["examples/inpaint_input.jpg", "examples/inpaint_output.jpg", "a women ,hand on neck"]
],
inputs=[image, image_out, prompt],
)
gr.HTML(read_file("demo_footer.html"))
gr.on(
triggers=[btn.click, prompt.submit],
fn=process_images,
inputs=[image, prompt, strength, seed, inference_step],
outputs=[image_out]
)
if __name__ == "__main__":
demo.launch(share=True, show_error=True)