Spaces:
Running
Running
File size: 10,787 Bytes
dbd510a 772ce93 dbd510a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
"""
Main application script for the Gradio interface.
This script initializes the application, loads prerequisite models via model_loader,
defines the user interface using Gradio Blocks, and orchestrates the multi-stage
image generation process by calling functions from the pipelines module.
"""
import gradio as gr
import gradio.themes as gr_themes
import time
import os
import random
# --- Imports from our custom modules ---
try:
from image_utils import prepare_image
from model_loader import load_models, are_models_loaded
from pipelines import run_pose_detection, run_low_res_generation, run_hires_tiling, cleanup_memory
print("Helper modules imported successfully.")
except ImportError as e:
print(f"ERROR: Failed to import required local modules: {e}")
print("Please ensure prompts.py, image_utils.py, model_loader.py, and pipelines.py are in the same directory.")
raise SystemExit(f"Module import failed: {e}")
# --- Constants & UI Configuration ---
DEFAULT_SEED = 1024
DEFAULT_STEPS_LOWRES = 30
DEFAULT_GUIDANCE_LOWRES = 8.0
DEFAULT_STRENGTH_LOWRES = 0.05
DEFAULT_CN_SCALE_LOWRES = 1.0
DEFAULT_STEPS_HIRES = 20
DEFAULT_GUIDANCE_HIRES = 8.0
DEFAULT_STRENGTH_HIRES = 0.75
DEFAULT_CN_SCALE_HIRES = 1.0
# OUTPUT_DIR = "outputs"
# os.makedirs(OUTPUT_DIR, exist_ok=True)
# --- Load Prerequisite Models at Startup ---
if not are_models_loaded():
print("Initial model loading required...")
load_successful = load_models()
if not load_successful:
print("FATAL: Failed to load prerequisite models. The application may not work correctly.")
else:
print("Models were already loaded.")
# --- Main Processing Function ---
def generate_full_pipeline(
input_image_path,
progress=gr.Progress(track_tqdm=True)
):
"""
Orchestrates the entire image generation workflow.
This function is called when the user clicks the 'Generate' button in the UI.
It takes inputs from the UI, calls the necessary processing steps in sequence
(prepare, detect pose, low-res gen, hi-res gen), updates the progress bar,
and returns the final generated image.
Args:
input_image_path (str): Path to the uploaded input image file.
seed (int): Random seed for generation.
steps_lowres (int): Inference steps for the low-resolution stage.
guidance_lowres (float): Guidance scale for the low-resolution stage.
strength_lowres (float): Img2Img strength for the low-resolution stage.
cn_scale_lowres (float): ControlNet scale for the low-resolution stage.
steps_hires (int): Inference steps per tile for the high-resolution stage.
guidance_hires (float): Guidance scale for the high-resolution stage.
strength_hires (float): Img2Img strength for the high-resolution stage.
cn_scale_hires (float): ControlNet scale for the high-resolution stage.
progress (gr.Progress): Gradio progress tracking object.
Returns:
PIL.Image.Image | None: The final generated high-resolution image,
or the low-resolution image as a fallback if
tiling fails, or None if critical errors occur early.
Raises:
gr.Error: If critical steps like image preparation or pose detection fail.
gr.Warning: If hi-res tiling fails but low-res succeeded (returns low-res).
"""
print(f"\n--- Starting New Generation Run ---")
run_start_time = time.time()
current_seed = DEFAULT_SEED
if current_seed == -1:
current_seed = random.randint(0, 9999999)
print(f"Using Random Seed: {current_seed}")
else:
print(f"Using Fixed Seed: {current_seed}")
low_res_image = None
final_image = None
try:
progress(0.05, desc="Preparing Input Image...")
resized_input_image = prepare_image(input_image_path, target_size=512)
if resized_input_image is None:
raise gr.Error("Failed to load or prepare the input image. Check format/corruption.")
progress(0.15, desc="Detecting Pose...")
pose_map = run_pose_detection(resized_input_image)
if pose_map is None:
raise gr.Error("Failed to detect pose from the input image.")
# try: pose_map.save(os.path.join(OUTPUT_DIR, f"pose_map_{current_seed}.png"))
# except Exception as save_e: print(f"Warning: Could not save pose map: {save_e}")
progress(0.25, desc="Starting Low-Res Generation...")
low_res_image = run_low_res_generation(
resized_input_image=resized_input_image,
pose_map=pose_map,
seed=int(current_seed),
steps=int(DEFAULT_STEPS_LOWRES),
guidance_scale=float(DEFAULT_GUIDANCE_LOWRES),
strength=float(DEFAULT_STRENGTH_LOWRES),
controlnet_scale=float(DEFAULT_CN_SCALE_LOWRES),
progress=progress
)
print("Low-res generation stage completed successfully.")
# try: low_res_image.save(os.path.join(OUTPUT_DIR, f"lowres_output_{current_seed}.png"))
# except Exception as save_e: print(f"Warning: Could not save low-res image: {save_e}")
progress(0.45, desc="Low-Res Generation Complete.")
progress(0.50, desc="Starting Hi-Res Tiling...")
final_image = run_hires_tiling(
low_res_image=low_res_image,
seed=int(current_seed),
steps=int(DEFAULT_STEPS_HIRES),
guidance_scale=float(DEFAULT_GUIDANCE_HIRES),
strength=float(DEFAULT_STRENGTH_HIRES),
controlnet_scale=float(DEFAULT_CN_SCALE_HIRES),
upscale_factor=2,
tile_size=1024,
tile_stride=1024,
progress=progress
)
print("Hi-res tiling stage completed successfully.")
# try: final_image.save(os.path.join(OUTPUT_DIR, f"hires_output_{current_seed}.png"))
# except Exception as save_e: print(f"Warning: Could not save final image: {save_e}")
progress(1.0, desc="Complete!")
except gr.Error as e:
print(f"Gradio Error occurred: {e}")
if final_image is None and low_res_image is not None and ("tiling" in str(e).lower() or "hi-res" in str(e).lower()):
gr.Warning(f"High-resolution upscaling failed ({e}). Returning low-resolution image.")
final_image = low_res_image
else:
raise e
except Exception as e:
print(f"An unexpected error occurred in generate_full_pipeline: {e}")
import traceback
traceback.print_exc()
raise gr.Error(f"An unexpected error occurred: {e}")
finally:
print("Running final cleanup check...")
cleanup_memory()
run_end_time = time.time()
print(f"--- Full Pipeline Run Finished in {run_end_time - run_start_time:.2f} seconds ---")
return final_image
# --- Gradio Interface Definition ---
theme = gr_themes.Soft(primary_hue=gr_themes.colors.blue, secondary_hue=gr_themes.colors.sky)
# New, improved Markdown description
DESCRIPTION = f"""
<div style="text-align: center;">
<h1 style="font-family: Impact, Charcoal, sans-serif; font-size: 280%; font-weight: 900; margin-bottom: 16px;">
Pose-Preserving Comicfier
</h1>
<p style="margin-bottom: 12; font-size: 94%">
Transform your photos into the gritty style of a 1940s Western comic! This app uses (Stable Diffusion + ControlNet)
to apply the artistic look while keeping the original pose intact. Just upload your image and click Generate!
</p>
<p style="font-size: 85%;"><em>(Generation can take several minutes on shared hardware. Prompts & parameters are fixed.)</em></p>
<p style="font-size: 80%; color: grey;">
<a href="https://github.com/mehran-khani/Pose-Preserving-Comicfier" target="_blank">[View Project on GitHub]</a> |
<a href="https://huggingface.co/spaces/Mer-o/Pose-Preserving-Comicfier/discussions" target="_blank">[Report an Issue]</a>
</p>
<!-- Remember to replace placeholders above with your actual links -->
</div>
"""
EXAMPLE_IMAGES_DIR = "examples"
EXAMPLE_IMAGES = [
os.path.join(EXAMPLE_IMAGES_DIR, "example1.jpg"),
os.path.join(EXAMPLE_IMAGES_DIR, "example2.jpg"),
os.path.join(EXAMPLE_IMAGES_DIR, "example3.jpg"),
os.path.join(EXAMPLE_IMAGES_DIR, "example4.jpg"),
os.path.join(EXAMPLE_IMAGES_DIR, "example5.jpg"),
os.path.join(EXAMPLE_IMAGES_DIR, "example6.jpg"),
]
EXAMPLE_IMAGES = [img for img in EXAMPLE_IMAGES if os.path.exists(img)]
CUSTOM_CSS = """
/* Target the container div Gradio uses for the Image component */
.gradio-image {
width: 100%; /* Ensure the container fills the column width */
height: 100%; /* Ensure the container fills the height set by the component (e.g., height=400) */
overflow: hidden; /* Hide any potential overflow before object-fit applies */
}
/* Target the actual <img> tag inside the container */
.gradio-image img {
display: block; /* Remove potential bottom spacing */
width: 100%; /* Force image width to match container */
height: 100%; /* Force image height to match container */
object-fit: cover; /* Scale/crop image to cover this forced W/H */
}
footer { visibility: hidden }
"""
with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Pose-Preserving Comicfier") as demo:
gr.HTML(DESCRIPTION)
with gr.Row():
# Input Column
with gr.Column(scale=1, min_width=350):
# REMOVED height=400
input_image = gr.Image(
type="filepath",
label="Upload Your Image Here"
)
generate_button = gr.Button("Generate Comic Image", variant="primary")
# Output Column
with gr.Column(scale=1, min_width=350):
# REMOVED height=400
output_image = gr.Image(
type="pil",
label="Generated Comic Image",
interactive=False
)
# Examples Section
if EXAMPLE_IMAGES:
gr.Examples(
examples=EXAMPLE_IMAGES,
inputs=[input_image],
outputs=[output_image],
fn=generate_full_pipeline,
cache_examples=False
)
generate_button.click(
fn=generate_full_pipeline,
inputs=[input_image],
outputs=[output_image],
api_name="generate"
)
# --- Launch the Gradio App ---
if __name__ == "__main__":
if not are_models_loaded():
print("Attempting to load models before launch...")
if not load_models():
print("FATAL: Model loading failed on launch. App may not function.")
print("Attempting to launch Gradio demo...")
demo.queue().launch(debug=False, share=False)
print("Gradio app launched. Access it at the URL provided above.") |