Spaces:
Sleeping
Sleeping
Last commit not found
import os | |
import gradio as gr | |
import huggingface_hub | |
import pillow_avif | |
import spaces | |
import torch | |
import gc | |
from huggingface_hub import snapshot_download | |
from pillow_heif import register_heif_opener | |
from pipelines.pipeline_infu_flux import InfUFluxPipeline | |
# Register HEIF support for Pillow | |
register_heif_opener() | |
class ModelVersion: | |
STAGE_1 = "sim_stage1" | |
STAGE_2 = "aes_stage2" | |
DEFAULT_VERSION = STAGE_2 | |
ENABLE_ANTI_BLUR_DEFAULT = False | |
ENABLE_REALISM_DEFAULT = False | |
loaded_pipeline_config = { | |
"model_version": "aes_stage2", | |
"enable_realism": False, | |
"enable_anti_blur": False, | |
'pipeline': None | |
} | |
def download_models(): | |
snapshot_download(repo_id='ByteDance/InfiniteYou', local_dir='./models/InfiniteYou', local_dir_use_symlinks=False) | |
try: | |
snapshot_download(repo_id='black-forest-labs/FLUX.1-dev', local_dir='./models/FLUX.1-dev', local_dir_use_symlinks=False) | |
except Exception as e: | |
print(e) | |
print('\nYou are downloading `black-forest-labs/FLUX.1-dev` to `./models/FLUX.1-dev` but failed. ' | |
'Please accept the agreement and obtain access at https://huggingface.co/black-forest-labs/FLUX.1-dev. ' | |
'Then, use `huggingface-cli login` and your access tokens at https://huggingface.co/settings/tokens to authenticate. ' | |
'After that, run the code again.') | |
print('\nYou can also download it manually from HuggingFace and put it in `./models/InfiniteYou`, ' | |
'or you can modify `base_model_path` in `app.py` to specify the correct path.') | |
exit() | |
def init_pipeline(model_version, enable_realism, enable_anti_blur): | |
loaded_pipeline_config["enable_realism"] = enable_realism | |
loaded_pipeline_config["enable_anti_blur"] = enable_anti_blur | |
loaded_pipeline_config["model_version"] = model_version | |
pipeline = loaded_pipeline_config['pipeline'] | |
gc.collect() | |
torch.cuda.empty_cache() | |
model_path = f'./models/InfiniteYou/infu_flux_v1.0/{model_version}' | |
print(f'loading model from {model_path}') | |
pipeline = InfUFluxPipeline( | |
base_model_path='./models/FLUX.1-dev', | |
infu_model_path=model_path, | |
insightface_root_path='./models/InfiniteYou/supports/insightface', | |
image_proj_num_tokens=8, | |
infu_flux_version='v1.0', | |
model_version=model_version, | |
) | |
loaded_pipeline_config['pipeline'] = pipeline | |
pipeline.pipe.delete_adapters(['realism', 'anti_blur']) | |
loras = [] | |
if enable_realism: loras.append(['realism', 1.0]) | |
if enable_anti_blur: loras.append(['anti_blur', 1.0]) | |
pipeline.load_loras_state_dict(loras) | |
return pipeline | |
def prepare_pipeline(model_version, enable_realism, enable_anti_blur): | |
if ( | |
loaded_pipeline_config['pipeline'] is not None | |
and loaded_pipeline_config["enable_realism"] == enable_realism | |
and loaded_pipeline_config["enable_anti_blur"] == enable_anti_blur | |
and model_version == loaded_pipeline_config["model_version"] | |
): | |
return loaded_pipeline_config['pipeline'] | |
loaded_pipeline_config["enable_realism"] = enable_realism | |
loaded_pipeline_config["enable_anti_blur"] = enable_anti_blur | |
loaded_pipeline_config["model_version"] = model_version | |
pipeline = loaded_pipeline_config['pipeline'] | |
if pipeline is None or pipeline.model_version != model_version: | |
print(f'Switching model to {model_version}') | |
pipeline.model_version = model_version | |
if model_version == 'aes_stage2': | |
pipeline.infusenet_sim.cpu() | |
pipeline.image_proj_model_sim.cpu() | |
torch.cuda.empty_cache() | |
pipeline.infusenet_aes.to('cuda') | |
pipeline.pipe.controlnet = pipeline.infusenet_aes | |
pipeline.image_proj_model_aes.to('cuda') | |
pipeline.image_proj_model = pipeline.image_proj_model_aes | |
else: | |
pipeline.infusenet_aes.cpu() | |
pipeline.image_proj_model_aes.cpu() | |
torch.cuda.empty_cache() | |
pipeline.infusenet_sim.to('cuda') | |
pipeline.pipe.controlnet = pipeline.infusenet_sim | |
pipeline.image_proj_model_sim.to('cuda') | |
pipeline.image_proj_model = pipeline.image_proj_model_sim | |
loaded_pipeline_config['pipeline'] = pipeline | |
pipeline.pipe.delete_adapters(['realism', 'anti_blur']) | |
loras = [] | |
if enable_realism: loras.append(['realism', 1.0]) | |
if enable_anti_blur: loras.append(['anti_blur', 1.0]) | |
pipeline.load_loras_state_dict(loras) | |
return pipeline | |
def generate_image( | |
input_image, | |
control_image, | |
prompt, | |
seed, | |
width, | |
height, | |
guidance_scale, | |
num_steps, | |
infusenet_conditioning_scale, | |
infusenet_guidance_start, | |
infusenet_guidance_end, | |
enable_realism, | |
enable_anti_blur, | |
model_version | |
): | |
try: | |
pipeline = prepare_pipeline(model_version=model_version, enable_realism=enable_realism, enable_anti_blur=enable_anti_blur) | |
if seed == 0: | |
seed = torch.seed() & 0xFFFFFFFF | |
image = pipeline( | |
id_image=input_image, | |
prompt=prompt, | |
control_image=control_image, | |
seed=seed, | |
width=width, | |
height=height, | |
guidance_scale=guidance_scale, | |
num_steps=num_steps, | |
infusenet_conditioning_scale=infusenet_conditioning_scale, | |
infusenet_guidance_start=infusenet_guidance_start, | |
infusenet_guidance_end=infusenet_guidance_end, | |
) | |
except Exception as e: | |
print(e) | |
gr.Error(f"An error occurred: {e}") | |
return gr.update() | |
return gr.update(value=image, label=f"Generated Image, seed = {seed}") | |
# μ΄ ν¨μλ μ¬μ©λμ§ μμΌλ―λ‘ μ κ±°ν΄λ λ©λλ€ | |
# def generate_examples(id_image, control_image, prompt_text, seed, enable_realism, enable_anti_blur, model_version): | |
# return generate_image(id_image, control_image, prompt_text, seed, 864, 1152, 3.5, 30, 1.0, 0.0, 1.0, enable_realism, enable_anti_blur, model_version) | |
with gr.Blocks() as demo: | |
session_state = gr.State({}) | |
default_model_version = "v1.0" | |
with gr.Row(): | |
with gr.Column(scale=3): | |
with gr.Row(): | |
ui_id_image = gr.Image(label="Identity Image", type="pil", scale=3, height=370, min_width=100) | |
with gr.Column(scale=2, min_width=100): | |
ui_control_image = gr.Image(label="Control Image [Optional]", type="pil", height=370, min_width=100) | |
ui_prompt_text = gr.Textbox(label="Prompt", value="Portrait, 4K, high quality, cinematic") | |
ui_model_version = gr.Dropdown( | |
label="Model Version", | |
choices=[ModelVersion.STAGE_1, ModelVersion.STAGE_2], | |
value=ModelVersion.DEFAULT_VERSION, | |
) | |
ui_btn_generate = gr.Button("Generate") | |
with gr.Accordion("Advanced", open=False): | |
with gr.Row(): | |
ui_num_steps = gr.Number(label="num steps", value=30) | |
ui_seed = gr.Number(label="seed (0 for random)", value=0) | |
with gr.Row(): | |
ui_width = gr.Number(label="width", value=864) | |
ui_height = gr.Number(label="height", value=1152) | |
ui_guidance_scale = gr.Number(label="guidance scale", value=3.5, step=0.5) | |
ui_infusenet_conditioning_scale = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="infusenet conditioning scale") | |
with gr.Row(): | |
ui_infusenet_guidance_start = gr.Slider(minimum=0.0, maximum=1.0, value=0.0, step=0.05, label="infusenet guidance start") | |
ui_infusenet_guidance_end = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.05, label="infusenet guidance end") | |
with gr.Accordion("LoRAs [Optional]", open=True): | |
with gr.Row(): | |
ui_enable_realism = gr.Checkbox(label="Enable realism LoRA", value=ENABLE_REALISM_DEFAULT) | |
ui_enable_anti_blur = gr.Checkbox(label="Enable anti-blur LoRA", value=ENABLE_ANTI_BLUR_DEFAULT) | |
with gr.Column(scale=2): | |
image_output = gr.Image(label="Generated Image", interactive=False, height=550, format='png') | |
# gr.Examples λΆλΆμ μμ ν μ κ±°νμ΅λλ€ | |
ui_btn_generate.click( | |
generate_image, | |
inputs=[ | |
ui_id_image, | |
ui_control_image, | |
ui_prompt_text, | |
ui_seed, | |
ui_width, | |
ui_height, | |
ui_guidance_scale, | |
ui_num_steps, | |
ui_infusenet_conditioning_scale, | |
ui_infusenet_guidance_start, | |
ui_infusenet_guidance_end, | |
ui_enable_realism, | |
ui_enable_anti_blur, | |
ui_model_version | |
], | |
outputs=[image_output], | |
concurrency_id="gpu" | |
) | |
huggingface_hub.login(os.getenv('PRIVATE_HF_TOKEN')) | |
download_models() | |
init_pipeline(model_version=ModelVersion.DEFAULT_VERSION, enable_realism=ENABLE_REALISM_DEFAULT, enable_anti_blur=ENABLE_ANTI_BLUR_DEFAULT) | |
# demo.queue() | |
demo.launch() | |
# demo.launch(server_name='0.0.0.0') # IPv4 | |
# demo.launch(server_name='[::]') # IPv6 |