Mer-o commited on
Commit
4940090
·
2 Parent(s): f7ac47c 536503c

merging to main

Browse files
.gitattributes CHANGED
@@ -1,35 +1,2 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ loras/*.safetensors filter=lfs diff=lfs merge=lfs -text
2
+ examples/*.jpg filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+ MANIFEST
27
+
28
+ # PyInstaller
29
+ *.manifest
30
+ *.spec
31
+
32
+ # Installer logs
33
+ pip-log.txt
34
+ pip-delete-this-directory.txt
35
+
36
+ # Unit test / coverage reports
37
+ htmlcov/
38
+ .tox/
39
+ .nox/
40
+ .coverage
41
+ .coverage.*
42
+ .cache
43
+ nosetests.xml
44
+ coverage.xml
45
+ *.cover
46
+ *.py,cover
47
+ .hypothesis/
48
+ .pytest_cache/
49
+
50
+ # Translations
51
+ *.mo
52
+ *.pot
53
+
54
+ # Django stuff:
55
+ *.log
56
+ local_settings.py
57
+ db.sqlite3
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # IPython
76
+ profile_default/
77
+ ipython_config.py
78
+
79
+ # pyenv
80
+ .python-version
81
+
82
+ # PEP 582; used by PDM, PEP 582 compatible tooling
83
+ __pypackages__/
84
+
85
+ # Celery stuff
86
+ celerybeat-schedule
87
+ celerybeat.pid
88
+
89
+ # SageMath parsed files
90
+ *.sage.py
91
+
92
+ # Environments
93
+ .env
94
+ .venv
95
+ env/
96
+ venv/
97
+ ENV/
98
+ env.bak/
99
+ venv.bak/
100
+
101
+ # Spyder project settings
102
+ .spyderproject
103
+ .spyproject
104
+
105
+ # Rope project settings
106
+ .ropeproject
107
+
108
+ # mkdocs documentation
109
+ /site
110
+
111
+ # mypy
112
+ .mypy_cache/
113
+ .dmypy.json
114
+ dmypy.json
115
+
116
+ # Pyre type checker
117
+ .pyre/
118
+
119
+ # OS generated files #
120
+ ######################
121
+ .DS_Store
122
+ .DS_Store?
123
+ ._*
124
+ .Spotlight-V100
125
+ .Trashes
126
+ ehthumbs.db
127
+ Thumbs.db
README.md CHANGED
@@ -1,14 +1,111 @@
1
- ---
2
- title: Pose Preserving Comicfier
3
- emoji: 📊
4
- colorFrom: green
5
- colorTo: green
6
- sdk: gradio
7
- sdk_version: 5.29.0
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: 'Comicfier: Transforms photos into retro Western comic style'
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Pose-Preserving Comicfier - Gradio App
2
+
3
+ [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/Mer-o)(https://huggingface.co/spaces/Mer-o/Pose-Preserving-Comicfier)
4
+
5
+ This repository contains the code for a Gradio web application that transforms input images into a specific retro Western comic book style while preserving the original pose. It uses Stable Diffusion v1.5, ControlNet (OpenPose + Tile), and specific LoRAs.
6
+
7
+ This application refactors the workflow initially developed in a [Kaggle Notebook](https://github.com/mehran-khani/SD-Controlnet-Comic-Styler) into a deployable web app.
8
+
9
+ ## Features
10
+
11
+ * **Pose Preservation:** Uses ControlNet OpenPose to accurately maintain the pose from the input image.
12
+ * **Retro Comic Style Transfer:** Applies specific LoRAs (`night_comic_V06.safetensors` & `add_detail.safetensors`) for a 1940s Western comic aesthetic with enhanced details.
13
+ * **Tiling Upscaling:** Implements ControlNet Tile for 2x high-resolution output (1024x1024), improving detail consistency over large images.
14
+ * **Simplified UI:** Easy-to-use interface with only an image upload and generate button.
15
+ * **Fixed Parameters:** Generation uses pre-defined, optimized parameters (steps, guidance, strength, prompts) based on the original notebook implementation for consistent results.
16
+ * **Dynamic Backgrounds:** The background elements in the generated image are randomized for variety in the low-resolution stage.
17
+ * **Broad Image Support:** Accepts common formats like JPG, PNG, WEBP, and HEIC (requires `pillow-heif`).
18
+
19
+ ## Technology Stack
20
+
21
+ * **Python 3**
22
+ * **Gradio:** Web UI framework.
23
+ * **PyTorch:** Core ML framework.
24
+ * **Hugging Face Libraries:**
25
+ * `diffusers`: Stable Diffusion pipelines, ControlNet integration.
26
+ * `transformers`: Underlying model components.
27
+ * `accelerate`: Hardware acceleration utilities.
28
+ * `peft`: LoRA loading and management.
29
+ * **ControlNet:**
30
+ * OpenPose Detector (`controlnet_aux`)
31
+ * OpenPose ControlNet Model (`lllyasviel/sd-controlnet-openpose`)
32
+ * Tile ControlNet Model (`lllyasviel/control_v11f1e_sd15_tile`)
33
+ * **Base Model:** `runwayml/stable-diffusion-v1-5`
34
+ * **LoRAs Used:**
35
+ * Style: [Western Comics Style](https://civitai.com/models/1081588/western-comics-style) (using `night_comic_V06.safetensors`)
36
+ * Detail: [Detail Tweaker LoRA](https://civitai.com/models/58390/detail-tweaker-lora-lora) (using `add_detail.safetensors`)
37
+ * **Image Processing:** `Pillow`, `pillow-heif`, `numpy`, `opencv-python-headless`
38
+ * **Dependencies:** `matplotlib`, `mediapipe` (required by `controlnet_aux`)
39
+
40
+ ## Workflow Overview
41
+
42
+ 1. **Image Preparation (`image_utils.py`):** Input image is loaded (supports HEIC), converted to RGB, EXIF data handled, and force-resized to 512x512.
43
+ 2. **Pose Detection (`pipelines.py`):** An OpenPose map is extracted from the resized image using `controlnet_aux`.
44
+ 3. **Low-Resolution Generation (`pipelines.py`):**
45
+ * An SDv1.5 Img2Img pipeline with Pose ControlNet is dynamically loaded.
46
+ * Prompts are generated (`prompts.py`) with a fixed base/style and a *randomized* background element.
47
+ * Style and Detail LoRAs are applied.
48
+ * A 512x512 image is generated using fixed parameters.
49
+ * The pipeline is unloaded to conserve VRAM.
50
+ 4. **High-Resolution Tiling (`pipelines.py`):**
51
+ * The 512x512 image is upscaled 2x (to 1024x1024) using bicubic interpolation (creating a blurry base).
52
+ * An SDv1.5 Img2Img pipeline with Tile ControlNet is dynamically loaded.
53
+ * Tile-specific prompts (excluding the random background) are used.
54
+ * Style and Detail LoRAs are applied (potentially with different weights).
55
+ * The image is processed in overlapping 1024x1024 tiles.
56
+ * Processed tiles are blended back together using an alpha mask (`image_utils.py`).
57
+ * The pipeline is unloaded.
58
+ 5. **Output (`app.py`):** The final 1024x1024 image is displayed in the Gradio UI.
59
+
60
+ ## How to Run Locally
61
+
62
+ *(Requires sufficient RAM/CPU or compatible GPU, Python 3.8+, and Git)*
63
+
64
+ 1. **Clone the repository:**
65
+ ```bash
66
+ git clone https://github.com/mehran-khani/Pose-Preserving-Comicfier.git
67
+ cd Pose-Preserving-Comicfier
68
+ ```
69
+ 2. **Create and activate a Python virtual environment:**
70
+ ```bash
71
+ python3 -m venv .venv
72
+ source .venv/bin/activate
73
+ # .\.venv\Scripts\Activate.ps1
74
+ # .\.venv\Scripts\activate.bat
75
+ ```
76
+ 3. **Install dependencies:**
77
+ ```bash
78
+ pip install -r requirements.txt
79
+ ```
80
+ *(Note: PyTorch installation might require specific commands depending on your OS/CUDA setup if using a local GPU. See PyTorch website.)*
81
+ 4. **Download LoRA files:**
82
+ * Create a folder named `loras` in the project root.
83
+ * Download `night_comic_V06.safetensors` (from Civitai link above) and place it in the `loras` folder.
84
+ * Download `add_detail.safetensors` (from Civitai link above) and place it in the `loras` folder.
85
+ 5. **Run the Gradio app:**
86
+ ```bash
87
+ python app.py
88
+ ```
89
+ 6. Open the local URL provided (e.g., `http://127.0.0.1:7860`) in your browser. *(Note: Execution will be very slow without a suitable GPU).*
90
+
91
+ ## Deployment to Hugging Face Spaces
92
+
93
+ This app is designed for deployment on Hugging Face Spaces, ideally with GPU hardware.
94
+
95
+ 1. Ensure all code (`*.py`), `requirements.txt`, `.gitignore`, and the `loras` folder (containing the `.safetensors` files) are committed and pushed to this GitHub repository.
96
+ 2. Create a new Space on Hugging Face ([huggingface.co/new-space](https://huggingface.co/new-space)).
97
+ 3. Choose an owner, Space name, and select "Gradio" as the Space SDK.
98
+ 4. Select desired hardware (e.g., "T4 small" under GPU options). Note compute costs may apply.
99
+ 5. Choose "Use existing GitHub repository".
100
+ 6. Enter the URL of this GitHub repository.
101
+ 7. Click "Create Space". The Space will build the environment from `requirements.txt` and run `app.py`. Monitor the build and runtime logs for any issues.
102
+
103
+ ## Limitations
104
+
105
+ * **Speed:** Generation requires significant time (minutes), especially on shared/free GPU hardware, due to the multi-stage process and dynamic model loading between stages. CPU execution is impractically slow.
106
+ * **VRAM:** While optimized with dynamic pipeline unloading, the process still requires considerable GPU VRAM (>10GB peak). Out-of-memory errors are possible on lower-VRAM GPUs.
107
+ * **Fixed Style:** The artistic style (prompts, LoRAs, parameters) is fixed in the code to replicate the notebook's specific output and cannot be changed via the UI.
108
+
109
+ ## License
110
+
111
+ MIT License
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main application script for the Gradio interface.
3
+
4
+ This script initializes the application, loads prerequisite models via model_loader,
5
+ defines the user interface using Gradio Blocks, and orchestrates the multi-stage
6
+ image generation process by calling functions from the pipelines module.
7
+ """
8
+
9
+ import gradio as gr
10
+ import gradio.themes as gr_themes
11
+ import time
12
+ import os
13
+ import random
14
+ # --- Imports from our custom modules ---
15
+ try:
16
+ from image_utils import prepare_image
17
+ from model_loader import load_models, are_models_loaded
18
+ from pipelines import run_pose_detection, run_low_res_generation, run_hires_tiling, cleanup_memory
19
+ print("Helper modules imported successfully.")
20
+ except ImportError as e:
21
+ print(f"ERROR: Failed to import required local modules: {e}")
22
+ print("Please ensure prompts.py, image_utils.py, model_loader.py, and pipelines.py are in the same directory.")
23
+ raise SystemExit(f"Module import failed: {e}")
24
+
25
+ # --- Constants & UI Configuration ---
26
+ DEFAULT_SEED = 1024
27
+ DEFAULT_STEPS_LOWRES = 30
28
+ DEFAULT_GUIDANCE_LOWRES = 8.0
29
+ DEFAULT_STRENGTH_LOWRES = 0.05
30
+ DEFAULT_CN_SCALE_LOWRES = 1.0
31
+
32
+ DEFAULT_STEPS_HIRES = 20
33
+ DEFAULT_GUIDANCE_HIRES = 8.0
34
+ DEFAULT_STRENGTH_HIRES = 0.75
35
+ DEFAULT_CN_SCALE_HIRES = 1.0
36
+
37
+ # OUTPUT_DIR = "outputs"
38
+ # os.makedirs(OUTPUT_DIR, exist_ok=True)
39
+
40
+ # --- Load Prerequisite Models at Startup ---
41
+ if not are_models_loaded():
42
+ print("Initial model loading required...")
43
+ load_successful = load_models()
44
+ if not load_successful:
45
+ print("FATAL: Failed to load prerequisite models. The application may not work correctly.")
46
+ else:
47
+ print("Models were already loaded.")
48
+
49
+
50
+ # --- Main Processing Function ---
51
+ def generate_full_pipeline(
52
+ input_image_path,
53
+ progress=gr.Progress(track_tqdm=True)
54
+ ):
55
+ """
56
+ Orchestrates the entire image generation workflow.
57
+
58
+ This function is called when the user clicks the 'Generate' button in the UI.
59
+ It takes inputs from the UI, calls the necessary processing steps in sequence
60
+ (prepare, detect pose, low-res gen, hi-res gen), updates the progress bar,
61
+ and returns the final generated image.
62
+
63
+ Args:
64
+ input_image_path (str): Path to the uploaded input image file.
65
+ seed (int): Random seed for generation.
66
+ steps_lowres (int): Inference steps for the low-resolution stage.
67
+ guidance_lowres (float): Guidance scale for the low-resolution stage.
68
+ strength_lowres (float): Img2Img strength for the low-resolution stage.
69
+ cn_scale_lowres (float): ControlNet scale for the low-resolution stage.
70
+ steps_hires (int): Inference steps per tile for the high-resolution stage.
71
+ guidance_hires (float): Guidance scale for the high-resolution stage.
72
+ strength_hires (float): Img2Img strength for the high-resolution stage.
73
+ cn_scale_hires (float): ControlNet scale for the high-resolution stage.
74
+ progress (gr.Progress): Gradio progress tracking object.
75
+
76
+ Returns:
77
+ PIL.Image.Image | None: The final generated high-resolution image,
78
+ or the low-resolution image as a fallback if
79
+ tiling fails, or None if critical errors occur early.
80
+
81
+ Raises:
82
+ gr.Error: If critical steps like image preparation or pose detection fail.
83
+ gr.Warning: If hi-res tiling fails but low-res succeeded (returns low-res).
84
+ """
85
+ print(f"\n--- Starting New Generation Run ---")
86
+ run_start_time = time.time()
87
+
88
+ current_seed = DEFAULT_SEED
89
+ if current_seed == -1:
90
+ current_seed = random.randint(0, 9999999)
91
+ print(f"Using Random Seed: {current_seed}")
92
+ else:
93
+ print(f"Using Fixed Seed: {current_seed}")
94
+
95
+ low_res_image = None
96
+ final_image = None
97
+
98
+ try:
99
+ progress(0.05, desc="Preparing Input Image...")
100
+ resized_input_image = prepare_image(input_image_path, target_size=512)
101
+ if resized_input_image is None:
102
+ raise gr.Error("Failed to load or prepare the input image. Check format/corruption.")
103
+
104
+ progress(0.15, desc="Detecting Pose...")
105
+ pose_map = run_pose_detection(resized_input_image)
106
+ if pose_map is None:
107
+ raise gr.Error("Failed to detect pose from the input image.")
108
+ # try: pose_map.save(os.path.join(OUTPUT_DIR, f"pose_map_{current_seed}.png"))
109
+ # except Exception as save_e: print(f"Warning: Could not save pose map: {save_e}")
110
+
111
+
112
+ progress(0.25, desc="Starting Low-Res Generation...")
113
+ low_res_image = run_low_res_generation(
114
+ resized_input_image=resized_input_image,
115
+ pose_map=pose_map,
116
+ seed=int(current_seed),
117
+ steps=int(DEFAULT_STEPS_LOWRES),
118
+ guidance_scale=float(DEFAULT_GUIDANCE_LOWRES),
119
+ strength=float(DEFAULT_STRENGTH_LOWRES),
120
+ controlnet_scale=float(DEFAULT_CN_SCALE_LOWRES),
121
+ progress=progress
122
+ )
123
+ print("Low-res generation stage completed successfully.")
124
+ # try: low_res_image.save(os.path.join(OUTPUT_DIR, f"lowres_output_{current_seed}.png"))
125
+ # except Exception as save_e: print(f"Warning: Could not save low-res image: {save_e}")
126
+ progress(0.45, desc="Low-Res Generation Complete.")
127
+
128
+
129
+ progress(0.50, desc="Starting Hi-Res Tiling...")
130
+ final_image = run_hires_tiling(
131
+ low_res_image=low_res_image,
132
+ seed=int(current_seed),
133
+ steps=int(DEFAULT_STEPS_HIRES),
134
+ guidance_scale=float(DEFAULT_GUIDANCE_HIRES),
135
+ strength=float(DEFAULT_STRENGTH_HIRES),
136
+ controlnet_scale=float(DEFAULT_CN_SCALE_HIRES),
137
+ upscale_factor=2,
138
+ tile_size=1024,
139
+ tile_stride=1024,
140
+ progress=progress
141
+ )
142
+ print("Hi-res tiling stage completed successfully.")
143
+ # try: final_image.save(os.path.join(OUTPUT_DIR, f"hires_output_{current_seed}.png"))
144
+ # except Exception as save_e: print(f"Warning: Could not save final image: {save_e}")
145
+
146
+ progress(1.0, desc="Complete!")
147
+
148
+ except gr.Error as e:
149
+ print(f"Gradio Error occurred: {e}")
150
+ if final_image is None and low_res_image is not None and ("tiling" in str(e).lower() or "hi-res" in str(e).lower()):
151
+ gr.Warning(f"High-resolution upscaling failed ({e}). Returning low-resolution image.")
152
+ final_image = low_res_image
153
+ else:
154
+ raise e
155
+ except Exception as e:
156
+ print(f"An unexpected error occurred in generate_full_pipeline: {e}")
157
+ import traceback
158
+ traceback.print_exc()
159
+ raise gr.Error(f"An unexpected error occurred: {e}")
160
+ finally:
161
+ print("Running final cleanup check...")
162
+ cleanup_memory()
163
+ run_end_time = time.time()
164
+ print(f"--- Full Pipeline Run Finished in {run_end_time - run_start_time:.2f} seconds ---")
165
+
166
+ return final_image
167
+
168
+
169
+ # --- Gradio Interface Definition ---
170
+
171
+ theme = gr_themes.Soft(primary_hue=gr_themes.colors.blue, secondary_hue=gr_themes.colors.sky)
172
+
173
+ # New, improved Markdown description
174
+ DESCRIPTION = f"""
175
+ <div style="text-align: center;">
176
+ <h1 style="font-family: Impact, Charcoal, sans-serif; font-size: 280%; font-weight: 900; margin-bottom: 16px;">
177
+ Pose-Preserving Comicfier
178
+ </h1>
179
+ <p style="margin-bottom: 12; font-size: 94%">
180
+ Transform your photos into the gritty style of a 1940s Western comic! This app uses (Stable Diffusion + ControlNet)
181
+ to apply the artistic look while keeping the original pose intact. Just upload your image and click Generate!
182
+ </p>
183
+ <p style="font-size: 85%;"><em>(Generation can take several minutes on shared hardware. Prompts & parameters are fixed.)</em></p>
184
+ <p style="font-size: 80%; color: grey;">
185
+ <a href="https://github.com/mehran-khani" target="_blank">[View Project on GitHub]</a> |
186
+ <a href="https://huggingface.co/spaces/.../discussions" target="_blank">[Report an Issue]</a>
187
+ </p>
188
+ <!-- Remember to replace placeholders above with your actual links -->
189
+ </div>
190
+ """
191
+
192
+ EXAMPLE_IMAGES_DIR = "examples"
193
+ EXAMPLE_IMAGES = [
194
+ os.path.join(EXAMPLE_IMAGES_DIR, "example1.jpg"),
195
+ os.path.join(EXAMPLE_IMAGES_DIR, "example2.jpg"),
196
+ os.path.join(EXAMPLE_IMAGES_DIR, "example3.jpg"),
197
+ os.path.join(EXAMPLE_IMAGES_DIR, "example4.jpg"),
198
+ os.path.join(EXAMPLE_IMAGES_DIR, "example5.jpg"),
199
+ os.path.join(EXAMPLE_IMAGES_DIR, "example6.jpg"),
200
+ ]
201
+ EXAMPLE_IMAGES = [img for img in EXAMPLE_IMAGES if os.path.exists(img)]
202
+
203
+ CUSTOM_CSS = """
204
+ /* Target the container div Gradio uses for the Image component */
205
+ .gradio-image {
206
+ width: 100%; /* Ensure the container fills the column width */
207
+ height: 100%; /* Ensure the container fills the height set by the component (e.g., height=400) */
208
+ overflow: hidden; /* Hide any potential overflow before object-fit applies */
209
+ }
210
+
211
+ /* Target the actual <img> tag inside the container */
212
+ .gradio-image img {
213
+ display: block; /* Remove potential bottom spacing */
214
+ width: 100%; /* Force image width to match container */
215
+ height: 100%; /* Force image height to match container */
216
+ object-fit: cover; /* Scale/crop image to cover this forced W/H */
217
+ }
218
+
219
+ footer { visibility: hidden }
220
+ """
221
+
222
+ with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Pose-Preserving Comicfier") as demo:
223
+ gr.HTML(DESCRIPTION)
224
+
225
+ with gr.Row():
226
+ # Input Column
227
+ with gr.Column(scale=1, min_width=350):
228
+ # REMOVED height=400
229
+ input_image = gr.Image(
230
+ type="filepath",
231
+ label="Upload Your Image Here"
232
+ )
233
+ generate_button = gr.Button("Generate Comic Image", variant="primary")
234
+
235
+ # Output Column
236
+ with gr.Column(scale=1, min_width=350):
237
+ # REMOVED height=400
238
+ output_image = gr.Image(
239
+ type="pil",
240
+ label="Generated Comic Image",
241
+ interactive=False
242
+ )
243
+
244
+
245
+ # Examples Section
246
+ if EXAMPLE_IMAGES:
247
+ gr.Examples(
248
+ examples=EXAMPLE_IMAGES,
249
+ inputs=[input_image],
250
+ outputs=[output_image],
251
+ fn=generate_full_pipeline,
252
+ cache_examples=False
253
+ )
254
+
255
+ generate_button.click(
256
+ fn=generate_full_pipeline,
257
+ inputs=[input_image],
258
+ outputs=[output_image],
259
+ api_name="generate"
260
+ )
261
+
262
+
263
+ # --- Launch the Gradio App ---
264
+ if __name__ == "__main__":
265
+ if not are_models_loaded():
266
+ print("Attempting to load models before launch...")
267
+ if not load_models():
268
+ print("FATAL: Model loading failed on launch. App may not function.")
269
+
270
+ print("Attempting to launch Gradio demo...")
271
+ demo.queue().launch(debug=False, share=False)
272
+ print("Gradio app launched. Access it at the URL provided above.")
examples/example1.jpg ADDED

Git LFS Details

  • SHA256: 66c374f373be7ea050820ba736dde2cb6844e627d1095af8c68155388c5bc3ba
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
examples/example2.jpg ADDED

Git LFS Details

  • SHA256: f68db724f0bc96d82525d6948be57d5ad2d43fd444ff2b66d36e7c1bbaea2443
  • Pointer size: 132 Bytes
  • Size of remote file: 3.19 MB
examples/example3.jpg ADDED

Git LFS Details

  • SHA256: e3a212b1b7f0de6044731814c7fd02a8e77aecc355f155a3e5002d07c64db726
  • Pointer size: 131 Bytes
  • Size of remote file: 803 kB
examples/example4.jpg ADDED

Git LFS Details

  • SHA256: 7bbf81b4d1b4eb69b78a969a7364ebd280d4e918c27d2fb327e5624994e9f0f5
  • Pointer size: 132 Bytes
  • Size of remote file: 2.48 MB
examples/example5.jpg ADDED

Git LFS Details

  • SHA256: 2e5783394cf58ce5f5725b54701e0ae90ed6595e9ae4bb6b53fd7cb08666885f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.85 MB
examples/example6.jpg ADDED

Git LFS Details

  • SHA256: 6c663c8c65e8568be4241bc269204910a5c7c09878bcdb6872a523e2d6045889
  • Pointer size: 131 Bytes
  • Size of remote file: 480 kB
image_utils.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains utility functions for image loading, preparation, and manipulation.
3
+ Includes HEIC image format support via the optional 'pillow-heif' library.
4
+ """
5
+
6
+ from PIL import Image, ImageOps, ImageDraw
7
+ import os
8
+
9
+ try:
10
+ from pillow_heif import register_heif_opener
11
+ register_heif_opener()
12
+ print("HEIC opener registered successfully using pillow-heif.")
13
+ _heic_support = True
14
+ except ImportError:
15
+ print("Warning: pillow-heif not installed. HEIC/HEIF support will be disabled.")
16
+ _heic_support = False
17
+
18
+
19
+ print("Loading Image Utils...")
20
+
21
+ def prepare_image(image_filepath, target_size=512):
22
+ """
23
+ Prepares an input image file for the diffusion pipeline.
24
+
25
+ Loads an image from the given filepath (supports standard formats like
26
+ JPG, PNG, WEBP, and HEIC/HEIF),
27
+ ensures it's in RGB format, handles EXIF orientation, and performs
28
+ a forced resize to a square target_size, ignoring the original aspect ratio.
29
+
30
+ Args:
31
+ image_filepath (str): The path to the image file.
32
+ target_size (int): The target dimension for both width and height.
33
+
34
+ Returns:
35
+ PIL.Image.Image | None: The prepared image as a PIL Image object in RGB format,
36
+ or None if loading or processing fails.
37
+ """
38
+ if image_filepath is None:
39
+ print("Warning: prepare_image received None filepath.")
40
+ return None
41
+
42
+ if not isinstance(image_filepath, str) or not os.path.exists(image_filepath):
43
+ print(f"Error: Invalid filepath provided to prepare_image: {image_filepath}")
44
+ if isinstance(image_filepath, Image.Image):
45
+ print("Warning: Received PIL Image instead of filepath, proceeding...")
46
+ image = image_filepath
47
+ else:
48
+ return None
49
+ else:
50
+ # --- Load Image from Filepath ---
51
+ print(f"Loading image from path: {image_filepath}")
52
+ try:
53
+ image = Image.open(image_filepath)
54
+ except ImportError as e:
55
+ print(f"ImportError during Image.open: {e}. Is pillow-heif installed?")
56
+ print("Cannot process image format.")
57
+ return None
58
+ except Exception as e:
59
+ print(f"Error opening image file {image_filepath} with PIL: {e}")
60
+ return None
61
+
62
+ # --- Process PIL Image ---
63
+ try:
64
+ image = ImageOps.exif_transpose(image)
65
+
66
+ image = image.convert("RGB")
67
+
68
+ original_width, original_height = image.size
69
+
70
+ final_width = target_size
71
+ final_height = target_size
72
+
73
+ resized_image = image.resize((final_width, final_height), Image.LANCZOS)
74
+
75
+ print(f"Original size: ({original_width}, {original_height}), FORCED Resized to: ({final_width}, {final_height})")
76
+ return resized_image
77
+ except Exception as e:
78
+ print(f"Error during PIL image processing steps: {e}")
79
+ return None
80
+
81
+ def create_blend_mask(tile_size=1024, overlap=256):
82
+ """
83
+ Creates a feathered blending mask (alpha mask) for smooth tile stitching.
84
+
85
+ Generates a square mask where the edges have a linear gradient ramp within
86
+ the specified overlap zone, and the central area is fully opaque.
87
+ Assumes overlap occurs equally on all four sides.
88
+
89
+ Args:
90
+ tile_size (int): The dimension (width and height) of the tiles being processed.
91
+ overlap (int): The number of pixels that overlap between adjacent tiles.
92
+
93
+ Returns:
94
+ PIL.Image.Image: The blending mask as a PIL Image object in 'L' (grayscale) mode.
95
+ White (255) areas are fully opaque, black (0) are transparent,
96
+ gray values provide blending.
97
+ """
98
+ if overlap >= tile_size // 2:
99
+ print("Warning: Overlap is large relative to tile size, mask generation might be suboptimal.")
100
+ overlap = tile_size // 2 - 1
101
+
102
+ mask = Image.new("L", (tile_size, tile_size), 0)
103
+ draw = ImageDraw.Draw(mask)
104
+
105
+ if overlap > 0:
106
+ for i in range(overlap):
107
+ alpha = int(255 * (i / float(overlap)))
108
+
109
+ # Left edge ramp
110
+ draw.line([(i, 0), (i, tile_size)], fill=alpha)
111
+ # Right edge ramp
112
+ draw.line([(tile_size - 1 - i, 0), (tile_size - 1 - i, tile_size)], fill=alpha)
113
+ # Top edge ramp
114
+ draw.line([(0, i), (tile_size, i)], fill=alpha)
115
+ # Bottom edge ramp
116
+ draw.line([(0, tile_size - 1 - i), (tile_size, tile_size - 1 - i)], fill=alpha)
117
+
118
+ center_start = overlap
119
+ center_end_x = tile_size - overlap
120
+ center_end_y = tile_size - overlap
121
+
122
+ if center_end_x > center_start and center_end_y > center_start:
123
+ draw.rectangle( (center_start, center_start, center_end_x - 1, center_end_y - 1), fill=255 )
124
+ else:
125
+ center_x, center_y = tile_size // 2, tile_size // 2
126
+ draw.point((center_x, center_y), fill=255)
127
+ if tile_size % 2 == 0:
128
+ draw.point((center_x-1, center_y), fill=255)
129
+ draw.point((center_x, center_y-1), fill=255)
130
+ draw.point((center_x-1, center_y-1), fill=255)
131
+
132
+
133
+ print(f"Blend mask created (Size: {tile_size}x{tile_size}, Overlap: {overlap})")
134
+ return mask
loras/add_detail.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47aaaf0d2945ca937151d61304946dd229b3f072140b85484bc93e38f2a6e2f7
3
+ size 37861176
loras/night_comic_V06.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:caf9280080bf4183a9064547a9374bc22575248867a9bcc54a883305b57a8ebb
3
+ size 14153788
model_loader.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Handles the loading and management of necessary AI models from Hugging Face Hub.
3
+
4
+ Provides functions to load models once at startup and access them throughout
5
+ the application, managing device placement (CPU/GPU) and data types.
6
+ Optimized for typical Hugging Face Space GPU environments.
7
+ """
8
+
9
+ import torch
10
+ from diffusers import ControlNetModel
11
+ from controlnet_aux import OpenposeDetector
12
+ import gc
13
+
14
+ # --- Configuration ---
15
+ # Automatically detect CUDA availability and set appropriate device/dtype
16
+ if torch.cuda.is_available():
17
+ DEVICE = "cuda"
18
+ DTYPE = torch.float16
19
+ print(f"CUDA available. Using Device: {DEVICE}, Dtype: {DTYPE}")
20
+ try:
21
+ print(f"GPU Name: {torch.cuda.get_device_name(0)}")
22
+ except Exception as e:
23
+ print(f"Couldn't get GPU name: {e}")
24
+ else:
25
+ DEVICE = "cpu"
26
+ DTYPE = torch.float32
27
+ print(f"CUDA not available. Using Device: {DEVICE}, Dtype: {DTYPE}")
28
+
29
+
30
+ # Model IDs from Hugging Face Hub
31
+ # BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5" # Base SD model ID needed by pipelines
32
+ OPENPOSE_DETECTOR_ID = 'lllyasviel/ControlNet' # Preprocessor model repo
33
+ CONTROLNET_POSE_MODEL_ID = "lllyasviel/sd-controlnet-openpose" # OpenPose ControlNet weights
34
+ CONTROLNET_TILE_MODEL_ID = "lllyasviel/control_v11f1e_sd15_tile" # Tile ControlNet weights
35
+
36
+ _openpose_detector = None
37
+ _controlnet_pose = None
38
+ _controlnet_tile = None
39
+ _models_loaded = False
40
+
41
+ # --- Loading Function ---
42
+
43
+ def load_models(force_reload=False):
44
+ """
45
+ Loads the OpenPose detector (to CPU) and ControlNet models (to configured DEVICE).
46
+
47
+ This function should typically be called once when the application starts.
48
+ It checks if models are already loaded to prevent redundant loading unless
49
+ `force_reload` is True.
50
+
51
+ Args:
52
+ force_reload (bool): If True, forces reloading even if models are already loaded.
53
+
54
+ Returns:
55
+ bool: True if all models were loaded successfully (or already were), False otherwise.
56
+ """
57
+ global _openpose_detector, _controlnet_pose, _controlnet_tile, _models_loaded
58
+
59
+ if _models_loaded and not force_reload:
60
+ print("Models already loaded.")
61
+ return True
62
+
63
+ print(f"--- Loading Models ---")
64
+ if DEVICE == "cuda":
65
+ print("Performing initial CUDA cache clear...")
66
+ gc.collect()
67
+ torch.cuda.empty_cache()
68
+
69
+ # 1. OpenPose Detector
70
+ try:
71
+ print(f"Loading OpenPose Detector from {OPENPOSE_DETECTOR_ID} to CPU...")
72
+ _openpose_detector = OpenposeDetector.from_pretrained(OPENPOSE_DETECTOR_ID)
73
+ print("OpenPose detector loaded successfully (on CPU).")
74
+ except Exception as e:
75
+ print(f"ERROR: Failed to load OpenPose Detector: {e}")
76
+ _models_loaded = False
77
+ return False
78
+
79
+ # 2. ControlNet Models
80
+ try:
81
+ print(f"Loading ControlNet Pose Model from {CONTROLNET_POSE_MODEL_ID} to {DEVICE} ({DTYPE})...")
82
+ _controlnet_pose = ControlNetModel.from_pretrained(
83
+ CONTROLNET_POSE_MODEL_ID, torch_dtype=DTYPE
84
+ )
85
+ _controlnet_pose.to(DEVICE)
86
+ print("ControlNet Pose model loaded successfully.")
87
+ except Exception as e:
88
+ print(f"ERROR: Failed to load ControlNet Pose Model: {e}")
89
+ _models_loaded = False
90
+ return False
91
+
92
+ try:
93
+ print(f"Loading ControlNet Tile Model from {CONTROLNET_TILE_MODEL_ID} to {DEVICE} ({DTYPE})...")
94
+ _controlnet_tile = ControlNetModel.from_pretrained(
95
+ CONTROLNET_TILE_MODEL_ID, torch_dtype=DTYPE
96
+ )
97
+ _controlnet_tile.to(DEVICE)
98
+ print("ControlNet Tile model loaded successfully.")
99
+ except Exception as e:
100
+ print(f"ERROR: Failed to load ControlNet Tile Model: {e}")
101
+ _models_loaded = False
102
+ return False
103
+
104
+ _models_loaded = True
105
+ print("--- All prerequisite models loaded successfully. ---")
106
+ if DEVICE == "cuda":
107
+ print("Performing post-load CUDA cache clear...")
108
+ gc.collect()
109
+ torch.cuda.empty_cache()
110
+ return True
111
+
112
+ # --- Getter Functions ---
113
+
114
+ def get_openpose_detector():
115
+ if not _models_loaded: load_models()
116
+ return _openpose_detector
117
+
118
+ def get_controlnet_pose():
119
+ if not _models_loaded: load_models()
120
+ return _controlnet_pose
121
+
122
+ def get_controlnet_tile():
123
+ if not _models_loaded: load_models()
124
+ return _controlnet_tile
125
+
126
+ def get_device():
127
+ return DEVICE
128
+
129
+ def get_dtype():
130
+ return DTYPE
131
+
132
+ def are_models_loaded():
133
+ return _models_loaded
pipelines.py ADDED
@@ -0,0 +1,433 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Contains functions to execute the main image generation stages:
3
+ 1. OpenPose Detection: Extracts pose information.
4
+ 2. Low-Resolution Generation: Creates initial image using Pose ControlNet.
5
+ 3. High-Resolution Tiling: Upscales the low-res image using Tile ControlNet.
6
+
7
+ Manages dynamic loading/unloading of diffusion pipelines to conserve VRAM.
8
+ """
9
+
10
+ import torch
11
+ import gc
12
+ import time
13
+ import os
14
+ from PIL import Image
15
+ from tqdm.auto import tqdm
16
+ import gradio as gr
17
+ from diffusers import (
18
+ StableDiffusionControlNetImg2ImgPipeline,
19
+ UniPCMultistepScheduler,
20
+ )
21
+ from model_loader import (
22
+ get_openpose_detector,
23
+ get_controlnet_pose,
24
+ get_controlnet_tile,
25
+ get_device,
26
+ get_dtype,
27
+ are_models_loaded,
28
+ )
29
+ from image_utils import create_blend_mask
30
+ from prompts import get_prompts_for_run
31
+
32
+
33
+ # --- Configuration ---
34
+ BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5"
35
+ LORA_DIR = "loras"
36
+ LORA_FILES = {
37
+ "style": os.path.join(LORA_DIR, "night_comic_V06.safetensors"),
38
+ "detail": os.path.join(LORA_DIR, "add_detail.safetensors"),
39
+ }
40
+ LORA_WEIGHTS_LOWRES = [1, 1]
41
+ LORA_WEIGHTS_HIRES = [1, 2]
42
+ ACTIVE_ADAPTERS = ["style", "detail"]
43
+
44
+ def cleanup_memory():
45
+ """Forces garbage collection and clears CUDA cache."""
46
+ gc.collect()
47
+ if torch.cuda.is_available():
48
+ torch.cuda.empty_cache()
49
+
50
+ # --- Stage 1: OpenPose Detection ---
51
+ def run_pose_detection(resized_input_image):
52
+ """
53
+ Detects human pose (body, hands, face) from the input image using OpenPose.
54
+
55
+ Temporarily moves the OpenPose detector model to the active GPU (if available)
56
+ for processing and then moves it back to the CPU to conserve VRAM.
57
+
58
+ Args:
59
+ input_image_resized (PIL.Image.Image): The input image, already resized
60
+ and in RGB format.
61
+
62
+ Returns:
63
+ PIL.Image.Image | None: A PIL Image representing the detected pose map,
64
+ or None if detection fails or models aren't loaded.
65
+ """
66
+ if not are_models_loaded():
67
+ print("Error: Cannot run pose detection, models not loaded.")
68
+ return None
69
+
70
+ detector = get_openpose_detector()
71
+ device = get_device()
72
+ control_image_openpose = None
73
+
74
+ if detector is None:
75
+ print("Error: OpenPose detector is None.")
76
+ return None
77
+
78
+ try:
79
+ detector.to(device)
80
+ cleanup_memory()
81
+
82
+ control_image_openpose = detector(
83
+ resized_input_image, include_face=True, include_hand=True
84
+ )
85
+
86
+ except Exception as e:
87
+ print(f"ERROR during OpenPose detection: {e}")
88
+ control_image_openpose = None
89
+ finally:
90
+ detector.to("cpu")
91
+ cleanup_memory()
92
+
93
+ return control_image_openpose
94
+
95
+ # --- Stage 2: Low-Resolution Generation ---
96
+ def run_low_res_generation(
97
+ resized_input_image,
98
+ pose_map,
99
+ seed,
100
+ steps,
101
+ guidance_scale,
102
+ strength,
103
+ controlnet_scale=0.8,
104
+ progress=gr.Progress(track_tqdm=True)
105
+ ):
106
+ """
107
+ Generates the initial low-resolution image using Img2Img with Pose ControlNet.
108
+
109
+ Dynamically loads the StableDiffusionControlNetImg2ImgPipeline, applies LoRAs,
110
+ runs inference, and then unloads the pipeline to free VRAM before returning.
111
+
112
+ Args:
113
+ input_image_resized (PIL.Image.Image): The resized input image.
114
+ pose_map (PIL.Image.Image): The pose map generated by run_pose_detection.
115
+ seed (int): The random seed for generation.
116
+ steps (int): Number of diffusion inference steps.
117
+ guidance_scale (float): Classifier-free guidance scale.
118
+ strength (float): Img2Img strength (0.0 to 1.0). How much noise to add.
119
+ controlnet_scale (float): Conditioning scale for the Pose ControlNet.
120
+ progress (gr.Progress): Gradio progress object for UI updates.
121
+
122
+ Returns:
123
+ PIL.Image.Image | None: The generated low-resolution PIL Image, or None if an error occurs.
124
+
125
+ Raises:
126
+ gr.Error: Raises a Gradio error if generation fails catastrophically.
127
+ """
128
+ if not are_models_loaded() or pose_map is None:
129
+ error_msg = "Cannot run low-res generation: "
130
+ if not are_models_loaded(): error_msg += "Models not loaded. "
131
+ if pose_map is None: error_msg += "Pose map is missing."
132
+ print(f"Error: {error_msg}")
133
+ return None
134
+
135
+ device = get_device()
136
+ dtype = get_dtype()
137
+ controlnet_pose = get_controlnet_pose()
138
+ output_image_lowres = None
139
+ pipe_lowres = None
140
+
141
+ positive_prompt, negative_prompt, _, _ = get_prompts_for_run()
142
+ generator = torch.Generator(device=device).manual_seed(int(seed))
143
+
144
+ progress(0, desc="Loading Low-Res Pipeline...")
145
+ try:
146
+ # 1. Load Pipeline
147
+ pipe_lowres = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
148
+ BASE_MODEL_ID,
149
+ controlnet=controlnet_pose,
150
+ torch_dtype=dtype,
151
+ safety_checker=None
152
+ )
153
+ pipe_lowres.scheduler = UniPCMultistepScheduler.from_config(pipe_lowres.scheduler.config)
154
+ pipe_lowres.to(device)
155
+
156
+ cleanup_memory()
157
+
158
+ # 2. Load LoRAs
159
+ if os.path.exists(LORA_FILES["style"]) and os.path.exists(LORA_FILES["detail"]):
160
+ pipe_lowres.load_lora_weights(LORA_FILES["style"], adapter_name="style")
161
+ pipe_lowres.load_lora_weights(LORA_FILES["detail"], adapter_name="detail")
162
+ pipe_lowres.set_adapters(ACTIVE_ADAPTERS, adapter_weights=LORA_WEIGHTS_LOWRES)
163
+ print(f"Activated LoRAs: {ACTIVE_ADAPTERS} with weights {LORA_WEIGHTS_LOWRES}")
164
+ else:
165
+ print("Warning: One or both LoRA files not found. Skipping LoRA loading.")
166
+ raise gr.Error("Required LoRA files not found in loras/ directory.")
167
+
168
+ # 3. Run Inference
169
+ progress(0.3, desc="Generating Low-Res Image...")
170
+ output_image_low_res = pipe_lowres(
171
+ prompt=positive_prompt,
172
+ negative_prompt=negative_prompt,
173
+ image=resized_input_image,
174
+ control_image=pose_map,
175
+ num_inference_steps=int(steps),
176
+ strength=strength,
177
+ guidance_scale=guidance_scale,
178
+ controlnet_conditioning_scale=float(controlnet_scale),
179
+ generator=generator,
180
+ ).images[0]
181
+ progress(0.9, desc="Low-Res Complete")
182
+
183
+ except Exception as e:
184
+ print(f"ERROR during Low-Res Generation Pipeline: {e}")
185
+ import traceback
186
+ traceback.print_exc()
187
+ output_image_low_res = None
188
+ raise gr.Error(f"Failed during low-res generation: {e}")
189
+ finally:
190
+ # 4. Cleanup Pipeline
191
+ print("Cleaning up Low-Res pipeline...")
192
+ if pipe_lowres is not None:
193
+ try:
194
+ if hasattr(pipe_lowres, 'get_active_adapters') and pipe_lowres.get_active_adapters():
195
+ print("Unloading LoRAs...")
196
+ pipe_lowres.unload_lora_weights()
197
+ except Exception as unload_e:
198
+ print(f"Note: Error unloading LoRAs: {unload_e}")
199
+
200
+ print("Moving Low-Res pipe components to CPU before deleting...")
201
+ try: pipe_lowres.to('cpu')
202
+ except Exception as cpu_e: print(f"Note: Error moving pipe to CPU: {cpu_e}")
203
+
204
+ print("Deleting Low-Res pipeline object...")
205
+ del pipe_lowres
206
+ pipe_lowres = None
207
+
208
+ print("Running garbage collection and emptying CUDA cache after Low-Res...")
209
+ cleanup_memory()
210
+ # time.sleep(1)
211
+
212
+ print("--- Low-Res Generation Stage Finished ---")
213
+ return output_image_low_res
214
+
215
+ # --- Stage 3: High-Resolution Tiling Upscaling ---
216
+ def run_hires_tiling(
217
+ low_res_image,
218
+ seed,
219
+ steps,
220
+ guidance_scale,
221
+ strength,
222
+ controlnet_scale=1.0,
223
+ upscale_factor=2,
224
+ tile_size=1024,
225
+ tile_stride=1024,
226
+ progress=gr.Progress(track_tqdm=True)
227
+ ):
228
+ """
229
+ Upscales the low-resolution image using tiling with the Tile ControlNet.
230
+
231
+ Dynamically loads the StableDiffusionControlNetImg2ImgPipeline for tiling,
232
+ applies LoRAs, processes the image in overlapping tiles, blends the results,
233
+ and unloads the pipeline to free VRAM.
234
+
235
+ Args:
236
+ low_res_image (PIL.Image.Image): The low-resolution image from the previous stage.
237
+ seed (int): The random seed (should ideally match low-res stage seed).
238
+ steps (int): Number of diffusion inference steps per tile.
239
+ guidance_scale (float): Classifier-free guidance scale for tiles.
240
+ strength (float): Img2Img strength for tiling (controls detail vs. original).
241
+ controlnet_scale (float): Conditioning scale for the Tile ControlNet.
242
+ upscale_factor (int): Factor by which to increase the image resolution.
243
+ tile_size (int): Size of the square tiles to process.
244
+ tile_stride (int): Step size between tiles. Overlap = tile_size - tile_stride.
245
+ progress (gr.Progress): Gradio progress object for UI updates.
246
+
247
+ Returns:
248
+ PIL.Image.Image | None: The generated high-resolution PIL Image, or None if an error occurs.
249
+
250
+ Raises:
251
+ gr.Error: Raises a Gradio error if tiling fails catastrophically.
252
+ """
253
+ if not are_models_loaded() or low_res_image is None:
254
+ error_msg = "Cannot run hi-res tiling: "
255
+ if not are_models_loaded(): error_msg += "Models not loaded. "
256
+ if low_res_image is None: error_msg += "Low-res image is missing."
257
+ print(f"Error: {error_msg}")
258
+ return None
259
+
260
+ device = get_device()
261
+ dtype = get_dtype()
262
+ controlnet_tile = get_controlnet_tile()
263
+ high_res_output_image = None
264
+ pipe_hires = None
265
+
266
+ _, _, positive_prompt_tile, negative_prompt_tile = get_prompts_for_run()
267
+
268
+ generator_tile = torch.Generator(device=device).manual_seed(int(seed))
269
+
270
+ print("\n--- Starting Hi-Res Tiling Stage ---")
271
+ progress(0, desc="Preparing for Tiling...")
272
+
273
+ try:
274
+ # --- Setup Tiling Parameters ---
275
+ target_width = low_res_image.width * upscale_factor
276
+ target_height = low_res_image.height * upscale_factor
277
+ if tile_size > min(target_width, target_height):
278
+ print(f"Warning: Tile size ({tile_size}) > target dimension ({target_width}x{target_height}). Clamping tile size.")
279
+ tile_size = min(target_width, target_height)
280
+ tile_stride = tile_size
281
+
282
+ overlap = tile_size - tile_stride
283
+ if overlap < 0:
284
+ print("Warning: Tile stride is larger than tile size. Setting stride = tile size.")
285
+ tile_stride = tile_size
286
+ overlap = 0
287
+
288
+ print(f"Target Res: {target_width}x{target_height}, Tile Size: {tile_size}, Stride: {tile_stride}, Overlap: {overlap}")
289
+
290
+ # 1. Load Pipeline
291
+ print(f"Loading Hi-Res Pipeline ({BASE_MODEL_ID} + Tile ControlNet)...")
292
+ progress(0.05, desc="Loading Hi-Res Pipeline...")
293
+ pipe_hires = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
294
+ BASE_MODEL_ID,
295
+ controlnet=controlnet_tile,
296
+ torch_dtype=dtype,
297
+ safety_checker=None,
298
+ )
299
+ pipe_hires.scheduler = UniPCMultistepScheduler.from_config(pipe_hires.scheduler.config)
300
+ pipe_hires.to(device)
301
+ # pipe_hires.enable_model_cpu_offload()
302
+ # pipe_hires.enable_xformers_memory_efficient_attention()
303
+ print("Hi-Res Pipeline loaded to GPU.")
304
+ cleanup_memory()
305
+
306
+ # 2. Load LoRAs
307
+ print("Loading LoRAs for Hi-Res pipe...")
308
+ if os.path.exists(LORA_FILES["style"]) and os.path.exists(LORA_FILES["detail"]):
309
+ pipe_hires.load_lora_weights(LORA_FILES["style"], adapter_name="style")
310
+ pipe_hires.load_lora_weights(LORA_FILES["detail"], adapter_name="detail")
311
+ pipe_hires.set_adapters(ACTIVE_ADAPTERS, adapter_weights=LORA_WEIGHTS_HIRES)
312
+ print(f"Activated LoRAs: {ACTIVE_ADAPTERS} with weights {LORA_WEIGHTS_HIRES}")
313
+ else:
314
+ print("Warning: One or both LoRA files not found. Skipping LoRA loading.")
315
+ raise gr.Error("Required LoRA files not found in loras/ directory.")
316
+
317
+
318
+ # --- Prepare for Tiling Loop ---
319
+ print(f"Creating blurry base image ({target_width}x{target_height})...")
320
+ progress(0.15, desc="Preparing Base Image...")
321
+ blurry_high_res = low_res_image.resize((target_width, target_height), Image.LANCZOS)
322
+
323
+ final_image = Image.new("RGB", (target_width, target_height))
324
+ blend_mask = create_blend_mask(tile_size, overlap)
325
+
326
+ num_tiles_x = (target_width + tile_stride - 1) // tile_stride
327
+ num_tiles_y = (target_height + tile_stride - 1) // tile_stride
328
+ total_tiles = num_tiles_x * num_tiles_y
329
+ print(f"Processing {num_tiles_x}x{num_tiles_y} = {total_tiles} tiles...")
330
+
331
+ # --- Tiling Loop ---
332
+ progress(0.2, desc=f"Processing Tiles (0/{total_tiles})")
333
+ processed_tile_count = 0
334
+ with tqdm(total=total_tiles, desc="Tiling Upscale") as pbar:
335
+ for y in range(num_tiles_y):
336
+ for x in range(num_tiles_x):
337
+ tile_start_time = time.time()
338
+ pbar.set_description(f"Tiling Upscale (Tile {processed_tile_count+1}/{total_tiles})")
339
+
340
+ x_start = x * tile_stride
341
+ y_start = y * tile_stride
342
+ x_end = min(x_start + tile_size, target_width)
343
+ y_end = min(y_start + tile_size, target_height)
344
+ crop_box = (x_start, y_start, x_end, y_end)
345
+
346
+ tile_image_blurry = blurry_high_res.crop(crop_box)
347
+ current_tile_width, current_tile_height = tile_image_blurry.size
348
+
349
+ if current_tile_width < tile_size or current_tile_height < tile_size:
350
+ try: edge_color = tile_image_blurry.getpixel((0, 0))
351
+ except IndexError: edge_color = (127, 127, 127)
352
+ padded_tile = Image.new("RGB", (tile_size, tile_size), edge_color)
353
+ padded_tile.paste(tile_image_blurry, (0, 0))
354
+ tile_image_blurry = padded_tile
355
+ print(f"Padded edge tile at ({x},{y})")
356
+
357
+
358
+ # 3. Run Inference on the Tile
359
+ with torch.inference_mode():
360
+ output_tile = pipe_hires(
361
+ prompt=positive_prompt_tile,
362
+ negative_prompt=negative_prompt_tile,
363
+ image=tile_image_blurry,
364
+ control_image=tile_image_blurry,
365
+ num_inference_steps=int(steps),
366
+ strength=strength,
367
+ guidance_scale=guidance_scale,
368
+ controlnet_conditioning_scale=float(controlnet_scale),
369
+ generator=generator_tile,
370
+ output_type="pil"
371
+ ).images[0]
372
+
373
+ # --- Stitch Tile Back ---
374
+ paste_x = x_start
375
+ paste_y = y_start
376
+ crop_w = x_end - x_start
377
+ crop_h = y_end - y_start
378
+
379
+ output_tile_region = output_tile.crop((0, 0, crop_w, crop_h))
380
+
381
+ if overlap > 0:
382
+ blend_mask_region = blend_mask.crop((0, 0, crop_w, crop_h))
383
+ current_content_region = final_image.crop((paste_x, paste_y, paste_x + crop_w, paste_y + crop_h))
384
+ blended_tile_region = Image.composite(output_tile_region, current_content_region, blend_mask_region)
385
+ final_image.paste(blended_tile_region, (paste_x, paste_y))
386
+ else:
387
+ final_image.paste(output_tile_region, (paste_x, paste_y))
388
+
389
+ processed_tile_count += 1
390
+ pbar.update(1)
391
+
392
+ # Update Gradio progress
393
+ gradio_progress = 0.2 + 0.75 * (processed_tile_count / total_tiles)
394
+ progress(gradio_progress, desc=f"Processing Tile {processed_tile_count}/{total_tiles}")
395
+
396
+ tile_end_time = time.time()
397
+ print(f"Tile ({x},{y}) processed in {tile_end_time - tile_start_time:.2f}s")
398
+ # cleanup_memory()
399
+
400
+ print("Tile processing complete.")
401
+ high_res_output_image = final_image
402
+ progress(0.95, desc="Tiling Complete")
403
+
404
+ except Exception as e:
405
+ print(f"ERROR during Hi-Res Tiling Pipeline: {e}")
406
+ import traceback
407
+ traceback.print_exc()
408
+ high_res_output_image = None
409
+ raise gr.Error(f"Failed during hi-res tiling: {e}")
410
+ finally:
411
+ # 4. Cleanup Pipeline
412
+ print("Cleaning up Hi-Res pipeline...")
413
+ if pipe_hires is not None:
414
+ try:
415
+ if hasattr(pipe_hires, 'get_active_adapters') and pipe_hires.get_active_adapters():
416
+ print("Unloading LoRAs...")
417
+ pipe_hires.unload_lora_weights()
418
+ except Exception as unload_e:
419
+ print(f"Note: Error unloading LoRAs: {unload_e}")
420
+
421
+ print("Moving Hi-Res pipe components to CPU before deleting...")
422
+ try: pipe_hires.to('cpu')
423
+ except Exception as cpu_e: print(f"Note: Error moving pipe to CPU: {cpu_e}")
424
+
425
+ print("Deleting Hi-Res pipeline object...")
426
+ del pipe_hires
427
+ pipe_hires = None
428
+
429
+ print("Running garbage collection and emptying CUDA cache after Hi-Res...")
430
+ cleanup_memory()
431
+
432
+ print("--- Hi-Res Tiling Stage Finished ---")
433
+ return high_res_output_image
prompts.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ """
4
+ Defines fixed prompts and provides a function to generate
5
+ randomized prompts for each run, mirroring the original notebook behavior.
6
+ Used by the main pipeline functions.
7
+ """
8
+
9
+ BASE_PROMPT = "detailed face portrait, accurate facial features, natural features, clear eyes, keep the gender same as the input image"
10
+
11
+ STYLE_PROMPT = r"((LIMITED PALETTE)), ((RETRO COMIC)), ((1940S \(STYLE\))), ((WESTERN COMICS \(STYLE\))), ((NIGHT COMIC)), detailed illustration, sharp lines, sfw"
12
+
13
+ BASE_NEGATIVE_PROMPT = (
14
+ "generic face, distorted features, unrealistic face, bad anatomy, extra limbs, fused fingers, poorly drawn hands, poorly drawn face, "
15
+ "text, signature, watermark, letters, words, username, artist name, speech bubble, multiple panels, "
16
+ "ugly, disfigured, deformed, low quality, worst quality, blurry, jpeg artifacts, noisy, "
17
+ "weapon, gun, knife, violence, gore, blood, injury, mutilated, horrific, nsfw, nude, naked, explicit, sexual, lingerie, bikini, suggestive, provocative, disturbing, scary, offensive, illegal, unlawful"
18
+ )
19
+
20
+ # --- Background Generation Elements ---
21
+ BG_SETTINGS = [
22
+ "on a futuristic city street at night", "in a retro sci-fi control room", "in a dusty western saloon",
23
+ "in front of an abstract energy field", "in a neon-lit alleyway", "in a stark cyberpunk cityscape",
24
+ "with speed lines background", "in a manga panel frame", "in a dimly lit laboratory",
25
+ "against a dramatic explosive background", "in a cluttered artist studio", "in a dynamic action scene"
26
+ ]
27
+
28
+ BG_DETAILS = [
29
+ "detailed background", "cinematic lighting", "dramatic shadows",
30
+ "high contrast", "low angle shot", "dynamic composition", "atmospheric perspective", "intricate details"
31
+ ]
32
+
33
+ def get_prompts_for_run():
34
+ """
35
+ Generates the prompts needed for one generation cycle,
36
+ including a newly randomized background for the low-res stage.
37
+ Returns prompts suitable for low-res and hi-res stages.
38
+ """
39
+ # --- Low-Res Prompt Generation ---
40
+ chosen_bg_setting = random.choice(BG_SETTINGS)
41
+ chosen_bg_detail = random.choice(BG_DETAILS)
42
+ background_prompt = f"{chosen_bg_setting}, {chosen_bg_detail}"
43
+ positive_prompt_lowres = f"{BASE_PROMPT}, {STYLE_PROMPT}, {background_prompt}"
44
+
45
+ # --- Tile Prompt Generation ---
46
+ positive_prompt_tile = f"{BASE_PROMPT}, {STYLE_PROMPT}"
47
+
48
+ negative_prompt_tile = (
49
+ BASE_NEGATIVE_PROMPT +
50
+ ", blurry face, distorted face, mangled face, bad face, low quality, blurry"
51
+ )
52
+
53
+ return positive_prompt_lowres, BASE_NEGATIVE_PROMPT, positive_prompt_tile, negative_prompt_tile
requirements.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base ML
2
+ torch==2.7.0
3
+ torchvision==0.22.0
4
+ torchaudio==2.7.0
5
+ accelerate==1.6.0
6
+
7
+ # Diffusers & Transformers
8
+ diffusers==0.33.1
9
+ transformers==4.51.3
10
+ peft==0.15.2
11
+
12
+ # ControlNet & Auxiliaries
13
+ controlnet_aux==0.0.9
14
+ mediapipe
15
+ matplotlib
16
+ opencv-python-headless==4.11.0.86
17
+ Pillow==11.2.1
18
+ pillow-heif==0.22.0
19
+ numpy==2.2.5
20
+
21
+ # Web UI
22
+ gradio==5.29.0