File size: 10,787 Bytes
dbd510a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772ce93
 
dbd510a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
"""
Main application script for the Gradio interface.

This script initializes the application, loads prerequisite models via model_loader,
defines the user interface using Gradio Blocks, and orchestrates the multi-stage
image generation process by calling functions from the pipelines module.
"""

import gradio as gr
import gradio.themes as gr_themes
import time
import os
import random
# --- Imports from our custom modules ---
try:
    from image_utils import prepare_image
    from model_loader import load_models, are_models_loaded
    from pipelines import run_pose_detection, run_low_res_generation, run_hires_tiling, cleanup_memory
    print("Helper modules imported successfully.")
except ImportError as e:
    print(f"ERROR: Failed to import required local modules: {e}")
    print("Please ensure prompts.py, image_utils.py, model_loader.py, and pipelines.py are in the same directory.")
    raise SystemExit(f"Module import failed: {e}")

# --- Constants & UI Configuration ---
DEFAULT_SEED = 1024
DEFAULT_STEPS_LOWRES = 30
DEFAULT_GUIDANCE_LOWRES = 8.0
DEFAULT_STRENGTH_LOWRES = 0.05
DEFAULT_CN_SCALE_LOWRES = 1.0

DEFAULT_STEPS_HIRES = 20
DEFAULT_GUIDANCE_HIRES = 8.0
DEFAULT_STRENGTH_HIRES = 0.75
DEFAULT_CN_SCALE_HIRES = 1.0

# OUTPUT_DIR = "outputs"
# os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Load Prerequisite Models at Startup ---
if not are_models_loaded():
    print("Initial model loading required...")
    load_successful = load_models()
    if not load_successful:
        print("FATAL: Failed to load prerequisite models. The application may not work correctly.")
else:
    print("Models were already loaded.")


# --- Main Processing Function ---
def generate_full_pipeline(
    input_image_path,
    progress=gr.Progress(track_tqdm=True)
    ):
    """
    Orchestrates the entire image generation workflow.

    This function is called when the user clicks the 'Generate' button in the UI.
    It takes inputs from the UI, calls the necessary processing steps in sequence
    (prepare, detect pose, low-res gen, hi-res gen), updates the progress bar,
    and returns the final generated image.

    Args:
        input_image_path (str): Path to the uploaded input image file.
        seed (int): Random seed for generation.
        steps_lowres (int): Inference steps for the low-resolution stage.
        guidance_lowres (float): Guidance scale for the low-resolution stage.
        strength_lowres (float): Img2Img strength for the low-resolution stage.
        cn_scale_lowres (float): ControlNet scale for the low-resolution stage.
        steps_hires (int): Inference steps per tile for the high-resolution stage.
        guidance_hires (float): Guidance scale for the high-resolution stage.
        strength_hires (float): Img2Img strength for the high-resolution stage.
        cn_scale_hires (float): ControlNet scale for the high-resolution stage.
        progress (gr.Progress): Gradio progress tracking object.

    Returns:
        PIL.Image.Image | None: The final generated high-resolution image,
        or the low-resolution image as a fallback if
        tiling fails, or None if critical errors occur early.

    Raises:
        gr.Error: If critical steps like image preparation or pose detection fail.
        gr.Warning: If hi-res tiling fails but low-res succeeded (returns low-res).
    """
    print(f"\n--- Starting New Generation Run ---")
    run_start_time = time.time()

    current_seed = DEFAULT_SEED
    if current_seed == -1:
        current_seed = random.randint(0, 9999999)
        print(f"Using Random Seed: {current_seed}")
    else:
        print(f"Using Fixed Seed: {current_seed}")

    low_res_image = None
    final_image = None

    try:
        progress(0.05, desc="Preparing Input Image...")
        resized_input_image = prepare_image(input_image_path, target_size=512)
        if resized_input_image is None:
            raise gr.Error("Failed to load or prepare the input image. Check format/corruption.")

        progress(0.15, desc="Detecting Pose...")
        pose_map = run_pose_detection(resized_input_image)
        if pose_map is None:
            raise gr.Error("Failed to detect pose from the input image.")
        # try: pose_map.save(os.path.join(OUTPUT_DIR, f"pose_map_{current_seed}.png"))
        # except Exception as save_e: print(f"Warning: Could not save pose map: {save_e}")


        progress(0.25, desc="Starting Low-Res Generation...")
        low_res_image = run_low_res_generation(
            resized_input_image=resized_input_image,
            pose_map=pose_map,
            seed=int(current_seed),
            steps=int(DEFAULT_STEPS_LOWRES),
            guidance_scale=float(DEFAULT_GUIDANCE_LOWRES),
            strength=float(DEFAULT_STRENGTH_LOWRES),
            controlnet_scale=float(DEFAULT_CN_SCALE_LOWRES),
            progress=progress
        )
        print("Low-res generation stage completed successfully.")
        # try: low_res_image.save(os.path.join(OUTPUT_DIR, f"lowres_output_{current_seed}.png"))
        # except Exception as save_e: print(f"Warning: Could not save low-res image: {save_e}")
        progress(0.45, desc="Low-Res Generation Complete.")


        progress(0.50, desc="Starting Hi-Res Tiling...")
        final_image = run_hires_tiling(
            low_res_image=low_res_image,
            seed=int(current_seed),
            steps=int(DEFAULT_STEPS_HIRES),
            guidance_scale=float(DEFAULT_GUIDANCE_HIRES),
            strength=float(DEFAULT_STRENGTH_HIRES),
            controlnet_scale=float(DEFAULT_CN_SCALE_HIRES),
            upscale_factor=2,
            tile_size=1024,
            tile_stride=1024,
            progress=progress
        )
        print("Hi-res tiling stage completed successfully.")
        # try: final_image.save(os.path.join(OUTPUT_DIR, f"hires_output_{current_seed}.png"))
        # except Exception as save_e: print(f"Warning: Could not save final image: {save_e}")

        progress(1.0, desc="Complete!")

    except gr.Error as e:
        print(f"Gradio Error occurred: {e}")
        if final_image is None and low_res_image is not None and ("tiling" in str(e).lower() or "hi-res" in str(e).lower()):
            gr.Warning(f"High-resolution upscaling failed ({e}). Returning low-resolution image.")
            final_image = low_res_image
        else:
            raise e
    except Exception as e:
        print(f"An unexpected error occurred in generate_full_pipeline: {e}")
        import traceback
        traceback.print_exc()
        raise gr.Error(f"An unexpected error occurred: {e}")
    finally:
        print("Running final cleanup check...")
        cleanup_memory()
        run_end_time = time.time()
        print(f"--- Full Pipeline Run Finished in {run_end_time - run_start_time:.2f} seconds ---")

    return final_image


