File size: 3,294 Bytes
9235b7f
 
b3a0761
371bdca
50bfc5a
 
 
 
a7bee92
 
 
 
 
 
50bfc5a
c97fcf1
 
 
a7bee92
371bdca
50bfc5a
 
9235b7f
a7bee92
9235b7f
 
 
d033e91
ab0b470
6ef3309
50bfc5a
 
 
 
6ef3309
b3a0761
9235b7f
 
e6730cb
 
9235b7f
fc8037f
7802e94
fc8037f
9235b7f
 
 
 
 
114a69f
 
9235b7f
114a69f
 
 
9235b7f
 
 
 
114a69f
9235b7f
 
 
 
 
 
114a69f
9235b7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import os
import gradio as gr
from scepter.modules.utils.file_system import FS
from huggingface_hub import hf_hub_download, snapshot_download

def resolve_hf_path(path):
    if isinstance(path, str) and path.startswith("hf://"):
        parts = path[len("hf://"):].split("@")
        if len(parts) == 1:
            repo_id = parts[0]
            filename = None
        elif len(parts) == 2:
            repo_id, filename = parts
        else:
            raise ValueError(f"Invalid HF URI format: {path}")
        token = os.environ.get("HUGGINGFACE_HUB_TOKEN")
        if token is None:
            raise ValueError("HUGGINGFACE_HUB_TOKEN environment variable not set!")
        # If filename is provided, download that file; otherwise, download the whole repo snapshot.
        local_path = hf_hub_download(repo_id=repo_id, filename=filename, token=token) if filename else snapshot_download(repo_id=repo_id, token=token)
        return local_path
    return path

os.environ["FLUX_FILL_PATH"] = "hf://black-forest-labs/FLUX.1-Fill-dev"
os.environ["PORTRAIT_MODEL_PATH"] = "ms://iic/ACE_Plus@portrait/comfyui_portrait_lora64.safetensors"
os.environ["SUBJECT_MODEL_PATH"] = "ms://iic/ACE_Plus@subject/comfyui_subject_lora16.safetensors"
os.environ["LOCAL_MODEL_PATH"] = "ms://iic/ACE_Plus@local_editing/comfyui_local_lora16.safetensors"
os.environ["ACE_PLUS_FFT_MODEL"] = "hf://ali-vilab/ACE_Plus@ace_plus_fft.safetensors"

flux_full = resolve_hf_path(os.environ["FLUX_FILL_PATH"])
ace_plus_fft_model_path = resolve_hf_path(os.environ["ACE_PLUS_FFT_MODEL"])

# Update the environment variables with the resolved local file paths.
os.environ["ACE_PLUS_FFT_MODEL"] = ace_plus_fft_model_path
os.environ["FLUX_FILL_PATH"] = flux_full

from inference.ace_plus_inference import ACEInference
from scepter.modules.utils.config import Config
from modules.flux import FluxMRModiACEPlus
from inference.registry import INFERENCES


config_path = os.path.join("config", "ace_plus_fft.yaml")
cfg = Config(load=True, cfg_file=config_path)

# Instantiate the ACEInference object.
ace_infer = ACEInference(cfg)

def face_swap_app(target_img, face_img):
    if target_img is None or face_img is None:
        raise ValueError("Both a target image and a face image must be provided.")

    # (Optional) Ensure images are in RGB
    target_img = target_img.convert("RGB")
    face_img = face_img.convert("RGB")

    output_img, edit_image, change_image, mask, seed = ace_infer(
        reference_image=target_img,
        edit_image=face_img,
        edit_mask=None,          # Let ACE++ generate the mask automatically
        prompt="Face swap",
        output_height=1024,
        output_width=1024,
        sampler='flow_euler',
        sample_steps=28,
        guide_scale=50,
        seed=-1                  # Random seed if not provided
    )
    return output_img

# Create the Gradio interface.
iface = gr.Interface(
    fn=face_swap_app,
    inputs=[
        gr.Image(type="pil", label="Target Image"),
        gr.Image(type="pil", label="Face Image")
    ],
    outputs=gr.Image(type="pil", label="Swapped Face Output"),
    title="ACE++ Face Swap Demo",
    description="Upload a target image and a face image to swap the face using the ACE++ model."
)

if __name__ == "__main__":
    iface.launch()