Last commit not found
import os
import gradio as gr
import huggingface_hub
import pillow_avif
import spaces
import torch
import gc
from huggingface_hub import snapshot_download
from pillow_heif import register_heif_opener
from pipelines.pipeline_infu_flux import InfUFluxPipeline
# Register HEIF support for Pillow
register_heif_opener()
class ModelVersion:
STAGE_1 = "sim_stage1"
STAGE_2 = "aes_stage2"
DEFAULT_VERSION = STAGE_2
ENABLE_ANTI_BLUR_DEFAULT = False
ENABLE_REALISM_DEFAULT = False
loaded_pipeline_config = {
"model_version": "aes_stage2",
"enable_realism": False,
"enable_anti_blur": False,
'pipeline': None
}
def download_models():
snapshot_download(repo_id='ByteDance/InfiniteYou', local_dir='./models/InfiniteYou', local_dir_use_symlinks=False)
try:
snapshot_download(repo_id='black-forest-labs/FLUX.1-dev', local_dir='./models/FLUX.1-dev', local_dir_use_symlinks=False)
except Exception as e:
print(e)
print('\nYou are downloading `black-forest-labs/FLUX.1-dev` to `./models/FLUX.1-dev` but failed. '
'Please accept the agreement and obtain access at https://huggingface.co/black-forest-labs/FLUX.1-dev. '
'Then, use `huggingface-cli login` and your access tokens at https://huggingface.co/settings/tokens to authenticate. '
'After that, run the code again.')
print('\nYou can also download it manually from HuggingFace and put it in `./models/InfiniteYou`, '
'or you can modify `base_model_path` in `app.py` to specify the correct path.')
exit()
def init_pipeline(model_version, enable_realism, enable_anti_blur):
loaded_pipeline_config["enable_realism"] = enable_realism
loaded_pipeline_config["enable_anti_blur"] = enable_anti_blur
loaded_pipeline_config["model_version"] = model_version
pipeline = loaded_pipeline_config['pipeline']
gc.collect()
torch.cuda.empty_cache()
model_path = f'./models/InfiniteYou/infu_flux_v1.0/{model_version}'
print(f'loading model from {model_path}')
pipeline = InfUFluxPipeline(
base_model_path='./models/FLUX.1-dev',
infu_model_path=model_path,
insightface_root_path='./models/InfiniteYou/supports/insightface',
image_proj_num_tokens=8,
infu_flux_version='v1.0',
model_version=model_version,
)
loaded_pipeline_config['pipeline'] = pipeline
pipeline.pipe.delete_adapters(['realism', 'anti_blur'])
loras = []
if enable_realism: loras.append(['realism', 1.0])
if enable_anti_blur: loras.append(['anti_blur', 1.0])
pipeline.load_loras_state_dict(loras)
return pipeline
def prepare_pipeline(model_version, enable_realism, enable_anti_blur):
if (
loaded_pipeline_config['pipeline'] is not None
and loaded_pipeline_config["enable_realism"] == enable_realism
and loaded_pipeline_config["enable_anti_blur"] == enable_anti_blur
and model_version == loaded_pipeline_config["model_version"]
):
return loaded_pipeline_config['pipeline']
loaded_pipeline_config["enable_realism"] = enable_realism
loaded_pipeline_config["enable_anti_blur"] = enable_anti_blur
loaded_pipeline_config["model_version"] = model_version
pipeline = loaded_pipeline_config['pipeline']
if pipeline is None or pipeline.model_version != model_version:
print(f'Switching model to {model_version}')
pipeline.model_version = model_version
if model_version == 'aes_stage2':
pipeline.infusenet_sim.cpu()
pipeline.image_proj_model_sim.cpu()
torch.cuda.empty_cache()
pipeline.infusenet_aes.to('cuda')
pipeline.pipe.controlnet = pipeline.infusenet_aes
pipeline.image_proj_model_aes.to('cuda')
pipeline.image_proj_model = pipeline.image_proj_model_aes
else:
pipeline.infusenet_aes.cpu()
pipeline.image_proj_model_aes.cpu()
torch.cuda.empty_cache()
pipeline.infusenet_sim.to('cuda')
pipeline.pipe.controlnet = pipeline.infusenet_sim
pipeline.image_proj_model_sim.to('cuda')
pipeline.image_proj_model = pipeline.image_proj_model_sim
loaded_pipeline_config['pipeline'] = pipeline
pipeline.pipe.delete_adapters(['realism', 'anti_blur'])
loras = []
if enable_realism: loras.append(['realism', 1.0])
if enable_anti_blur: loras.append(['anti_blur', 1.0])
pipeline.load_loras_state_dict(loras)
return pipeline
@spaces.GPU(duration=120)
def generate_image(
input_image,
control_image,
prompt,
seed,
width,
height,
guidance_scale,
num_steps,
infusenet_conditioning_scale,
infusenet_guidance_start,
infusenet_guidance_end,
enable_realism,
enable_anti_blur,
model_version
):
try:
pipeline = prepare_pipeline(model_version=model_version, enable_realism=enable_realism, enable_anti_blur=enable_anti_blur)
if seed == 0:
seed = torch.seed() & 0xFFFFFFFF
image = pipeline(
id_image=input_image,
prompt=prompt,
control_image=control_image,
seed=seed,
width=width,
height=height,
guidance_scale=guidance_scale,
num_steps=num_steps,
infusenet_conditioning_scale=infusenet_conditioning_scale,
infusenet_guidance_start=infusenet_guidance_start,
infusenet_guidance_end=infusenet_guidance_end,
)
except Exception as e:
print(e)
gr.Error(f"An error occurred: {e}")
return gr.update()
return gr.update(value=image, label=f"Generated Image, seed = {seed}")
# 이 ν•¨μˆ˜λŠ” μ‚¬μš©λ˜μ§€ μ•ŠμœΌλ―€λ‘œ μ œκ±°ν•΄λ„ λ©λ‹ˆλ‹€
# def generate_examples(id_image, control_image, prompt_text, seed, enable_realism, enable_anti_blur, model_version):
# return generate_image(id_image, control_image, prompt_text, seed, 864, 1152, 3.5, 30, 1.0, 0.0, 1.0, enable_realism, enable_anti_blur, model_version)
with gr.Blocks() as demo:
session_state = gr.State({})
default_model_version = "v1.0"
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
ui_id_image = gr.Image(label="Identity Image", type="pil", scale=3, height=370, min_width=100)
with gr.Column(scale=2, min_width=100):
ui_control_image = gr.Image(label="Control Image [Optional]", type="pil", height=370, min_width=100)
ui_prompt_text = gr.Textbox(label="Prompt", value="Portrait, 4K, high quality, cinematic")
ui_model_version = gr.Dropdown(
label="Model Version",
choices=[ModelVersion.STAGE_1, ModelVersion.STAGE_2],
value=ModelVersion.DEFAULT_VERSION,
)
ui_btn_generate = gr.Button("Generate")
with gr.Accordion("Advanced", open=False):
with gr.Row():
ui_num_steps = gr.Number(label="num steps", value=30)
ui_seed = gr.Number(label="seed (0 for random)", value=0)
with gr.Row():
ui_width = gr.Number(label="width", value=864)
ui_height = gr.Number(label="height", value=1152)
ui_guidance_scale = gr.Number(label="guidance scale", value=3.5, step=0.5)
ui_infusenet_conditioning_scale = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="infusenet conditioning scale")
with gr.Row():
ui_infusenet_guidance_start = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.05, label="infusenet guidance start")
ui_infusenet_guidance_end = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="infusenet guidance end")
with gr.Accordion("LoRAs [Optional]", open=True):
with gr.Row():
ui_enable_realism = gr.Checkbox(label="Enable realism LoRA", value=ENABLE_REALISM_DEFAULT)
ui_enable_anti_blur = gr.Checkbox(label="Enable anti-blur LoRA", value=ENABLE_ANTI_BLUR_DEFAULT)
with gr.Column(scale=2):
image_output = gr.Image(label="Generated Image", interactive=False, height=550, format='png')
# gr.Examples 뢀뢄을 μ™„μ „νžˆ μ œκ±°ν–ˆμŠ΅λ‹ˆλ‹€
ui_btn_generate.click(
generate_image,
inputs=[
ui_id_image,
ui_control_image,
ui_prompt_text,
ui_seed,
ui_width,
ui_height,
ui_guidance_scale,
ui_num_steps,
ui_infusenet_conditioning_scale,
ui_infusenet_guidance_start,
ui_infusenet_guidance_end,
ui_enable_realism,
ui_enable_anti_blur,
ui_model_version
],
outputs=[image_output],
concurrency_id="gpu"
)
huggingface_hub.login(os.getenv('PRIVATE_HF_TOKEN'))
download_models()
init_pipeline(model_version=ModelVersion.DEFAULT_VERSION, enable_realism=ENABLE_REALISM_DEFAULT, enable_anti_blur=ENABLE_ANTI_BLUR_DEFAULT)
# demo.queue()
demo.launch()
# demo.launch(server_name='0.0.0.0') # IPv4
# demo.launch(server_name='[::]') # IPv6