# --- Gradio Interface Definition ---

theme = gr_themes.Soft(primary_hue=gr_themes.colors.blue, secondary_hue=gr_themes.colors.sky)

# New, improved Markdown description
DESCRIPTION = f"""
<div style="text-align: center;">
    <h1 style="font-family: Impact, Charcoal, sans-serif; font-size: 280%; font-weight: 900; margin-bottom: 16px;"> 
    Pose-Preserving Comicfier
    </h1>
    <p style="margin-bottom: 12; font-size: 94%">
    Transform your photos into the gritty style of a 1940s Western comic! This app uses (Stable Diffusion + ControlNet)
    to apply the artistic look while keeping the original pose intact. Just upload your image and click Generate!
    </p>
    <p style="font-size: 85%;"><em>(Generation can take several minutes on shared hardware. Prompts & parameters are fixed.)</em></p>
    <p style="font-size: 80%; color: grey;">
    <a href="https://github.com/mehran-khani/Pose-Preserving-Comicfier" target="_blank">[View Project on GitHub]</a> | 
    <a href="https://huggingface.co/spaces/Mer-o/Pose-Preserving-Comicfier/discussions" target="_blank">[Report an Issue]</a> 
    </p> 
    <!-- Remember to replace placeholders above with your actual links -->
</div>
"""

EXAMPLE_IMAGES_DIR = "examples"
EXAMPLE_IMAGES = [
    os.path.join(EXAMPLE_IMAGES_DIR, "example1.jpg"),
    os.path.join(EXAMPLE_IMAGES_DIR, "example2.jpg"),
    os.path.join(EXAMPLE_IMAGES_DIR, "example3.jpg"),
    os.path.join(EXAMPLE_IMAGES_DIR, "example4.jpg"),
    os.path.join(EXAMPLE_IMAGES_DIR, "example5.jpg"),
    os.path.join(EXAMPLE_IMAGES_DIR, "example6.jpg"),
]
EXAMPLE_IMAGES = [img for img in EXAMPLE_IMAGES if os.path.exists(img)]

CUSTOM_CSS = """
/* Target the container div Gradio uses for the Image component */
.gradio-image {
    width: 100%;   /* Ensure the container fills the column width */
    height: 100%;  /* Ensure the container fills the height set by the component (e.g., height=400) */
    overflow: hidden; /* Hide any potential overflow before object-fit applies */
}

/* Target the actual <img> tag inside the container */
.gradio-image img { 
    display: block;    /* Remove potential bottom spacing */
    width: 100%;       /* Force image width to match container */
    height: 100%;      /* Force image height to match container */
    object-fit: cover; /* Scale/crop image to cover this forced W/H */
} 

footer { visibility: hidden } 
"""

with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Pose-Preserving Comicfier") as demo:
    gr.HTML(DESCRIPTION)

    with gr.Row():
        # Input Column
        with gr.Column(scale=1, min_width=350):
            # REMOVED height=400
            input_image = gr.Image(
                type="filepath",
                label="Upload Your Image Here"
            )
            generate_button = gr.Button("Generate Comic Image", variant="primary")

        # Output Column
        with gr.Column(scale=1, min_width=350):
             # REMOVED height=400
            output_image = gr.Image(
                type="pil",
                label="Generated Comic Image",
                interactive=False
            )


    # Examples Section
    if EXAMPLE_IMAGES:
        gr.Examples(
            examples=EXAMPLE_IMAGES,
            inputs=[input_image],
            outputs=[output_image],
            fn=generate_full_pipeline,
            cache_examples=False
        )

    generate_button.click(
        fn=generate_full_pipeline,
        inputs=[input_image],
        outputs=[output_image],
        api_name="generate"
    )


# --- Launch the Gradio App ---
if __name__ == "__main__":
    if not are_models_loaded():
        print("Attempting to load models before launch...")
        if not load_models():
             print("FATAL: Model loading failed on launch. App may not function.")

    print("Attempting to launch Gradio demo...")
    demo.queue().launch(debug=False, share=False)
    print("Gradio app launched. Access it at the URL provided above.")