Spaces:
Running
Running
""" | |
Main application script for the Gradio interface. | |
This script initializes the application, loads prerequisite models via model_loader, | |
defines the user interface using Gradio Blocks, and orchestrates the multi-stage | |
image generation process by calling functions from the pipelines module. | |
""" | |
import gradio as gr | |
import gradio.themes as gr_themes | |
import time | |
import os | |
import random | |
# --- Imports from our custom modules --- | |
try: | |
from image_utils import prepare_image | |
from model_loader import load_models, are_models_loaded | |
from pipelines import run_pose_detection, run_low_res_generation, run_hires_tiling, cleanup_memory | |
print("Helper modules imported successfully.") | |
except ImportError as e: | |
print(f"ERROR: Failed to import required local modules: {e}") | |
print("Please ensure prompts.py, image_utils.py, model_loader.py, and pipelines.py are in the same directory.") | |
raise SystemExit(f"Module import failed: {e}") | |
# --- Constants & UI Configuration --- | |
DEFAULT_SEED = 1024 | |
DEFAULT_STEPS_LOWRES = 30 | |
DEFAULT_GUIDANCE_LOWRES = 8.0 | |
DEFAULT_STRENGTH_LOWRES = 0.05 | |
DEFAULT_CN_SCALE_LOWRES = 1.0 | |
DEFAULT_STEPS_HIRES = 20 | |
DEFAULT_GUIDANCE_HIRES = 8.0 | |
DEFAULT_STRENGTH_HIRES = 0.75 | |
DEFAULT_CN_SCALE_HIRES = 1.0 | |
# OUTPUT_DIR = "outputs" | |
# os.makedirs(OUTPUT_DIR, exist_ok=True) | |
# --- Load Prerequisite Models at Startup --- | |
if not are_models_loaded(): | |
print("Initial model loading required...") | |
load_successful = load_models() | |
if not load_successful: | |
print("FATAL: Failed to load prerequisite models. The application may not work correctly.") | |
else: | |
print("Models were already loaded.") | |
# --- Main Processing Function --- | |
def generate_full_pipeline( | |
input_image_path, | |
progress=gr.Progress(track_tqdm=True) | |
): | |
""" | |
Orchestrates the entire image generation workflow. | |
This function is called when the user clicks the 'Generate' button in the UI. | |
It takes inputs from the UI, calls the necessary processing steps in sequence | |
(prepare, detect pose, low-res gen, hi-res gen), updates the progress bar, | |
and returns the final generated image. | |
Args: | |
input_image_path (str): Path to the uploaded input image file. | |
seed (int): Random seed for generation. | |
steps_lowres (int): Inference steps for the low-resolution stage. | |
guidance_lowres (float): Guidance scale for the low-resolution stage. | |
strength_lowres (float): Img2Img strength for the low-resolution stage. | |
cn_scale_lowres (float): ControlNet scale for the low-resolution stage. | |
steps_hires (int): Inference steps per tile for the high-resolution stage. | |
guidance_hires (float): Guidance scale for the high-resolution stage. | |
strength_hires (float): Img2Img strength for the high-resolution stage. | |
cn_scale_hires (float): ControlNet scale for the high-resolution stage. | |
progress (gr.Progress): Gradio progress tracking object. | |
Returns: | |
PIL.Image.Image | None: The final generated high-resolution image, | |
or the low-resolution image as a fallback if | |
tiling fails, or None if critical errors occur early. | |
Raises: | |
gr.Error: If critical steps like image preparation or pose detection fail. | |
gr.Warning: If hi-res tiling fails but low-res succeeded (returns low-res). | |
""" | |
print(f"\n--- Starting New Generation Run ---") | |
run_start_time = time.time() | |
current_seed = DEFAULT_SEED | |
if current_seed == -1: | |
current_seed = random.randint(0, 9999999) | |
print(f"Using Random Seed: {current_seed}") | |
else: | |
print(f"Using Fixed Seed: {current_seed}") | |
low_res_image = None | |
final_image = None | |
try: | |
progress(0.05, desc="Preparing Input Image...") | |
resized_input_image = prepare_image(input_image_path, target_size=512) | |
if resized_input_image is None: | |
raise gr.Error("Failed to load or prepare the input image. Check format/corruption.") | |
progress(0.15, desc="Detecting Pose...") | |
pose_map = run_pose_detection(resized_input_image) | |
if pose_map is None: | |
raise gr.Error("Failed to detect pose from the input image.") | |
# try: pose_map.save(os.path.join(OUTPUT_DIR, f"pose_map_{current_seed}.png")) | |
# except Exception as save_e: print(f"Warning: Could not save pose map: {save_e}") | |
progress(0.25, desc="Starting Low-Res Generation...") | |
low_res_image = run_low_res_generation( | |
resized_input_image=resized_input_image, | |
pose_map=pose_map, | |
seed=int(current_seed), | |
steps=int(DEFAULT_STEPS_LOWRES), | |
guidance_scale=float(DEFAULT_GUIDANCE_LOWRES), | |
strength=float(DEFAULT_STRENGTH_LOWRES), | |
controlnet_scale=float(DEFAULT_CN_SCALE_LOWRES), | |
progress=progress | |
) | |
print("Low-res generation stage completed successfully.") | |
# try: low_res_image.save(os.path.join(OUTPUT_DIR, f"lowres_output_{current_seed}.png")) | |
# except Exception as save_e: print(f"Warning: Could not save low-res image: {save_e}") | |
progress(0.45, desc="Low-Res Generation Complete.") | |
progress(0.50, desc="Starting Hi-Res Tiling...") | |
final_image = run_hires_tiling( | |
low_res_image=low_res_image, | |
seed=int(current_seed), | |
steps=int(DEFAULT_STEPS_HIRES), | |
guidance_scale=float(DEFAULT_GUIDANCE_HIRES), | |
strength=float(DEFAULT_STRENGTH_HIRES), | |
controlnet_scale=float(DEFAULT_CN_SCALE_HIRES), | |
upscale_factor=2, | |
tile_size=1024, | |
tile_stride=1024, | |
progress=progress | |
) | |
print("Hi-res tiling stage completed successfully.") | |
# try: final_image.save(os.path.join(OUTPUT_DIR, f"hires_output_{current_seed}.png")) | |
# except Exception as save_e: print(f"Warning: Could not save final image: {save_e}") | |
progress(1.0, desc="Complete!") | |
except gr.Error as e: | |
print(f"Gradio Error occurred: {e}") | |
if final_image is None and low_res_image is not None and ("tiling" in str(e).lower() or "hi-res" in str(e).lower()): | |
gr.Warning(f"High-resolution upscaling failed ({e}). Returning low-resolution image.") | |
final_image = low_res_image | |
else: | |
raise e | |
except Exception as e: | |
print(f"An unexpected error occurred in generate_full_pipeline: {e}") | |
import traceback | |
traceback.print_exc() | |
raise gr.Error(f"An unexpected error occurred: {e}") | |
finally: | |
print("Running final cleanup check...") | |
cleanup_memory() | |
run_end_time = time.time() | |
print(f"--- Full Pipeline Run Finished in {run_end_time - run_start_time:.2f} seconds ---") | |
return final_image | |
# --- Gradio Interface Definition --- | |
theme = gr_themes.Soft(primary_hue=gr_themes.colors.blue, secondary_hue=gr_themes.colors.sky) | |
# New, improved Markdown description | |
DESCRIPTION = f""" | |
<div style="text-align: center;"> | |
<h1 style="font-family: Impact, Charcoal, sans-serif; font-size: 280%; font-weight: 900; margin-bottom: 16px;"> | |
Pose-Preserving Comicfier | |
</h1> | |
<p style="margin-bottom: 12; font-size: 94%"> | |
Transform your photos into the gritty style of a 1940s Western comic! This app uses (Stable Diffusion + ControlNet) | |
to apply the artistic look while keeping the original pose intact. Just upload your image and click Generate! | |
</p> | |
<p style="font-size: 85%;"><em>(Generation can take several minutes on shared hardware. Prompts & parameters are fixed.)</em></p> | |
<p style="font-size: 80%; color: grey;"> | |
<a href="https://github.com/mehran-khani/Pose-Preserving-Comicfier" target="_blank">[View Project on GitHub]</a> | | |
<a href="https://huggingface.co/spaces/Mer-o/Pose-Preserving-Comicfier/discussions" target="_blank">[Report an Issue]</a> | |
</p> | |
<!-- Remember to replace placeholders above with your actual links --> | |
</div> | |
""" | |
EXAMPLE_IMAGES_DIR = "examples" | |
EXAMPLE_IMAGES = [ | |
os.path.join(EXAMPLE_IMAGES_DIR, "example1.jpg"), | |
os.path.join(EXAMPLE_IMAGES_DIR, "example2.jpg"), | |
os.path.join(EXAMPLE_IMAGES_DIR, "example3.jpg"), | |
os.path.join(EXAMPLE_IMAGES_DIR, "example4.jpg"), | |
os.path.join(EXAMPLE_IMAGES_DIR, "example5.jpg"), | |
os.path.join(EXAMPLE_IMAGES_DIR, "example6.jpg"), | |
] | |
EXAMPLE_IMAGES = [img for img in EXAMPLE_IMAGES if os.path.exists(img)] | |
CUSTOM_CSS = """ | |
/* Target the container div Gradio uses for the Image component */ | |
.gradio-image { | |
width: 100%; /* Ensure the container fills the column width */ | |
height: 100%; /* Ensure the container fills the height set by the component (e.g., height=400) */ | |
overflow: hidden; /* Hide any potential overflow before object-fit applies */ | |
} | |
/* Target the actual <img> tag inside the container */ | |
.gradio-image img { | |
display: block; /* Remove potential bottom spacing */ | |
width: 100%; /* Force image width to match container */ | |
height: 100%; /* Force image height to match container */ | |
object-fit: cover; /* Scale/crop image to cover this forced W/H */ | |
} | |
footer { visibility: hidden } | |
""" | |
with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Pose-Preserving Comicfier") as demo: | |
gr.HTML(DESCRIPTION) | |
with gr.Row(): | |
# Input Column | |
with gr.Column(scale=1, min_width=350): | |
# REMOVED height=400 | |
input_image = gr.Image( | |
type="filepath", | |
label="Upload Your Image Here" | |
) | |
generate_button = gr.Button("Generate Comic Image", variant="primary") | |
# Output Column | |
with gr.Column(scale=1, min_width=350): | |
# REMOVED height=400 | |
output_image = gr.Image( | |
type="pil", | |
label="Generated Comic Image", | |
interactive=False | |
) | |
# Examples Section | |
if EXAMPLE_IMAGES: | |
gr.Examples( | |
examples=EXAMPLE_IMAGES, | |
inputs=[input_image], | |
outputs=[output_image], | |
fn=generate_full_pipeline, | |
cache_examples=False | |
) | |
generate_button.click( | |
fn=generate_full_pipeline, | |
inputs=[input_image], | |
outputs=[output_image], | |
api_name="generate" | |
) | |
# --- Launch the Gradio App --- | |
if __name__ == "__main__": | |
if not are_models_loaded(): | |
print("Attempting to load models before launch...") | |
if not load_models(): | |
print("FATAL: Model loading failed on launch. App may not function.") | |
print("Attempting to launch Gradio demo...") | |
demo.queue().launch(debug=False, share=False) | |
print("Gradio app launched. Access it at the URL provided above.") |