Spaces:
Running
on
Zero
Running
on
Zero
import spaces # import first | |
import random | |
import numpy as np | |
import torch | |
from diffusers import StableDiffusionXLPipeline | |
import gradio as gr | |
from tkg import apply_tkg_noise, ColorSet, COLOR_SET_MAP | |
torch.backends.cuda.matmul.allow_tf32 = True | |
torch.backends.cudnn.allow_tf32 = True | |
device = "cuda" | |
model_repo_id = "cagliostrolab/animagine-xl-4.0" # Replace to the model you would like to use | |
pipe = StableDiffusionXLPipeline.from_pretrained( | |
"cagliostrolab/animagine-xl-4.0", | |
torch_dtype=torch.bfloat16, | |
custom_pipeline="lpw_stable_diffusion_xl", | |
add_watermarker=False, | |
) | |
pipe = pipe.to(device) | |
MAX_SEED = np.iinfo(np.int32).max | |
MAX_IMAGE_SIZE = 2048 | |
def infer( | |
prompt: str, | |
negative_prompt: str, | |
seed: int, | |
randomize_seed: bool, | |
width: int, | |
height: int, | |
guidance_scale: float, | |
num_inference_steps: int, | |
tkg_channels: list[int] = [0, 1, 1, 0], | |
chroma_key_shift: float = 0.11, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
generator = torch.Generator(device=device).manual_seed(seed) | |
latents = torch.randn( | |
( | |
1, | |
4, # 4 channels | |
height // 8, | |
width // 8, | |
), | |
generator=generator, | |
device=device, | |
dtype=torch.bfloat16, | |
) | |
tkg_latents = apply_tkg_noise( | |
latents, | |
shift=chroma_key_shift, | |
delta_shift=0.1, | |
std_dev=0.5, | |
factor=8, | |
channels=tkg_channels, | |
).to(torch.bfloat16) | |
latents = torch.cat( | |
[ | |
tkg_latents, | |
latents, | |
], | |
dim=0, | |
) | |
images = pipe( | |
latents=latents, | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
guidance_scale=guidance_scale, | |
num_inference_steps=num_inference_steps, | |
width=width, | |
height=height, | |
num_images_per_prompt=2, | |
generator=generator, | |
).images | |
w_tkg, wo_tkg = images | |
return w_tkg, wo_tkg, seed | |
def color_name_to_channels(color_name: str) -> list[int]: | |
if color_name in COLOR_SET_MAP: | |
return COLOR_SET_MAP[color_name].channels | |
else: | |
raise ValueError(f"Unknown color name: {color_name}") | |
def on_generate( | |
prompt: str, | |
negative_prompt: str, | |
seed: int, | |
randomize_seed: bool, | |
width: int, | |
height: int, | |
guidance_scale: float, | |
num_inference_steps: int, | |
color_name: str, | |
chroma_key_shift: float, | |
*args, | |
**kwargs | |
): | |
tkg_channels = color_name_to_channels(color_name) | |
# TODO: custom channels | |
w_tkg, wo_tkg, seed = infer( | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
tkg_channels=tkg_channels, | |
chroma_key_shift=chroma_key_shift, | |
*args, | |
**kwargs, | |
) | |
return w_tkg, wo_tkg, seed | |
examples = [ | |
# "1girl, arima kana, oshi no ko, hoshimachi suisei, hoshimachi suisei \(1st costume\), cosplay, looking at viewer, smile, outdoors, night, v, masterpiece, high score, great score, absurdres", | |
"1girl, solo, school uniform, cat ears, full body, looking at viewer, straight-on, chibi, simple background, best quality", | |
"1girl, solo, hand up, waving, long hair, sideways glance, upper body, cropped torso, simple background, best quality", | |
] | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
gr.Markdown( | |
""" | |
# TKG Chroma-Key with AnimagineXL 4.0 | |
TKG-DMπ₯π: Training-free Chroma Key Content Generation Diffusion Model | |
- arXiv: https://arxiv.org/abs/2411.15580 | |
- GitHub: https://github.com/ryugo417/TKG-DM | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
prompt = gr.Textbox( | |
label="Prompt", | |
max_lines=4, | |
placeholder="Enter your prompt", | |
) | |
color_set = gr.Dropdown( | |
label="Background color", | |
choices=list(COLOR_SET_MAP.keys()), | |
value="green", | |
) | |
with gr.Accordion("TKG Settings", open=False): | |
chroma_key_shift = gr.Slider( | |
label="Latent mean shift for chroma key", | |
minimum=0.0, | |
maximum=0.2, | |
step=0.005, | |
value=0.11, | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
negative_prompt = gr.Textbox( | |
label="Negative prompt", | |
max_lines=4, | |
placeholder="Enter a negative prompt", | |
value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry", | |
) | |
seed = gr.Slider( | |
label="Seed", | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=0, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
with gr.Row(): | |
width = gr.Slider( | |
label="Width", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=832, | |
) | |
height = gr.Slider( | |
label="Height", | |
minimum=256, | |
maximum=MAX_IMAGE_SIZE, | |
step=32, | |
value=1152, | |
) | |
with gr.Row(): | |
guidance_scale = gr.Slider( | |
label="Guidance scale", | |
minimum=0.0, | |
maximum=10.0, | |
step=0.1, | |
value=5.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Number of inference steps", | |
minimum=1, | |
maximum=50, | |
step=1, | |
value=25, | |
) | |
with gr.Column(): | |
run_button = gr.Button("Generate", variant="primary") | |
with gr.Row(): | |
result_w_tkg = gr.Image(label="with TKG") | |
result_wo_tkg = gr.Image(label="without TKG") | |
gr.Examples(examples=examples, inputs=[prompt]) | |
gr.on( | |
triggers=[run_button.click, prompt.submit], | |
fn=on_generate, | |
inputs=[ | |
prompt, | |
negative_prompt, | |
seed, | |
randomize_seed, | |
width, | |
height, | |
guidance_scale, | |
num_inference_steps, | |
color_set, | |
chroma_key_shift, | |
], | |
outputs=[result_w_tkg, result_wo_tkg, seed], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |