import spaces # import first import random import numpy as np import torch from diffusers import StableDiffusionXLPipeline import gradio as gr from tkg import apply_tkg_noise, ColorSet, COLOR_SET_MAP torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True device = "cuda" model_repo_id = "cagliostrolab/animagine-xl-4.0" # Replace to the model you would like to use pipe = StableDiffusionXLPipeline.from_pretrained( "cagliostrolab/animagine-xl-4.0", torch_dtype=torch.bfloat16, custom_pipeline="lpw_stable_diffusion_xl", add_watermarker=False, ) pipe = pipe.to(device) MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 2048 @spaces.GPU def infer( prompt: str, negative_prompt: str, seed: int, randomize_seed: bool, width: int, height: int, guidance_scale: float, num_inference_steps: int, tkg_channels: list[int] = [0, 1, 1, 0], progress=gr.Progress(track_tqdm=True), ): if randomize_seed: seed = random.randint(0, MAX_SEED) generator = torch.Generator(device=device).manual_seed(seed) latents = torch.randn( ( 1, 4, # 4 channels height // 8, width // 8, ), generator=generator, device=device, dtype=torch.bfloat16, ) latents = apply_tkg_noise( latents, shift=0.11, delta_shift=0.1, std_dev=0.5, factor=8, channels=tkg_channels, ) images = pipe( latents=latents, prompt=prompt, negative_prompt=negative_prompt, guidance_scale=guidance_scale, num_inference_steps=num_inference_steps, width=width, height=height, generator=generator, ).images w_tkg, wo_tkg = images return w_tkg, wo_tkg, seed def color_name_to_channels(color_name: str) -> list[int]: if color_name in COLOR_SET_MAP: return COLOR_SET_MAP[color_name].channels else: raise ValueError(f"Unknown color name: {color_name}") def on_generate( prompt: str, negative_prompt: str, seed: int, randomize_seed: bool, width: int, height: int, guidance_scale: float, num_inference_steps: int, color_name: str, *args, **kwargs ): tkg_channels = color_name_to_channels(color_name) # TODO: custom channels w_tkg, wo_tkg, seed = infer( prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, tkg_channels=tkg_channels, *args, **kwargs, ) return w_tkg, wo_tkg, seed examples = [ # "1girl, arima kana, oshi no ko, hoshimachi suisei, hoshimachi suisei \(1st costume\), cosplay, looking at viewer, smile, outdoors, night, v, masterpiece, high score, great score, absurdres", "1girl, solo, upper body, looking at viewer, straight-on, masterpiece, best quality", ] with gr.Blocks() as demo: with gr.Column(): gr.Markdown( """ # TKG Chroma-Key with AnimagineXL 4.0 TKG-DM🥚🍚: Training-free Chroma Key Content Generation Diffusion Model - arXiv: https://arxiv.org/abs/2411.15580 - GitHub: https://github.com/ryugo417/TKG-DM """) with gr.Row(): with gr.Column(): prompt = gr.Text( label="Prompt", max_lines=4, placeholder="Enter your prompt", ) negative_prompt = gr.Textbox( label="Negative prompt", max_lines=4, placeholder="Enter a negative prompt", value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry", ) color_set = gr.Dropdown( label="Chroma key color", choices=list(COLOR_SET_MAP.keys()), value="green", ) with gr.Accordion("Advanced Settings", open=False): seed = gr.Slider( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0, ) randomize_seed = gr.Checkbox(label="Randomize seed", value=True) with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=832, ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1152, ) with gr.Row(): guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=50, step=1, value=25, ) with gr.Column(): run_button = gr.Button("Generate", variant="primary") result_w_tkg = gr.Image(label="with TKG") result_wo_tkg = gr.Image(label="without TKG") gr.Examples(examples=examples, inputs=[prompt]) gr.on( triggers=[run_button.click, prompt.submit], fn=on_generate, inputs=[ prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, color_set, ], outputs=[result_w_tkg, result_wo_tkg, seed], ) if __name__ == "__main__": demo.launch()