Spaces:
Running
Running
Last commit not found
import gradio as gr | |
import diffusers | |
import random | |
import json | |
diffusers.utils.logging.set_verbosity_error() | |
import torch | |
from PIL import Image | |
import numpy as np | |
from unet.unet_controller import UNetController | |
from main import load_unet_controller | |
from unet import utils | |
# Global flag to control interruption | |
interrupt_flag = False | |
def main_gradio(model_path, id_prompt, frame_prompt_list, precision, seed, window_length, alpha_weaken, beta_weaken, alpha_enhance, beta_enhance, ipca_drop_out, use_freeu, use_same_init_noise): | |
global interrupt_flag | |
interrupt_flag = False # Reset the flag at the start of the function | |
if seed == -1: | |
seed = random.randint(0, 2**32 - 1) | |
frame_prompt_list = frame_prompt_list.split(",") | |
pipe, _ = utils.load_pipe_from_path(model_path, "cuda:1", torch.float16 if precision == "fp16" else torch.float32, precision) | |
if interrupt_flag: | |
print("Generation interrupted") | |
del pipe | |
torch.cuda.empty_cache() | |
if 'story_image' not in locals(): | |
empty_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) | |
yield empty_image | |
return | |
unet_controller = load_unet_controller(pipe, "cuda:1") | |
unet_controller.Alpha_enhance = alpha_enhance | |
unet_controller.Beta_enhance = beta_enhance | |
unet_controller.Alpha_weaken = alpha_weaken | |
unet_controller.Beta_weaken = beta_weaken | |
unet_controller.Ipca_dropout = ipca_drop_out | |
unet_controller.Is_freeu_enabled = use_freeu | |
unet_controller.Use_same_init_noise = use_same_init_noise | |
import os | |
from datetime import datetime | |
current_time = datetime.now().strftime("%Y%m%d%H") | |
current_time_ = datetime.now().strftime("%M%S") | |
save_dir = os.path.join(".", f'result/{current_time}/{current_time_}_gradio_seed{seed}') | |
os.makedirs(save_dir, exist_ok=True) | |
generate = torch.Generator().manual_seed(seed) | |
if unet_controller.Use_ipca is True: | |
unet_controller.Store_qkv = True | |
original_prompt_embeds_mode = unet_controller.Prompt_embeds_mode | |
unet_controller.Prompt_embeds_mode = "original" | |
_ = pipe(id_prompt, generator=generate, unet_controller=unet_controller).images | |
unet_controller.Prompt_embeds_mode = original_prompt_embeds_mode | |
unet_controller.Store_qkv = False | |
max_window_length = utils.get_max_window_length(unet_controller, id_prompt, frame_prompt_list) | |
window_length = min(window_length, max_window_length) | |
if window_length < len(frame_prompt_list): | |
movement_lists = utils.circular_sliding_windows(frame_prompt_list, window_length) | |
else: | |
movement_lists = [movement for movement in frame_prompt_list] | |
story_image_list = [] | |
generate = torch.Generator().manual_seed(seed) | |
unet_controller.id_prompt = id_prompt | |
for index, movement in enumerate(frame_prompt_list): | |
if interrupt_flag: | |
print("Generation interrupted") | |
del pipe | |
torch.cuda.empty_cache() | |
if 'story_image' not in locals(): | |
empty_image = Image.fromarray(np.zeros((512, 512, 3), dtype=np.uint8)) | |
yield empty_image | |
return | |
if unet_controller is not None: | |
if window_length < len(frame_prompt_list): | |
unet_controller.frame_prompt_suppress = movement_lists[index][1:] | |
unet_controller.frame_prompt_express = movement_lists[index][0] | |
gen_propmts = [f'{id_prompt} {" ".join(movement_lists[index])}'] | |
else: | |
unet_controller.frame_prompt_suppress = movement_lists[:index] + movement_lists[index+1:] | |
unet_controller.frame_prompt_express = movement_lists[index] | |
gen_propmts = [f'{id_prompt} {" ".join(movement_lists)}'] | |
else: | |
gen_propmts = f'{id_prompt} {movement}' | |
print(f"suppress: {unet_controller.frame_prompt_suppress}") | |
print(f"express: {unet_controller.frame_prompt_express}") | |
print(f'id_prompt: {id_prompt}') | |
print(f"gen_propmts: {gen_propmts}") | |
if unet_controller is not None and unet_controller.Use_same_init_noise is True: | |
generate = torch.Generator().manual_seed(seed) | |
images = pipe(gen_propmts, generator=generate, unet_controller=unet_controller).images | |
story_image_list.append(images[0]) | |
story_image = np.concatenate(story_image_list, axis=1) | |
story_image = Image.fromarray(story_image.astype(np.uint8)) | |
yield story_image | |
import os | |
images[0].save(os.path.join(save_dir, f'{id_prompt} {unet_controller.frame_prompt_express}.jpg')) | |
story_image.save(os.path.join(save_dir, 'story_image.jpg')) | |
import gc | |
del pipe | |
gc.collect() | |
torch.cuda.empty_cache() | |
# Gradio interface | |
def gradio_interface(): | |
global interrupt_flag | |
with gr.Blocks() as demo: | |
gr.Markdown("### Consistent Image Generation with 1Prompt1Story") | |
# Load JSON data | |
with open('./resource/example.json', 'r') as f: | |
data = json.load(f) | |
# Extract id_prompts and frame_prompts | |
id_prompts = [item['id_prompt'] for item in data['combinations']] | |
frame_prompts = [", ".join(item['frame_prompt_list']) for item in data['combinations']] | |
# Input fields | |
id_prompt = gr.Dropdown( | |
label="ID Prompt", | |
choices=id_prompts, | |
value=id_prompts[0], | |
allow_custom_value=True | |
) | |
frame_prompt_list = gr.Dropdown( | |
label="Frame Prompts (comma-separated)", | |
choices=frame_prompts, | |
value=frame_prompts[0], | |
allow_custom_value=True | |
) | |
model_path = gr.Dropdown( | |
label="Model Path", | |
choices=["stabilityai/stable-diffusion-xl-base-1.0", "RunDiffusion/Juggernaut-X-v10", "playgroundai/playground-v2.5-1024px-aesthetic", "SG161222/RealVisXL_V4.0", "RunDiffusion/Juggernaut-XI-v11", "SG161222/RealVisXL_V5.0"], | |
value="playgroundai/playground-v2.5-1024px-aesthetic", | |
allow_custom_value=True | |
) | |
with gr.Row(): | |
seed = gr.Slider(label="Seed (set -1 for random seed)", minimum=-1, maximum=10000, value=-1, step=1) | |
window_length = gr.Slider(label="Window Length", minimum=1, maximum=20, value=10, step=1) | |
with gr.Row(): | |
alpha_weaken = gr.Number(label="Alpha Weaken", value=UNetController.Alpha_weaken, interactive=True, step=0.01) | |
beta_weaken = gr.Number(label="Beta Weaken", value=UNetController.Beta_weaken, interactive=True, step=0.01) | |
alpha_enhance = gr.Number(label="Alpha Enhance", value=UNetController.Alpha_enhance, interactive=True, step=0.001) | |
beta_enhance = gr.Number(label="Beta Enhance", value=UNetController.Beta_enhance, interactive=True, step=0.1) | |
with gr.Row(): | |
ipca_drop_out = gr.Number(label="Ipca Dropout", value=UNetController.Ipca_dropout, interactive=True, step=0.1, minimum=0, maximum=1) | |
precision = gr.Dropdown(label="Precision", choices=["fp16", "fp32"], value="fp16") | |
use_freeu = gr.Dropdown(label="Use FreeU", choices=[False, True], value=UNetController.Is_freeu_enabled) | |
use_same_init_noise = gr.Dropdown(label="Use Same Init Noise", choices=[True, False], value=UNetController.Use_same_init_noise) | |
reset_button = gr.Button("Reset to Default") | |
def reset_values(): | |
return UNetController.Alpha_weaken, UNetController.Beta_weaken, UNetController.Alpha_enhance, UNetController.Beta_enhance, UNetController.Ipca_dropout, "fp16", UNetController.Is_freeu_enabled, UNetController.Use_same_init_noise | |
reset_button.click( | |
fn=reset_values, | |
inputs=[], | |
outputs=[alpha_weaken, beta_weaken, alpha_enhance, beta_enhance, ipca_drop_out, precision, use_freeu, use_same_init_noise] | |
) | |
# Output | |
output_gallery = gr.Image() | |
# Buttons | |
generate_button = gr.Button("Generate Images") | |
interrupt_button = gr.Button("Interrupt") | |
def interrupt_generation(): | |
global interrupt_flag | |
interrupt_flag = True | |
interrupt_button.click( | |
fn=interrupt_generation, | |
inputs=[], | |
outputs=[] | |
) | |
generate_button.click( | |
fn=main_gradio, | |
inputs=[ | |
model_path, id_prompt, frame_prompt_list, precision, seed, window_length, alpha_weaken, beta_weaken, alpha_enhance, beta_enhance, ipca_drop_out, use_freeu, use_same_init_noise | |
], | |
outputs=output_gallery | |
) | |
return demo | |
if __name__ == "__main__": | |
demo = gradio_interface() | |
demo.launch(share=True) | |