Spaces:
Running
on
Zero
Running
on
Zero
| import spaces # import first | |
| import random | |
| import numpy as np | |
| import torch | |
| from diffusers import StableDiffusionXLPipeline | |
| import gradio as gr | |
| from tkg import apply_tkg_noise, ColorSet, COLOR_SET_MAP | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| device = "cuda" | |
| model_repo_id = "cagliostrolab/animagine-xl-4.0" # Replace to the model you would like to use | |
| pipe = StableDiffusionXLPipeline.from_pretrained( | |
| "cagliostrolab/animagine-xl-4.0", | |
| torch_dtype=torch.bfloat16, | |
| custom_pipeline="lpw_stable_diffusion_xl", | |
| add_watermarker=False, | |
| ) | |
| pipe = pipe.to(device) | |
| MAX_SEED = np.iinfo(np.int32).max | |
| MAX_IMAGE_SIZE = 2048 | |
| def infer( | |
| prompt: str, | |
| negative_prompt: str, | |
| seed: int, | |
| randomize_seed: bool, | |
| width: int, | |
| height: int, | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| tkg_channels: list[int] = [0, 1, 1, 0], | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| if randomize_seed: | |
| seed = random.randint(0, MAX_SEED) | |
| generator = torch.Generator(device=device).manual_seed(seed) | |
| latents = torch.randn( | |
| ( | |
| 1, | |
| 4, # 4 channels | |
| height // 8, | |
| width // 8, | |
| ), | |
| generator=generator, | |
| device=device, | |
| dtype=torch.bfloat16, | |
| ) | |
| latents = apply_tkg_noise( | |
| latents, | |
| shift=0.11, | |
| delta_shift=0.1, | |
| std_dev=0.5, | |
| factor=8, | |
| channels=tkg_channels, | |
| ) | |
| images = pipe( | |
| latents=latents, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| guidance_scale=guidance_scale, | |
| num_inference_steps=num_inference_steps, | |
| width=width, | |
| height=height, | |
| generator=generator, | |
| ).images | |
| w_tkg, wo_tkg = images | |
| return w_tkg, wo_tkg, seed | |
| def color_name_to_channels(color_name: str) -> list[int]: | |
| if color_name in COLOR_SET_MAP: | |
| return COLOR_SET_MAP[color_name].channels | |
| else: | |
| raise ValueError(f"Unknown color name: {color_name}") | |
| def on_generate( | |
| prompt: str, | |
| negative_prompt: str, | |
| seed: int, | |
| randomize_seed: bool, | |
| width: int, | |
| height: int, | |
| guidance_scale: float, | |
| num_inference_steps: int, | |
| color_name: str, | |
| *args, | |
| **kwargs | |
| ): | |
| tkg_channels = color_name_to_channels(color_name) | |
| # TODO: custom channels | |
| w_tkg, wo_tkg, seed = infer( | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| tkg_channels=tkg_channels, | |
| *args, | |
| **kwargs, | |
| ) | |
| return w_tkg, wo_tkg, seed | |
| examples = [ | |
| # "1girl, arima kana, oshi no ko, hoshimachi suisei, hoshimachi suisei \(1st costume\), cosplay, looking at viewer, smile, outdoors, night, v, masterpiece, high score, great score, absurdres", | |
| "1girl, solo, upper body, looking at viewer, straight-on, masterpiece, best quality", | |
| ] | |
| with gr.Blocks() as demo: | |
| with gr.Column(): | |
| gr.Markdown( | |
| """ | |
| # TKG Chroma-Key with AnimagineXL 4.0 | |
| TKG-DMπ₯π: Training-free Chroma Key Content Generation Diffusion Model | |
| - arXiv: https://arxiv.org/abs/2411.15580 | |
| - GitHub: https://github.com/ryugo417/TKG-DM | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Text( | |
| label="Prompt", | |
| max_lines=4, | |
| placeholder="Enter your prompt", | |
| ) | |
| negative_prompt = gr.Textbox( | |
| label="Negative prompt", | |
| max_lines=4, | |
| placeholder="Enter a negative prompt", | |
| value="lowres, bad anatomy, bad hands, text, error, missing finger, extra digits, fewer digits, cropped, worst quality, low quality, low score, bad score, average score, signature, watermark, username, blurry", | |
| ) | |
| color_set = gr.Dropdown( | |
| label="Chroma key color", | |
| choices=list(COLOR_SET_MAP.keys()), | |
| value="green", | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| seed = gr.Slider( | |
| label="Seed", | |
| minimum=0, | |
| maximum=MAX_SEED, | |
| step=1, | |
| value=0, | |
| ) | |
| randomize_seed = gr.Checkbox(label="Randomize seed", value=True) | |
| with gr.Row(): | |
| width = gr.Slider( | |
| label="Width", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=32, | |
| value=832, | |
| ) | |
| height = gr.Slider( | |
| label="Height", | |
| minimum=256, | |
| maximum=MAX_IMAGE_SIZE, | |
| step=32, | |
| value=1152, | |
| ) | |
| with gr.Row(): | |
| guidance_scale = gr.Slider( | |
| label="Guidance scale", | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=5.0, | |
| ) | |
| num_inference_steps = gr.Slider( | |
| label="Number of inference steps", | |
| minimum=1, | |
| maximum=50, | |
| step=1, | |
| value=25, | |
| ) | |
| with gr.Column(): | |
| run_button = gr.Button("Generate", variant="primary") | |
| result_w_tkg = gr.Image(label="with TKG") | |
| result_wo_tkg = gr.Image(label="without TKG") | |
| gr.Examples(examples=examples, inputs=[prompt]) | |
| gr.on( | |
| triggers=[run_button.click, prompt.submit], | |
| fn=on_generate, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| randomize_seed, | |
| width, | |
| height, | |
| guidance_scale, | |
| num_inference_steps, | |
| color_set, | |
| ], | |
| outputs=[result_w_tkg, result_wo_tkg, seed], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |