p1atdev's picture
Update app.py
e63593c verified
import spaces # import first
import random
import numpy as np
import torch
from diffusers import StableDiffusionXLPipeline
import gradio as gr
from tkg import apply_tkg_noise, ColorSet, COLOR_SET_MAP
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
device = "cuda"
model_repo_id = "cagliostrolab/animagine-xl-4.0" # Replace to the model you would like to use
pipe = StableDiffusionXLPipeline.from_pretrained(
"cagliostrolab/animagine-xl-4.0",
torch_dtype=torch.bfloat16,
custom_pipeline="lpw_stable_diffusion_xl",
add_watermarker=False,
)
pipe = pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
@spaces.GPU
def infer(
prompt: str,
negative_prompt: str,
seed: int,
randomize_seed: bool,
width: int,
height: int,
guidance_scale: float,
num_inference_steps: int,
tkg_channels: list[int] = [0, 1, 1, 0],
chroma_key_shift: float = 0.11,
progress=gr.Progress(track_tqdm=True),
):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator(device=device).manual_seed(seed)
latents = torch.randn(
(
1,
4, # 4 channels
height // 8,
width // 8,
),
generator=generator,
device=device,
dtype=torch.bfloat16,
)
tkg_latents = apply_tkg_noise(
latents,
shift=chroma_key_shift,
delta_shift=0.1,
std_dev=0.5,
factor=8,
channels=tkg_channels,
).to(torch.bfloat16)
latents = torch.cat(
[
tkg_latents,
latents,
],
dim=0,
)
images = pipe(
latents=latents,
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
num_images_per_prompt=2,
generator=generator,
).images
w_tkg, wo_tkg = images
return w_tkg, wo_tkg, seed
def color_name_to_channels(color_name: str) -> list[int]:
if color_name in COLOR_SET_MAP:
return COLOR_SET_MAP[color_name].channels
else:
raise ValueError(f"Unknown color name: {color_name}")
def on_generate(
prompt: str,
negative_prompt: str,
seed: int,
randomize_seed: bool,
width: int,
height: int,
guidance_scale: float,
num_inference_steps: int,
color_name: str,
chroma_key_shift: float,
*args,
**kwargs
):
tkg_channels = color_name_to_channels(color_name)
# TODO: custom channels
w_tkg, wo_tkg, seed = infer(
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
tkg_channels=tkg_channels,
chroma_key_shift=chroma_key_shift,
*args,
**kwargs,
)
return w_tkg, wo_tkg, seed
examples = [
# "1girl, arima kana, oshi no ko, hoshimachi suisei, hoshimachi suisei \(1st costume\), cosplay, looking at viewer, smile, outdoors, night, v, masterpiece, high score, great score, absurdres",
"1girl, solo, school uniform, cat ears, full body, looking at viewer, straight-on, chibi, simple background, best quality",
"1girl, solo, hand up, waving, long hair, sideways glance, upper body, cropped torso, simple background, best quality",
]
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown(
"""
# TKG Chroma-Key with AnimagineXL 4.0
TKG-DMπŸ₯šπŸš: Training-free Chroma Key Content Generation Diffusion Model
- arXiv: https://arxiv.org/abs/2411.15580
- GitHub: https://github.com/ryugo417/TKG-DM
""")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Prompt",
max_lines=4,
placeholder="Enter your prompt",
)
color_set = gr.Dropdown(
label="Background color",
choices=list(COLOR_SET_MAP.keys()),
value="green",
)
with gr.Accordion("TKG Settings", open=False):
chroma_key_shift = gr.Slider(
label="Latent mean shift for chroma key",
minimum=0.0,
maximum=0.2,
step=0.005,
value=0.11,
)
with gr.Accordion("Advanced Settings", open=False):
negative_prompt = gr.Textbox(
label="Negative prompt",
max_lines=4,
placeholder="Enter a negative prompt",
value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry",
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=832,
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1152,
)
with gr.Row():
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=25,
)
with gr.Column():
run_button = gr.Button("Generate", variant="primary")
with gr.Row():
result_w_tkg = gr.Image(label="with TKG")
result_wo_tkg = gr.Image(label="without TKG")
gr.Examples(examples=examples, inputs=[prompt])
gr.on(
triggers=[run_button.click, prompt.submit],
fn=on_generate,
inputs=[
prompt,
negative_prompt,
seed,
randomize_seed,
width,
height,
guidance_scale,
num_inference_steps,
color_set,
chroma_key_shift,
],
outputs=[result_w_tkg, result_wo_tkg, seed],
)
if __name__ == "__main__":
demo.launch()