Spaces:
Running
Running
merging to main
Browse files- .gitattributes +2 -35
- .gitignore +127 -0
- README.md +111 -14
- app.py +272 -0
- examples/example1.jpg +3 -0
- examples/example2.jpg +3 -0
- examples/example3.jpg +3 -0
- examples/example4.jpg +3 -0
- examples/example5.jpg +3 -0
- examples/example6.jpg +3 -0
- image_utils.py +134 -0
- loras/add_detail.safetensors +3 -0
- loras/night_comic_V06.safetensors +3 -0
- model_loader.py +133 -0
- pipelines.py +433 -0
- prompts.py +53 -0
- requirements.txt +22 -0
.gitattributes
CHANGED
@@ -1,35 +1,2 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
loras/*.safetensors filter=lfs diff=lfs merge=lfs -text
|
2 |
+
examples/*.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
*.egg-info/
|
24 |
+
.installed.cfg
|
25 |
+
*.egg
|
26 |
+
MANIFEST
|
27 |
+
|
28 |
+
# PyInstaller
|
29 |
+
*.manifest
|
30 |
+
*.spec
|
31 |
+
|
32 |
+
# Installer logs
|
33 |
+
pip-log.txt
|
34 |
+
pip-delete-this-directory.txt
|
35 |
+
|
36 |
+
# Unit test / coverage reports
|
37 |
+
htmlcov/
|
38 |
+
.tox/
|
39 |
+
.nox/
|
40 |
+
.coverage
|
41 |
+
.coverage.*
|
42 |
+
.cache
|
43 |
+
nosetests.xml
|
44 |
+
coverage.xml
|
45 |
+
*.cover
|
46 |
+
*.py,cover
|
47 |
+
.hypothesis/
|
48 |
+
.pytest_cache/
|
49 |
+
|
50 |
+
# Translations
|
51 |
+
*.mo
|
52 |
+
*.pot
|
53 |
+
|
54 |
+
# Django stuff:
|
55 |
+
*.log
|
56 |
+
local_settings.py
|
57 |
+
db.sqlite3
|
58 |
+
|
59 |
+
# Flask stuff:
|
60 |
+
instance/
|
61 |
+
.webassets-cache
|
62 |
+
|
63 |
+
# Scrapy stuff:
|
64 |
+
.scrapy
|
65 |
+
|
66 |
+
# Sphinx documentation
|
67 |
+
docs/_build/
|
68 |
+
|
69 |
+
# PyBuilder
|
70 |
+
target/
|
71 |
+
|
72 |
+
# Jupyter Notebook
|
73 |
+
.ipynb_checkpoints
|
74 |
+
|
75 |
+
# IPython
|
76 |
+
profile_default/
|
77 |
+
ipython_config.py
|
78 |
+
|
79 |
+
# pyenv
|
80 |
+
.python-version
|
81 |
+
|
82 |
+
# PEP 582; used by PDM, PEP 582 compatible tooling
|
83 |
+
__pypackages__/
|
84 |
+
|
85 |
+
# Celery stuff
|
86 |
+
celerybeat-schedule
|
87 |
+
celerybeat.pid
|
88 |
+
|
89 |
+
# SageMath parsed files
|
90 |
+
*.sage.py
|
91 |
+
|
92 |
+
# Environments
|
93 |
+
.env
|
94 |
+
.venv
|
95 |
+
env/
|
96 |
+
venv/
|
97 |
+
ENV/
|
98 |
+
env.bak/
|
99 |
+
venv.bak/
|
100 |
+
|
101 |
+
# Spyder project settings
|
102 |
+
.spyderproject
|
103 |
+
.spyproject
|
104 |
+
|
105 |
+
# Rope project settings
|
106 |
+
.ropeproject
|
107 |
+
|
108 |
+
# mkdocs documentation
|
109 |
+
/site
|
110 |
+
|
111 |
+
# mypy
|
112 |
+
.mypy_cache/
|
113 |
+
.dmypy.json
|
114 |
+
dmypy.json
|
115 |
+
|
116 |
+
# Pyre type checker
|
117 |
+
.pyre/
|
118 |
+
|
119 |
+
# OS generated files #
|
120 |
+
######################
|
121 |
+
.DS_Store
|
122 |
+
.DS_Store?
|
123 |
+
._*
|
124 |
+
.Spotlight-V100
|
125 |
+
.Trashes
|
126 |
+
ehthumbs.db
|
127 |
+
Thumbs.db
|
README.md
CHANGED
@@ -1,14 +1,111 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pose-Preserving Comicfier - Gradio App
|
2 |
+
|
3 |
+
[](https://huggingface.co/Mer-o)(https://huggingface.co/spaces/Mer-o/Pose-Preserving-Comicfier)
|
4 |
+
|
5 |
+
This repository contains the code for a Gradio web application that transforms input images into a specific retro Western comic book style while preserving the original pose. It uses Stable Diffusion v1.5, ControlNet (OpenPose + Tile), and specific LoRAs.
|
6 |
+
|
7 |
+
This application refactors the workflow initially developed in a [Kaggle Notebook](https://github.com/mehran-khani/SD-Controlnet-Comic-Styler) into a deployable web app.
|
8 |
+
|
9 |
+
## Features
|
10 |
+
|
11 |
+
* **Pose Preservation:** Uses ControlNet OpenPose to accurately maintain the pose from the input image.
|
12 |
+
* **Retro Comic Style Transfer:** Applies specific LoRAs (`night_comic_V06.safetensors` & `add_detail.safetensors`) for a 1940s Western comic aesthetic with enhanced details.
|
13 |
+
* **Tiling Upscaling:** Implements ControlNet Tile for 2x high-resolution output (1024x1024), improving detail consistency over large images.
|
14 |
+
* **Simplified UI:** Easy-to-use interface with only an image upload and generate button.
|
15 |
+
* **Fixed Parameters:** Generation uses pre-defined, optimized parameters (steps, guidance, strength, prompts) based on the original notebook implementation for consistent results.
|
16 |
+
* **Dynamic Backgrounds:** The background elements in the generated image are randomized for variety in the low-resolution stage.
|
17 |
+
* **Broad Image Support:** Accepts common formats like JPG, PNG, WEBP, and HEIC (requires `pillow-heif`).
|
18 |
+
|
19 |
+
## Technology Stack
|
20 |
+
|
21 |
+
* **Python 3**
|
22 |
+
* **Gradio:** Web UI framework.
|
23 |
+
* **PyTorch:** Core ML framework.
|
24 |
+
* **Hugging Face Libraries:**
|
25 |
+
* `diffusers`: Stable Diffusion pipelines, ControlNet integration.
|
26 |
+
* `transformers`: Underlying model components.
|
27 |
+
* `accelerate`: Hardware acceleration utilities.
|
28 |
+
* `peft`: LoRA loading and management.
|
29 |
+
* **ControlNet:**
|
30 |
+
* OpenPose Detector (`controlnet_aux`)
|
31 |
+
* OpenPose ControlNet Model (`lllyasviel/sd-controlnet-openpose`)
|
32 |
+
* Tile ControlNet Model (`lllyasviel/control_v11f1e_sd15_tile`)
|
33 |
+
* **Base Model:** `runwayml/stable-diffusion-v1-5`
|
34 |
+
* **LoRAs Used:**
|
35 |
+
* Style: [Western Comics Style](https://civitai.com/models/1081588/western-comics-style) (using `night_comic_V06.safetensors`)
|
36 |
+
* Detail: [Detail Tweaker LoRA](https://civitai.com/models/58390/detail-tweaker-lora-lora) (using `add_detail.safetensors`)
|
37 |
+
* **Image Processing:** `Pillow`, `pillow-heif`, `numpy`, `opencv-python-headless`
|
38 |
+
* **Dependencies:** `matplotlib`, `mediapipe` (required by `controlnet_aux`)
|
39 |
+
|
40 |
+
## Workflow Overview
|
41 |
+
|
42 |
+
1. **Image Preparation (`image_utils.py`):** Input image is loaded (supports HEIC), converted to RGB, EXIF data handled, and force-resized to 512x512.
|
43 |
+
2. **Pose Detection (`pipelines.py`):** An OpenPose map is extracted from the resized image using `controlnet_aux`.
|
44 |
+
3. **Low-Resolution Generation (`pipelines.py`):**
|
45 |
+
* An SDv1.5 Img2Img pipeline with Pose ControlNet is dynamically loaded.
|
46 |
+
* Prompts are generated (`prompts.py`) with a fixed base/style and a *randomized* background element.
|
47 |
+
* Style and Detail LoRAs are applied.
|
48 |
+
* A 512x512 image is generated using fixed parameters.
|
49 |
+
* The pipeline is unloaded to conserve VRAM.
|
50 |
+
4. **High-Resolution Tiling (`pipelines.py`):**
|
51 |
+
* The 512x512 image is upscaled 2x (to 1024x1024) using bicubic interpolation (creating a blurry base).
|
52 |
+
* An SDv1.5 Img2Img pipeline with Tile ControlNet is dynamically loaded.
|
53 |
+
* Tile-specific prompts (excluding the random background) are used.
|
54 |
+
* Style and Detail LoRAs are applied (potentially with different weights).
|
55 |
+
* The image is processed in overlapping 1024x1024 tiles.
|
56 |
+
* Processed tiles are blended back together using an alpha mask (`image_utils.py`).
|
57 |
+
* The pipeline is unloaded.
|
58 |
+
5. **Output (`app.py`):** The final 1024x1024 image is displayed in the Gradio UI.
|
59 |
+
|
60 |
+
## How to Run Locally
|
61 |
+
|
62 |
+
*(Requires sufficient RAM/CPU or compatible GPU, Python 3.8+, and Git)*
|
63 |
+
|
64 |
+
1. **Clone the repository:**
|
65 |
+
```bash
|
66 |
+
git clone https://github.com/mehran-khani/Pose-Preserving-Comicfier.git
|
67 |
+
cd Pose-Preserving-Comicfier
|
68 |
+
```
|
69 |
+
2. **Create and activate a Python virtual environment:**
|
70 |
+
```bash
|
71 |
+
python3 -m venv .venv
|
72 |
+
source .venv/bin/activate
|
73 |
+
# .\.venv\Scripts\Activate.ps1
|
74 |
+
# .\.venv\Scripts\activate.bat
|
75 |
+
```
|
76 |
+
3. **Install dependencies:**
|
77 |
+
```bash
|
78 |
+
pip install -r requirements.txt
|
79 |
+
```
|
80 |
+
*(Note: PyTorch installation might require specific commands depending on your OS/CUDA setup if using a local GPU. See PyTorch website.)*
|
81 |
+
4. **Download LoRA files:**
|
82 |
+
* Create a folder named `loras` in the project root.
|
83 |
+
* Download `night_comic_V06.safetensors` (from Civitai link above) and place it in the `loras` folder.
|
84 |
+
* Download `add_detail.safetensors` (from Civitai link above) and place it in the `loras` folder.
|
85 |
+
5. **Run the Gradio app:**
|
86 |
+
```bash
|
87 |
+
python app.py
|
88 |
+
```
|
89 |
+
6. Open the local URL provided (e.g., `http://127.0.0.1:7860`) in your browser. *(Note: Execution will be very slow without a suitable GPU).*
|
90 |
+
|
91 |
+
## Deployment to Hugging Face Spaces
|
92 |
+
|
93 |
+
This app is designed for deployment on Hugging Face Spaces, ideally with GPU hardware.
|
94 |
+
|
95 |
+
1. Ensure all code (`*.py`), `requirements.txt`, `.gitignore`, and the `loras` folder (containing the `.safetensors` files) are committed and pushed to this GitHub repository.
|
96 |
+
2. Create a new Space on Hugging Face ([huggingface.co/new-space](https://huggingface.co/new-space)).
|
97 |
+
3. Choose an owner, Space name, and select "Gradio" as the Space SDK.
|
98 |
+
4. Select desired hardware (e.g., "T4 small" under GPU options). Note compute costs may apply.
|
99 |
+
5. Choose "Use existing GitHub repository".
|
100 |
+
6. Enter the URL of this GitHub repository.
|
101 |
+
7. Click "Create Space". The Space will build the environment from `requirements.txt` and run `app.py`. Monitor the build and runtime logs for any issues.
|
102 |
+
|
103 |
+
## Limitations
|
104 |
+
|
105 |
+
* **Speed:** Generation requires significant time (minutes), especially on shared/free GPU hardware, due to the multi-stage process and dynamic model loading between stages. CPU execution is impractically slow.
|
106 |
+
* **VRAM:** While optimized with dynamic pipeline unloading, the process still requires considerable GPU VRAM (>10GB peak). Out-of-memory errors are possible on lower-VRAM GPUs.
|
107 |
+
* **Fixed Style:** The artistic style (prompts, LoRAs, parameters) is fixed in the code to replicate the notebook's specific output and cannot be changed via the UI.
|
108 |
+
|
109 |
+
## License
|
110 |
+
|
111 |
+
MIT License
|
app.py
ADDED
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Main application script for the Gradio interface.
|
3 |
+
|
4 |
+
This script initializes the application, loads prerequisite models via model_loader,
|
5 |
+
defines the user interface using Gradio Blocks, and orchestrates the multi-stage
|
6 |
+
image generation process by calling functions from the pipelines module.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
import gradio.themes as gr_themes
|
11 |
+
import time
|
12 |
+
import os
|
13 |
+
import random
|
14 |
+
# --- Imports from our custom modules ---
|
15 |
+
try:
|
16 |
+
from image_utils import prepare_image
|
17 |
+
from model_loader import load_models, are_models_loaded
|
18 |
+
from pipelines import run_pose_detection, run_low_res_generation, run_hires_tiling, cleanup_memory
|
19 |
+
print("Helper modules imported successfully.")
|
20 |
+
except ImportError as e:
|
21 |
+
print(f"ERROR: Failed to import required local modules: {e}")
|
22 |
+
print("Please ensure prompts.py, image_utils.py, model_loader.py, and pipelines.py are in the same directory.")
|
23 |
+
raise SystemExit(f"Module import failed: {e}")
|
24 |
+
|
25 |
+
# --- Constants & UI Configuration ---
|
26 |
+
DEFAULT_SEED = 1024
|
27 |
+
DEFAULT_STEPS_LOWRES = 30
|
28 |
+
DEFAULT_GUIDANCE_LOWRES = 8.0
|
29 |
+
DEFAULT_STRENGTH_LOWRES = 0.05
|
30 |
+
DEFAULT_CN_SCALE_LOWRES = 1.0
|
31 |
+
|
32 |
+
DEFAULT_STEPS_HIRES = 20
|
33 |
+
DEFAULT_GUIDANCE_HIRES = 8.0
|
34 |
+
DEFAULT_STRENGTH_HIRES = 0.75
|
35 |
+
DEFAULT_CN_SCALE_HIRES = 1.0
|
36 |
+
|
37 |
+
# OUTPUT_DIR = "outputs"
|
38 |
+
# os.makedirs(OUTPUT_DIR, exist_ok=True)
|
39 |
+
|
40 |
+
# --- Load Prerequisite Models at Startup ---
|
41 |
+
if not are_models_loaded():
|
42 |
+
print("Initial model loading required...")
|
43 |
+
load_successful = load_models()
|
44 |
+
if not load_successful:
|
45 |
+
print("FATAL: Failed to load prerequisite models. The application may not work correctly.")
|
46 |
+
else:
|
47 |
+
print("Models were already loaded.")
|
48 |
+
|
49 |
+
|
50 |
+
# --- Main Processing Function ---
|
51 |
+
def generate_full_pipeline(
|
52 |
+
input_image_path,
|
53 |
+
progress=gr.Progress(track_tqdm=True)
|
54 |
+
):
|
55 |
+
"""
|
56 |
+
Orchestrates the entire image generation workflow.
|
57 |
+
|
58 |
+
This function is called when the user clicks the 'Generate' button in the UI.
|
59 |
+
It takes inputs from the UI, calls the necessary processing steps in sequence
|
60 |
+
(prepare, detect pose, low-res gen, hi-res gen), updates the progress bar,
|
61 |
+
and returns the final generated image.
|
62 |
+
|
63 |
+
Args:
|
64 |
+
input_image_path (str): Path to the uploaded input image file.
|
65 |
+
seed (int): Random seed for generation.
|
66 |
+
steps_lowres (int): Inference steps for the low-resolution stage.
|
67 |
+
guidance_lowres (float): Guidance scale for the low-resolution stage.
|
68 |
+
strength_lowres (float): Img2Img strength for the low-resolution stage.
|
69 |
+
cn_scale_lowres (float): ControlNet scale for the low-resolution stage.
|
70 |
+
steps_hires (int): Inference steps per tile for the high-resolution stage.
|
71 |
+
guidance_hires (float): Guidance scale for the high-resolution stage.
|
72 |
+
strength_hires (float): Img2Img strength for the high-resolution stage.
|
73 |
+
cn_scale_hires (float): ControlNet scale for the high-resolution stage.
|
74 |
+
progress (gr.Progress): Gradio progress tracking object.
|
75 |
+
|
76 |
+
Returns:
|
77 |
+
PIL.Image.Image | None: The final generated high-resolution image,
|
78 |
+
or the low-resolution image as a fallback if
|
79 |
+
tiling fails, or None if critical errors occur early.
|
80 |
+
|
81 |
+
Raises:
|
82 |
+
gr.Error: If critical steps like image preparation or pose detection fail.
|
83 |
+
gr.Warning: If hi-res tiling fails but low-res succeeded (returns low-res).
|
84 |
+
"""
|
85 |
+
print(f"\n--- Starting New Generation Run ---")
|
86 |
+
run_start_time = time.time()
|
87 |
+
|
88 |
+
current_seed = DEFAULT_SEED
|
89 |
+
if current_seed == -1:
|
90 |
+
current_seed = random.randint(0, 9999999)
|
91 |
+
print(f"Using Random Seed: {current_seed}")
|
92 |
+
else:
|
93 |
+
print(f"Using Fixed Seed: {current_seed}")
|
94 |
+
|
95 |
+
low_res_image = None
|
96 |
+
final_image = None
|
97 |
+
|
98 |
+
try:
|
99 |
+
progress(0.05, desc="Preparing Input Image...")
|
100 |
+
resized_input_image = prepare_image(input_image_path, target_size=512)
|
101 |
+
if resized_input_image is None:
|
102 |
+
raise gr.Error("Failed to load or prepare the input image. Check format/corruption.")
|
103 |
+
|
104 |
+
progress(0.15, desc="Detecting Pose...")
|
105 |
+
pose_map = run_pose_detection(resized_input_image)
|
106 |
+
if pose_map is None:
|
107 |
+
raise gr.Error("Failed to detect pose from the input image.")
|
108 |
+
# try: pose_map.save(os.path.join(OUTPUT_DIR, f"pose_map_{current_seed}.png"))
|
109 |
+
# except Exception as save_e: print(f"Warning: Could not save pose map: {save_e}")
|
110 |
+
|
111 |
+
|
112 |
+
progress(0.25, desc="Starting Low-Res Generation...")
|
113 |
+
low_res_image = run_low_res_generation(
|
114 |
+
resized_input_image=resized_input_image,
|
115 |
+
pose_map=pose_map,
|
116 |
+
seed=int(current_seed),
|
117 |
+
steps=int(DEFAULT_STEPS_LOWRES),
|
118 |
+
guidance_scale=float(DEFAULT_GUIDANCE_LOWRES),
|
119 |
+
strength=float(DEFAULT_STRENGTH_LOWRES),
|
120 |
+
controlnet_scale=float(DEFAULT_CN_SCALE_LOWRES),
|
121 |
+
progress=progress
|
122 |
+
)
|
123 |
+
print("Low-res generation stage completed successfully.")
|
124 |
+
# try: low_res_image.save(os.path.join(OUTPUT_DIR, f"lowres_output_{current_seed}.png"))
|
125 |
+
# except Exception as save_e: print(f"Warning: Could not save low-res image: {save_e}")
|
126 |
+
progress(0.45, desc="Low-Res Generation Complete.")
|
127 |
+
|
128 |
+
|
129 |
+
progress(0.50, desc="Starting Hi-Res Tiling...")
|
130 |
+
final_image = run_hires_tiling(
|
131 |
+
low_res_image=low_res_image,
|
132 |
+
seed=int(current_seed),
|
133 |
+
steps=int(DEFAULT_STEPS_HIRES),
|
134 |
+
guidance_scale=float(DEFAULT_GUIDANCE_HIRES),
|
135 |
+
strength=float(DEFAULT_STRENGTH_HIRES),
|
136 |
+
controlnet_scale=float(DEFAULT_CN_SCALE_HIRES),
|
137 |
+
upscale_factor=2,
|
138 |
+
tile_size=1024,
|
139 |
+
tile_stride=1024,
|
140 |
+
progress=progress
|
141 |
+
)
|
142 |
+
print("Hi-res tiling stage completed successfully.")
|
143 |
+
# try: final_image.save(os.path.join(OUTPUT_DIR, f"hires_output_{current_seed}.png"))
|
144 |
+
# except Exception as save_e: print(f"Warning: Could not save final image: {save_e}")
|
145 |
+
|
146 |
+
progress(1.0, desc="Complete!")
|
147 |
+
|
148 |
+
except gr.Error as e:
|
149 |
+
print(f"Gradio Error occurred: {e}")
|
150 |
+
if final_image is None and low_res_image is not None and ("tiling" in str(e).lower() or "hi-res" in str(e).lower()):
|
151 |
+
gr.Warning(f"High-resolution upscaling failed ({e}). Returning low-resolution image.")
|
152 |
+
final_image = low_res_image
|
153 |
+
else:
|
154 |
+
raise e
|
155 |
+
except Exception as e:
|
156 |
+
print(f"An unexpected error occurred in generate_full_pipeline: {e}")
|
157 |
+
import traceback
|
158 |
+
traceback.print_exc()
|
159 |
+
raise gr.Error(f"An unexpected error occurred: {e}")
|
160 |
+
finally:
|
161 |
+
print("Running final cleanup check...")
|
162 |
+
cleanup_memory()
|
163 |
+
run_end_time = time.time()
|
164 |
+
print(f"--- Full Pipeline Run Finished in {run_end_time - run_start_time:.2f} seconds ---")
|
165 |
+
|
166 |
+
return final_image
|
167 |
+
|
168 |
+
|
169 |
+
# --- Gradio Interface Definition ---
|
170 |
+
|
171 |
+
theme = gr_themes.Soft(primary_hue=gr_themes.colors.blue, secondary_hue=gr_themes.colors.sky)
|
172 |
+
|
173 |
+
# New, improved Markdown description
|
174 |
+
DESCRIPTION = f"""
|
175 |
+
<div style="text-align: center;">
|
176 |
+
<h1 style="font-family: Impact, Charcoal, sans-serif; font-size: 280%; font-weight: 900; margin-bottom: 16px;">
|
177 |
+
Pose-Preserving Comicfier
|
178 |
+
</h1>
|
179 |
+
<p style="margin-bottom: 12; font-size: 94%">
|
180 |
+
Transform your photos into the gritty style of a 1940s Western comic! This app uses (Stable Diffusion + ControlNet)
|
181 |
+
to apply the artistic look while keeping the original pose intact. Just upload your image and click Generate!
|
182 |
+
</p>
|
183 |
+
<p style="font-size: 85%;"><em>(Generation can take several minutes on shared hardware. Prompts & parameters are fixed.)</em></p>
|
184 |
+
<p style="font-size: 80%; color: grey;">
|
185 |
+
<a href="https://github.com/mehran-khani" target="_blank">[View Project on GitHub]</a> |
|
186 |
+
<a href="https://huggingface.co/spaces/.../discussions" target="_blank">[Report an Issue]</a>
|
187 |
+
</p>
|
188 |
+
<!-- Remember to replace placeholders above with your actual links -->
|
189 |
+
</div>
|
190 |
+
"""
|
191 |
+
|
192 |
+
EXAMPLE_IMAGES_DIR = "examples"
|
193 |
+
EXAMPLE_IMAGES = [
|
194 |
+
os.path.join(EXAMPLE_IMAGES_DIR, "example1.jpg"),
|
195 |
+
os.path.join(EXAMPLE_IMAGES_DIR, "example2.jpg"),
|
196 |
+
os.path.join(EXAMPLE_IMAGES_DIR, "example3.jpg"),
|
197 |
+
os.path.join(EXAMPLE_IMAGES_DIR, "example4.jpg"),
|
198 |
+
os.path.join(EXAMPLE_IMAGES_DIR, "example5.jpg"),
|
199 |
+
os.path.join(EXAMPLE_IMAGES_DIR, "example6.jpg"),
|
200 |
+
]
|
201 |
+
EXAMPLE_IMAGES = [img for img in EXAMPLE_IMAGES if os.path.exists(img)]
|
202 |
+
|
203 |
+
CUSTOM_CSS = """
|
204 |
+
/* Target the container div Gradio uses for the Image component */
|
205 |
+
.gradio-image {
|
206 |
+
width: 100%; /* Ensure the container fills the column width */
|
207 |
+
height: 100%; /* Ensure the container fills the height set by the component (e.g., height=400) */
|
208 |
+
overflow: hidden; /* Hide any potential overflow before object-fit applies */
|
209 |
+
}
|
210 |
+
|
211 |
+
/* Target the actual <img> tag inside the container */
|
212 |
+
.gradio-image img {
|
213 |
+
display: block; /* Remove potential bottom spacing */
|
214 |
+
width: 100%; /* Force image width to match container */
|
215 |
+
height: 100%; /* Force image height to match container */
|
216 |
+
object-fit: cover; /* Scale/crop image to cover this forced W/H */
|
217 |
+
}
|
218 |
+
|
219 |
+
footer { visibility: hidden }
|
220 |
+
"""
|
221 |
+
|
222 |
+
with gr.Blocks(theme=theme, css=CUSTOM_CSS, title="Pose-Preserving Comicfier") as demo:
|
223 |
+
gr.HTML(DESCRIPTION)
|
224 |
+
|
225 |
+
with gr.Row():
|
226 |
+
# Input Column
|
227 |
+
with gr.Column(scale=1, min_width=350):
|
228 |
+
# REMOVED height=400
|
229 |
+
input_image = gr.Image(
|
230 |
+
type="filepath",
|
231 |
+
label="Upload Your Image Here"
|
232 |
+
)
|
233 |
+
generate_button = gr.Button("Generate Comic Image", variant="primary")
|
234 |
+
|
235 |
+
# Output Column
|
236 |
+
with gr.Column(scale=1, min_width=350):
|
237 |
+
# REMOVED height=400
|
238 |
+
output_image = gr.Image(
|
239 |
+
type="pil",
|
240 |
+
label="Generated Comic Image",
|
241 |
+
interactive=False
|
242 |
+
)
|
243 |
+
|
244 |
+
|
245 |
+
# Examples Section
|
246 |
+
if EXAMPLE_IMAGES:
|
247 |
+
gr.Examples(
|
248 |
+
examples=EXAMPLE_IMAGES,
|
249 |
+
inputs=[input_image],
|
250 |
+
outputs=[output_image],
|
251 |
+
fn=generate_full_pipeline,
|
252 |
+
cache_examples=False
|
253 |
+
)
|
254 |
+
|
255 |
+
generate_button.click(
|
256 |
+
fn=generate_full_pipeline,
|
257 |
+
inputs=[input_image],
|
258 |
+
outputs=[output_image],
|
259 |
+
api_name="generate"
|
260 |
+
)
|
261 |
+
|
262 |
+
|
263 |
+
# --- Launch the Gradio App ---
|
264 |
+
if __name__ == "__main__":
|
265 |
+
if not are_models_loaded():
|
266 |
+
print("Attempting to load models before launch...")
|
267 |
+
if not load_models():
|
268 |
+
print("FATAL: Model loading failed on launch. App may not function.")
|
269 |
+
|
270 |
+
print("Attempting to launch Gradio demo...")
|
271 |
+
demo.queue().launch(debug=False, share=False)
|
272 |
+
print("Gradio app launched. Access it at the URL provided above.")
|
examples/example1.jpg
ADDED
![]() |
Git LFS Details
|
examples/example2.jpg
ADDED
![]() |
Git LFS Details
|
examples/example3.jpg
ADDED
![]() |
Git LFS Details
|
examples/example4.jpg
ADDED
![]() |
Git LFS Details
|
examples/example5.jpg
ADDED
![]() |
Git LFS Details
|
examples/example6.jpg
ADDED
![]() |
Git LFS Details
|
image_utils.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Contains utility functions for image loading, preparation, and manipulation.
|
3 |
+
Includes HEIC image format support via the optional 'pillow-heif' library.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from PIL import Image, ImageOps, ImageDraw
|
7 |
+
import os
|
8 |
+
|
9 |
+
try:
|
10 |
+
from pillow_heif import register_heif_opener
|
11 |
+
register_heif_opener()
|
12 |
+
print("HEIC opener registered successfully using pillow-heif.")
|
13 |
+
_heic_support = True
|
14 |
+
except ImportError:
|
15 |
+
print("Warning: pillow-heif not installed. HEIC/HEIF support will be disabled.")
|
16 |
+
_heic_support = False
|
17 |
+
|
18 |
+
|
19 |
+
print("Loading Image Utils...")
|
20 |
+
|
21 |
+
def prepare_image(image_filepath, target_size=512):
|
22 |
+
"""
|
23 |
+
Prepares an input image file for the diffusion pipeline.
|
24 |
+
|
25 |
+
Loads an image from the given filepath (supports standard formats like
|
26 |
+
JPG, PNG, WEBP, and HEIC/HEIF),
|
27 |
+
ensures it's in RGB format, handles EXIF orientation, and performs
|
28 |
+
a forced resize to a square target_size, ignoring the original aspect ratio.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
image_filepath (str): The path to the image file.
|
32 |
+
target_size (int): The target dimension for both width and height.
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
PIL.Image.Image | None: The prepared image as a PIL Image object in RGB format,
|
36 |
+
or None if loading or processing fails.
|
37 |
+
"""
|
38 |
+
if image_filepath is None:
|
39 |
+
print("Warning: prepare_image received None filepath.")
|
40 |
+
return None
|
41 |
+
|
42 |
+
if not isinstance(image_filepath, str) or not os.path.exists(image_filepath):
|
43 |
+
print(f"Error: Invalid filepath provided to prepare_image: {image_filepath}")
|
44 |
+
if isinstance(image_filepath, Image.Image):
|
45 |
+
print("Warning: Received PIL Image instead of filepath, proceeding...")
|
46 |
+
image = image_filepath
|
47 |
+
else:
|
48 |
+
return None
|
49 |
+
else:
|
50 |
+
# --- Load Image from Filepath ---
|
51 |
+
print(f"Loading image from path: {image_filepath}")
|
52 |
+
try:
|
53 |
+
image = Image.open(image_filepath)
|
54 |
+
except ImportError as e:
|
55 |
+
print(f"ImportError during Image.open: {e}. Is pillow-heif installed?")
|
56 |
+
print("Cannot process image format.")
|
57 |
+
return None
|
58 |
+
except Exception as e:
|
59 |
+
print(f"Error opening image file {image_filepath} with PIL: {e}")
|
60 |
+
return None
|
61 |
+
|
62 |
+
# --- Process PIL Image ---
|
63 |
+
try:
|
64 |
+
image = ImageOps.exif_transpose(image)
|
65 |
+
|
66 |
+
image = image.convert("RGB")
|
67 |
+
|
68 |
+
original_width, original_height = image.size
|
69 |
+
|
70 |
+
final_width = target_size
|
71 |
+
final_height = target_size
|
72 |
+
|
73 |
+
resized_image = image.resize((final_width, final_height), Image.LANCZOS)
|
74 |
+
|
75 |
+
print(f"Original size: ({original_width}, {original_height}), FORCED Resized to: ({final_width}, {final_height})")
|
76 |
+
return resized_image
|
77 |
+
except Exception as e:
|
78 |
+
print(f"Error during PIL image processing steps: {e}")
|
79 |
+
return None
|
80 |
+
|
81 |
+
def create_blend_mask(tile_size=1024, overlap=256):
|
82 |
+
"""
|
83 |
+
Creates a feathered blending mask (alpha mask) for smooth tile stitching.
|
84 |
+
|
85 |
+
Generates a square mask where the edges have a linear gradient ramp within
|
86 |
+
the specified overlap zone, and the central area is fully opaque.
|
87 |
+
Assumes overlap occurs equally on all four sides.
|
88 |
+
|
89 |
+
Args:
|
90 |
+
tile_size (int): The dimension (width and height) of the tiles being processed.
|
91 |
+
overlap (int): The number of pixels that overlap between adjacent tiles.
|
92 |
+
|
93 |
+
Returns:
|
94 |
+
PIL.Image.Image: The blending mask as a PIL Image object in 'L' (grayscale) mode.
|
95 |
+
White (255) areas are fully opaque, black (0) are transparent,
|
96 |
+
gray values provide blending.
|
97 |
+
"""
|
98 |
+
if overlap >= tile_size // 2:
|
99 |
+
print("Warning: Overlap is large relative to tile size, mask generation might be suboptimal.")
|
100 |
+
overlap = tile_size // 2 - 1
|
101 |
+
|
102 |
+
mask = Image.new("L", (tile_size, tile_size), 0)
|
103 |
+
draw = ImageDraw.Draw(mask)
|
104 |
+
|
105 |
+
if overlap > 0:
|
106 |
+
for i in range(overlap):
|
107 |
+
alpha = int(255 * (i / float(overlap)))
|
108 |
+
|
109 |
+
# Left edge ramp
|
110 |
+
draw.line([(i, 0), (i, tile_size)], fill=alpha)
|
111 |
+
# Right edge ramp
|
112 |
+
draw.line([(tile_size - 1 - i, 0), (tile_size - 1 - i, tile_size)], fill=alpha)
|
113 |
+
# Top edge ramp
|
114 |
+
draw.line([(0, i), (tile_size, i)], fill=alpha)
|
115 |
+
# Bottom edge ramp
|
116 |
+
draw.line([(0, tile_size - 1 - i), (tile_size, tile_size - 1 - i)], fill=alpha)
|
117 |
+
|
118 |
+
center_start = overlap
|
119 |
+
center_end_x = tile_size - overlap
|
120 |
+
center_end_y = tile_size - overlap
|
121 |
+
|
122 |
+
if center_end_x > center_start and center_end_y > center_start:
|
123 |
+
draw.rectangle( (center_start, center_start, center_end_x - 1, center_end_y - 1), fill=255 )
|
124 |
+
else:
|
125 |
+
center_x, center_y = tile_size // 2, tile_size // 2
|
126 |
+
draw.point((center_x, center_y), fill=255)
|
127 |
+
if tile_size % 2 == 0:
|
128 |
+
draw.point((center_x-1, center_y), fill=255)
|
129 |
+
draw.point((center_x, center_y-1), fill=255)
|
130 |
+
draw.point((center_x-1, center_y-1), fill=255)
|
131 |
+
|
132 |
+
|
133 |
+
print(f"Blend mask created (Size: {tile_size}x{tile_size}, Overlap: {overlap})")
|
134 |
+
return mask
|
loras/add_detail.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:47aaaf0d2945ca937151d61304946dd229b3f072140b85484bc93e38f2a6e2f7
|
3 |
+
size 37861176
|
loras/night_comic_V06.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:caf9280080bf4183a9064547a9374bc22575248867a9bcc54a883305b57a8ebb
|
3 |
+
size 14153788
|
model_loader.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Handles the loading and management of necessary AI models from Hugging Face Hub.
|
3 |
+
|
4 |
+
Provides functions to load models once at startup and access them throughout
|
5 |
+
the application, managing device placement (CPU/GPU) and data types.
|
6 |
+
Optimized for typical Hugging Face Space GPU environments.
|
7 |
+
"""
|
8 |
+
|
9 |
+
import torch
|
10 |
+
from diffusers import ControlNetModel
|
11 |
+
from controlnet_aux import OpenposeDetector
|
12 |
+
import gc
|
13 |
+
|
14 |
+
# --- Configuration ---
|
15 |
+
# Automatically detect CUDA availability and set appropriate device/dtype
|
16 |
+
if torch.cuda.is_available():
|
17 |
+
DEVICE = "cuda"
|
18 |
+
DTYPE = torch.float16
|
19 |
+
print(f"CUDA available. Using Device: {DEVICE}, Dtype: {DTYPE}")
|
20 |
+
try:
|
21 |
+
print(f"GPU Name: {torch.cuda.get_device_name(0)}")
|
22 |
+
except Exception as e:
|
23 |
+
print(f"Couldn't get GPU name: {e}")
|
24 |
+
else:
|
25 |
+
DEVICE = "cpu"
|
26 |
+
DTYPE = torch.float32
|
27 |
+
print(f"CUDA not available. Using Device: {DEVICE}, Dtype: {DTYPE}")
|
28 |
+
|
29 |
+
|
30 |
+
# Model IDs from Hugging Face Hub
|
31 |
+
# BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5" # Base SD model ID needed by pipelines
|
32 |
+
OPENPOSE_DETECTOR_ID = 'lllyasviel/ControlNet' # Preprocessor model repo
|
33 |
+
CONTROLNET_POSE_MODEL_ID = "lllyasviel/sd-controlnet-openpose" # OpenPose ControlNet weights
|
34 |
+
CONTROLNET_TILE_MODEL_ID = "lllyasviel/control_v11f1e_sd15_tile" # Tile ControlNet weights
|
35 |
+
|
36 |
+
_openpose_detector = None
|
37 |
+
_controlnet_pose = None
|
38 |
+
_controlnet_tile = None
|
39 |
+
_models_loaded = False
|
40 |
+
|
41 |
+
# --- Loading Function ---
|
42 |
+
|
43 |
+
def load_models(force_reload=False):
|
44 |
+
"""
|
45 |
+
Loads the OpenPose detector (to CPU) and ControlNet models (to configured DEVICE).
|
46 |
+
|
47 |
+
This function should typically be called once when the application starts.
|
48 |
+
It checks if models are already loaded to prevent redundant loading unless
|
49 |
+
`force_reload` is True.
|
50 |
+
|
51 |
+
Args:
|
52 |
+
force_reload (bool): If True, forces reloading even if models are already loaded.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
bool: True if all models were loaded successfully (or already were), False otherwise.
|
56 |
+
"""
|
57 |
+
global _openpose_detector, _controlnet_pose, _controlnet_tile, _models_loaded
|
58 |
+
|
59 |
+
if _models_loaded and not force_reload:
|
60 |
+
print("Models already loaded.")
|
61 |
+
return True
|
62 |
+
|
63 |
+
print(f"--- Loading Models ---")
|
64 |
+
if DEVICE == "cuda":
|
65 |
+
print("Performing initial CUDA cache clear...")
|
66 |
+
gc.collect()
|
67 |
+
torch.cuda.empty_cache()
|
68 |
+
|
69 |
+
# 1. OpenPose Detector
|
70 |
+
try:
|
71 |
+
print(f"Loading OpenPose Detector from {OPENPOSE_DETECTOR_ID} to CPU...")
|
72 |
+
_openpose_detector = OpenposeDetector.from_pretrained(OPENPOSE_DETECTOR_ID)
|
73 |
+
print("OpenPose detector loaded successfully (on CPU).")
|
74 |
+
except Exception as e:
|
75 |
+
print(f"ERROR: Failed to load OpenPose Detector: {e}")
|
76 |
+
_models_loaded = False
|
77 |
+
return False
|
78 |
+
|
79 |
+
# 2. ControlNet Models
|
80 |
+
try:
|
81 |
+
print(f"Loading ControlNet Pose Model from {CONTROLNET_POSE_MODEL_ID} to {DEVICE} ({DTYPE})...")
|
82 |
+
_controlnet_pose = ControlNetModel.from_pretrained(
|
83 |
+
CONTROLNET_POSE_MODEL_ID, torch_dtype=DTYPE
|
84 |
+
)
|
85 |
+
_controlnet_pose.to(DEVICE)
|
86 |
+
print("ControlNet Pose model loaded successfully.")
|
87 |
+
except Exception as e:
|
88 |
+
print(f"ERROR: Failed to load ControlNet Pose Model: {e}")
|
89 |
+
_models_loaded = False
|
90 |
+
return False
|
91 |
+
|
92 |
+
try:
|
93 |
+
print(f"Loading ControlNet Tile Model from {CONTROLNET_TILE_MODEL_ID} to {DEVICE} ({DTYPE})...")
|
94 |
+
_controlnet_tile = ControlNetModel.from_pretrained(
|
95 |
+
CONTROLNET_TILE_MODEL_ID, torch_dtype=DTYPE
|
96 |
+
)
|
97 |
+
_controlnet_tile.to(DEVICE)
|
98 |
+
print("ControlNet Tile model loaded successfully.")
|
99 |
+
except Exception as e:
|
100 |
+
print(f"ERROR: Failed to load ControlNet Tile Model: {e}")
|
101 |
+
_models_loaded = False
|
102 |
+
return False
|
103 |
+
|
104 |
+
_models_loaded = True
|
105 |
+
print("--- All prerequisite models loaded successfully. ---")
|
106 |
+
if DEVICE == "cuda":
|
107 |
+
print("Performing post-load CUDA cache clear...")
|
108 |
+
gc.collect()
|
109 |
+
torch.cuda.empty_cache()
|
110 |
+
return True
|
111 |
+
|
112 |
+
# --- Getter Functions ---
|
113 |
+
|
114 |
+
def get_openpose_detector():
|
115 |
+
if not _models_loaded: load_models()
|
116 |
+
return _openpose_detector
|
117 |
+
|
118 |
+
def get_controlnet_pose():
|
119 |
+
if not _models_loaded: load_models()
|
120 |
+
return _controlnet_pose
|
121 |
+
|
122 |
+
def get_controlnet_tile():
|
123 |
+
if not _models_loaded: load_models()
|
124 |
+
return _controlnet_tile
|
125 |
+
|
126 |
+
def get_device():
|
127 |
+
return DEVICE
|
128 |
+
|
129 |
+
def get_dtype():
|
130 |
+
return DTYPE
|
131 |
+
|
132 |
+
def are_models_loaded():
|
133 |
+
return _models_loaded
|
pipelines.py
ADDED
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Contains functions to execute the main image generation stages:
|
3 |
+
1. OpenPose Detection: Extracts pose information.
|
4 |
+
2. Low-Resolution Generation: Creates initial image using Pose ControlNet.
|
5 |
+
3. High-Resolution Tiling: Upscales the low-res image using Tile ControlNet.
|
6 |
+
|
7 |
+
Manages dynamic loading/unloading of diffusion pipelines to conserve VRAM.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import gc
|
12 |
+
import time
|
13 |
+
import os
|
14 |
+
from PIL import Image
|
15 |
+
from tqdm.auto import tqdm
|
16 |
+
import gradio as gr
|
17 |
+
from diffusers import (
|
18 |
+
StableDiffusionControlNetImg2ImgPipeline,
|
19 |
+
UniPCMultistepScheduler,
|
20 |
+
)
|
21 |
+
from model_loader import (
|
22 |
+
get_openpose_detector,
|
23 |
+
get_controlnet_pose,
|
24 |
+
get_controlnet_tile,
|
25 |
+
get_device,
|
26 |
+
get_dtype,
|
27 |
+
are_models_loaded,
|
28 |
+
)
|
29 |
+
from image_utils import create_blend_mask
|
30 |
+
from prompts import get_prompts_for_run
|
31 |
+
|
32 |
+
|
33 |
+
# --- Configuration ---
|
34 |
+
BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5"
|
35 |
+
LORA_DIR = "loras"
|
36 |
+
LORA_FILES = {
|
37 |
+
"style": os.path.join(LORA_DIR, "night_comic_V06.safetensors"),
|
38 |
+
"detail": os.path.join(LORA_DIR, "add_detail.safetensors"),
|
39 |
+
}
|
40 |
+
LORA_WEIGHTS_LOWRES = [1, 1]
|
41 |
+
LORA_WEIGHTS_HIRES = [1, 2]
|
42 |
+
ACTIVE_ADAPTERS = ["style", "detail"]
|
43 |
+
|
44 |
+
def cleanup_memory():
|
45 |
+
"""Forces garbage collection and clears CUDA cache."""
|
46 |
+
gc.collect()
|
47 |
+
if torch.cuda.is_available():
|
48 |
+
torch.cuda.empty_cache()
|
49 |
+
|
50 |
+
# --- Stage 1: OpenPose Detection ---
|
51 |
+
def run_pose_detection(resized_input_image):
|
52 |
+
"""
|
53 |
+
Detects human pose (body, hands, face) from the input image using OpenPose.
|
54 |
+
|
55 |
+
Temporarily moves the OpenPose detector model to the active GPU (if available)
|
56 |
+
for processing and then moves it back to the CPU to conserve VRAM.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
input_image_resized (PIL.Image.Image): The input image, already resized
|
60 |
+
and in RGB format.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
PIL.Image.Image | None: A PIL Image representing the detected pose map,
|
64 |
+
or None if detection fails or models aren't loaded.
|
65 |
+
"""
|
66 |
+
if not are_models_loaded():
|
67 |
+
print("Error: Cannot run pose detection, models not loaded.")
|
68 |
+
return None
|
69 |
+
|
70 |
+
detector = get_openpose_detector()
|
71 |
+
device = get_device()
|
72 |
+
control_image_openpose = None
|
73 |
+
|
74 |
+
if detector is None:
|
75 |
+
print("Error: OpenPose detector is None.")
|
76 |
+
return None
|
77 |
+
|
78 |
+
try:
|
79 |
+
detector.to(device)
|
80 |
+
cleanup_memory()
|
81 |
+
|
82 |
+
control_image_openpose = detector(
|
83 |
+
resized_input_image, include_face=True, include_hand=True
|
84 |
+
)
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
print(f"ERROR during OpenPose detection: {e}")
|
88 |
+
control_image_openpose = None
|
89 |
+
finally:
|
90 |
+
detector.to("cpu")
|
91 |
+
cleanup_memory()
|
92 |
+
|
93 |
+
return control_image_openpose
|
94 |
+
|
95 |
+
# --- Stage 2: Low-Resolution Generation ---
|
96 |
+
def run_low_res_generation(
|
97 |
+
resized_input_image,
|
98 |
+
pose_map,
|
99 |
+
seed,
|
100 |
+
steps,
|
101 |
+
guidance_scale,
|
102 |
+
strength,
|
103 |
+
controlnet_scale=0.8,
|
104 |
+
progress=gr.Progress(track_tqdm=True)
|
105 |
+
):
|
106 |
+
"""
|
107 |
+
Generates the initial low-resolution image using Img2Img with Pose ControlNet.
|
108 |
+
|
109 |
+
Dynamically loads the StableDiffusionControlNetImg2ImgPipeline, applies LoRAs,
|
110 |
+
runs inference, and then unloads the pipeline to free VRAM before returning.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
input_image_resized (PIL.Image.Image): The resized input image.
|
114 |
+
pose_map (PIL.Image.Image): The pose map generated by run_pose_detection.
|
115 |
+
seed (int): The random seed for generation.
|
116 |
+
steps (int): Number of diffusion inference steps.
|
117 |
+
guidance_scale (float): Classifier-free guidance scale.
|
118 |
+
strength (float): Img2Img strength (0.0 to 1.0). How much noise to add.
|
119 |
+
controlnet_scale (float): Conditioning scale for the Pose ControlNet.
|
120 |
+
progress (gr.Progress): Gradio progress object for UI updates.
|
121 |
+
|
122 |
+
Returns:
|
123 |
+
PIL.Image.Image | None: The generated low-resolution PIL Image, or None if an error occurs.
|
124 |
+
|
125 |
+
Raises:
|
126 |
+
gr.Error: Raises a Gradio error if generation fails catastrophically.
|
127 |
+
"""
|
128 |
+
if not are_models_loaded() or pose_map is None:
|
129 |
+
error_msg = "Cannot run low-res generation: "
|
130 |
+
if not are_models_loaded(): error_msg += "Models not loaded. "
|
131 |
+
if pose_map is None: error_msg += "Pose map is missing."
|
132 |
+
print(f"Error: {error_msg}")
|
133 |
+
return None
|
134 |
+
|
135 |
+
device = get_device()
|
136 |
+
dtype = get_dtype()
|
137 |
+
controlnet_pose = get_controlnet_pose()
|
138 |
+
output_image_lowres = None
|
139 |
+
pipe_lowres = None
|
140 |
+
|
141 |
+
positive_prompt, negative_prompt, _, _ = get_prompts_for_run()
|
142 |
+
generator = torch.Generator(device=device).manual_seed(int(seed))
|
143 |
+
|
144 |
+
progress(0, desc="Loading Low-Res Pipeline...")
|
145 |
+
try:
|
146 |
+
# 1. Load Pipeline
|
147 |
+
pipe_lowres = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
148 |
+
BASE_MODEL_ID,
|
149 |
+
controlnet=controlnet_pose,
|
150 |
+
torch_dtype=dtype,
|
151 |
+
safety_checker=None
|
152 |
+
)
|
153 |
+
pipe_lowres.scheduler = UniPCMultistepScheduler.from_config(pipe_lowres.scheduler.config)
|
154 |
+
pipe_lowres.to(device)
|
155 |
+
|
156 |
+
cleanup_memory()
|
157 |
+
|
158 |
+
# 2. Load LoRAs
|
159 |
+
if os.path.exists(LORA_FILES["style"]) and os.path.exists(LORA_FILES["detail"]):
|
160 |
+
pipe_lowres.load_lora_weights(LORA_FILES["style"], adapter_name="style")
|
161 |
+
pipe_lowres.load_lora_weights(LORA_FILES["detail"], adapter_name="detail")
|
162 |
+
pipe_lowres.set_adapters(ACTIVE_ADAPTERS, adapter_weights=LORA_WEIGHTS_LOWRES)
|
163 |
+
print(f"Activated LoRAs: {ACTIVE_ADAPTERS} with weights {LORA_WEIGHTS_LOWRES}")
|
164 |
+
else:
|
165 |
+
print("Warning: One or both LoRA files not found. Skipping LoRA loading.")
|
166 |
+
raise gr.Error("Required LoRA files not found in loras/ directory.")
|
167 |
+
|
168 |
+
# 3. Run Inference
|
169 |
+
progress(0.3, desc="Generating Low-Res Image...")
|
170 |
+
output_image_low_res = pipe_lowres(
|
171 |
+
prompt=positive_prompt,
|
172 |
+
negative_prompt=negative_prompt,
|
173 |
+
image=resized_input_image,
|
174 |
+
control_image=pose_map,
|
175 |
+
num_inference_steps=int(steps),
|
176 |
+
strength=strength,
|
177 |
+
guidance_scale=guidance_scale,
|
178 |
+
controlnet_conditioning_scale=float(controlnet_scale),
|
179 |
+
generator=generator,
|
180 |
+
).images[0]
|
181 |
+
progress(0.9, desc="Low-Res Complete")
|
182 |
+
|
183 |
+
except Exception as e:
|
184 |
+
print(f"ERROR during Low-Res Generation Pipeline: {e}")
|
185 |
+
import traceback
|
186 |
+
traceback.print_exc()
|
187 |
+
output_image_low_res = None
|
188 |
+
raise gr.Error(f"Failed during low-res generation: {e}")
|
189 |
+
finally:
|
190 |
+
# 4. Cleanup Pipeline
|
191 |
+
print("Cleaning up Low-Res pipeline...")
|
192 |
+
if pipe_lowres is not None:
|
193 |
+
try:
|
194 |
+
if hasattr(pipe_lowres, 'get_active_adapters') and pipe_lowres.get_active_adapters():
|
195 |
+
print("Unloading LoRAs...")
|
196 |
+
pipe_lowres.unload_lora_weights()
|
197 |
+
except Exception as unload_e:
|
198 |
+
print(f"Note: Error unloading LoRAs: {unload_e}")
|
199 |
+
|
200 |
+
print("Moving Low-Res pipe components to CPU before deleting...")
|
201 |
+
try: pipe_lowres.to('cpu')
|
202 |
+
except Exception as cpu_e: print(f"Note: Error moving pipe to CPU: {cpu_e}")
|
203 |
+
|
204 |
+
print("Deleting Low-Res pipeline object...")
|
205 |
+
del pipe_lowres
|
206 |
+
pipe_lowres = None
|
207 |
+
|
208 |
+
print("Running garbage collection and emptying CUDA cache after Low-Res...")
|
209 |
+
cleanup_memory()
|
210 |
+
# time.sleep(1)
|
211 |
+
|
212 |
+
print("--- Low-Res Generation Stage Finished ---")
|
213 |
+
return output_image_low_res
|
214 |
+
|
215 |
+
# --- Stage 3: High-Resolution Tiling Upscaling ---
|
216 |
+
def run_hires_tiling(
|
217 |
+
low_res_image,
|
218 |
+
seed,
|
219 |
+
steps,
|
220 |
+
guidance_scale,
|
221 |
+
strength,
|
222 |
+
controlnet_scale=1.0,
|
223 |
+
upscale_factor=2,
|
224 |
+
tile_size=1024,
|
225 |
+
tile_stride=1024,
|
226 |
+
progress=gr.Progress(track_tqdm=True)
|
227 |
+
):
|
228 |
+
"""
|
229 |
+
Upscales the low-resolution image using tiling with the Tile ControlNet.
|
230 |
+
|
231 |
+
Dynamically loads the StableDiffusionControlNetImg2ImgPipeline for tiling,
|
232 |
+
applies LoRAs, processes the image in overlapping tiles, blends the results,
|
233 |
+
and unloads the pipeline to free VRAM.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
low_res_image (PIL.Image.Image): The low-resolution image from the previous stage.
|
237 |
+
seed (int): The random seed (should ideally match low-res stage seed).
|
238 |
+
steps (int): Number of diffusion inference steps per tile.
|
239 |
+
guidance_scale (float): Classifier-free guidance scale for tiles.
|
240 |
+
strength (float): Img2Img strength for tiling (controls detail vs. original).
|
241 |
+
controlnet_scale (float): Conditioning scale for the Tile ControlNet.
|
242 |
+
upscale_factor (int): Factor by which to increase the image resolution.
|
243 |
+
tile_size (int): Size of the square tiles to process.
|
244 |
+
tile_stride (int): Step size between tiles. Overlap = tile_size - tile_stride.
|
245 |
+
progress (gr.Progress): Gradio progress object for UI updates.
|
246 |
+
|
247 |
+
Returns:
|
248 |
+
PIL.Image.Image | None: The generated high-resolution PIL Image, or None if an error occurs.
|
249 |
+
|
250 |
+
Raises:
|
251 |
+
gr.Error: Raises a Gradio error if tiling fails catastrophically.
|
252 |
+
"""
|
253 |
+
if not are_models_loaded() or low_res_image is None:
|
254 |
+
error_msg = "Cannot run hi-res tiling: "
|
255 |
+
if not are_models_loaded(): error_msg += "Models not loaded. "
|
256 |
+
if low_res_image is None: error_msg += "Low-res image is missing."
|
257 |
+
print(f"Error: {error_msg}")
|
258 |
+
return None
|
259 |
+
|
260 |
+
device = get_device()
|
261 |
+
dtype = get_dtype()
|
262 |
+
controlnet_tile = get_controlnet_tile()
|
263 |
+
high_res_output_image = None
|
264 |
+
pipe_hires = None
|
265 |
+
|
266 |
+
_, _, positive_prompt_tile, negative_prompt_tile = get_prompts_for_run()
|
267 |
+
|
268 |
+
generator_tile = torch.Generator(device=device).manual_seed(int(seed))
|
269 |
+
|
270 |
+
print("\n--- Starting Hi-Res Tiling Stage ---")
|
271 |
+
progress(0, desc="Preparing for Tiling...")
|
272 |
+
|
273 |
+
try:
|
274 |
+
# --- Setup Tiling Parameters ---
|
275 |
+
target_width = low_res_image.width * upscale_factor
|
276 |
+
target_height = low_res_image.height * upscale_factor
|
277 |
+
if tile_size > min(target_width, target_height):
|
278 |
+
print(f"Warning: Tile size ({tile_size}) > target dimension ({target_width}x{target_height}). Clamping tile size.")
|
279 |
+
tile_size = min(target_width, target_height)
|
280 |
+
tile_stride = tile_size
|
281 |
+
|
282 |
+
overlap = tile_size - tile_stride
|
283 |
+
if overlap < 0:
|
284 |
+
print("Warning: Tile stride is larger than tile size. Setting stride = tile size.")
|
285 |
+
tile_stride = tile_size
|
286 |
+
overlap = 0
|
287 |
+
|
288 |
+
print(f"Target Res: {target_width}x{target_height}, Tile Size: {tile_size}, Stride: {tile_stride}, Overlap: {overlap}")
|
289 |
+
|
290 |
+
# 1. Load Pipeline
|
291 |
+
print(f"Loading Hi-Res Pipeline ({BASE_MODEL_ID} + Tile ControlNet)...")
|
292 |
+
progress(0.05, desc="Loading Hi-Res Pipeline...")
|
293 |
+
pipe_hires = StableDiffusionControlNetImg2ImgPipeline.from_pretrained(
|
294 |
+
BASE_MODEL_ID,
|
295 |
+
controlnet=controlnet_tile,
|
296 |
+
torch_dtype=dtype,
|
297 |
+
safety_checker=None,
|
298 |
+
)
|
299 |
+
pipe_hires.scheduler = UniPCMultistepScheduler.from_config(pipe_hires.scheduler.config)
|
300 |
+
pipe_hires.to(device)
|
301 |
+
# pipe_hires.enable_model_cpu_offload()
|
302 |
+
# pipe_hires.enable_xformers_memory_efficient_attention()
|
303 |
+
print("Hi-Res Pipeline loaded to GPU.")
|
304 |
+
cleanup_memory()
|
305 |
+
|
306 |
+
# 2. Load LoRAs
|
307 |
+
print("Loading LoRAs for Hi-Res pipe...")
|
308 |
+
if os.path.exists(LORA_FILES["style"]) and os.path.exists(LORA_FILES["detail"]):
|
309 |
+
pipe_hires.load_lora_weights(LORA_FILES["style"], adapter_name="style")
|
310 |
+
pipe_hires.load_lora_weights(LORA_FILES["detail"], adapter_name="detail")
|
311 |
+
pipe_hires.set_adapters(ACTIVE_ADAPTERS, adapter_weights=LORA_WEIGHTS_HIRES)
|
312 |
+
print(f"Activated LoRAs: {ACTIVE_ADAPTERS} with weights {LORA_WEIGHTS_HIRES}")
|
313 |
+
else:
|
314 |
+
print("Warning: One or both LoRA files not found. Skipping LoRA loading.")
|
315 |
+
raise gr.Error("Required LoRA files not found in loras/ directory.")
|
316 |
+
|
317 |
+
|
318 |
+
# --- Prepare for Tiling Loop ---
|
319 |
+
print(f"Creating blurry base image ({target_width}x{target_height})...")
|
320 |
+
progress(0.15, desc="Preparing Base Image...")
|
321 |
+
blurry_high_res = low_res_image.resize((target_width, target_height), Image.LANCZOS)
|
322 |
+
|
323 |
+
final_image = Image.new("RGB", (target_width, target_height))
|
324 |
+
blend_mask = create_blend_mask(tile_size, overlap)
|
325 |
+
|
326 |
+
num_tiles_x = (target_width + tile_stride - 1) // tile_stride
|
327 |
+
num_tiles_y = (target_height + tile_stride - 1) // tile_stride
|
328 |
+
total_tiles = num_tiles_x * num_tiles_y
|
329 |
+
print(f"Processing {num_tiles_x}x{num_tiles_y} = {total_tiles} tiles...")
|
330 |
+
|
331 |
+
# --- Tiling Loop ---
|
332 |
+
progress(0.2, desc=f"Processing Tiles (0/{total_tiles})")
|
333 |
+
processed_tile_count = 0
|
334 |
+
with tqdm(total=total_tiles, desc="Tiling Upscale") as pbar:
|
335 |
+
for y in range(num_tiles_y):
|
336 |
+
for x in range(num_tiles_x):
|
337 |
+
tile_start_time = time.time()
|
338 |
+
pbar.set_description(f"Tiling Upscale (Tile {processed_tile_count+1}/{total_tiles})")
|
339 |
+
|
340 |
+
x_start = x * tile_stride
|
341 |
+
y_start = y * tile_stride
|
342 |
+
x_end = min(x_start + tile_size, target_width)
|
343 |
+
y_end = min(y_start + tile_size, target_height)
|
344 |
+
crop_box = (x_start, y_start, x_end, y_end)
|
345 |
+
|
346 |
+
tile_image_blurry = blurry_high_res.crop(crop_box)
|
347 |
+
current_tile_width, current_tile_height = tile_image_blurry.size
|
348 |
+
|
349 |
+
if current_tile_width < tile_size or current_tile_height < tile_size:
|
350 |
+
try: edge_color = tile_image_blurry.getpixel((0, 0))
|
351 |
+
except IndexError: edge_color = (127, 127, 127)
|
352 |
+
padded_tile = Image.new("RGB", (tile_size, tile_size), edge_color)
|
353 |
+
padded_tile.paste(tile_image_blurry, (0, 0))
|
354 |
+
tile_image_blurry = padded_tile
|
355 |
+
print(f"Padded edge tile at ({x},{y})")
|
356 |
+
|
357 |
+
|
358 |
+
# 3. Run Inference on the Tile
|
359 |
+
with torch.inference_mode():
|
360 |
+
output_tile = pipe_hires(
|
361 |
+
prompt=positive_prompt_tile,
|
362 |
+
negative_prompt=negative_prompt_tile,
|
363 |
+
image=tile_image_blurry,
|
364 |
+
control_image=tile_image_blurry,
|
365 |
+
num_inference_steps=int(steps),
|
366 |
+
strength=strength,
|
367 |
+
guidance_scale=guidance_scale,
|
368 |
+
controlnet_conditioning_scale=float(controlnet_scale),
|
369 |
+
generator=generator_tile,
|
370 |
+
output_type="pil"
|
371 |
+
).images[0]
|
372 |
+
|
373 |
+
# --- Stitch Tile Back ---
|
374 |
+
paste_x = x_start
|
375 |
+
paste_y = y_start
|
376 |
+
crop_w = x_end - x_start
|
377 |
+
crop_h = y_end - y_start
|
378 |
+
|
379 |
+
output_tile_region = output_tile.crop((0, 0, crop_w, crop_h))
|
380 |
+
|
381 |
+
if overlap > 0:
|
382 |
+
blend_mask_region = blend_mask.crop((0, 0, crop_w, crop_h))
|
383 |
+
current_content_region = final_image.crop((paste_x, paste_y, paste_x + crop_w, paste_y + crop_h))
|
384 |
+
blended_tile_region = Image.composite(output_tile_region, current_content_region, blend_mask_region)
|
385 |
+
final_image.paste(blended_tile_region, (paste_x, paste_y))
|
386 |
+
else:
|
387 |
+
final_image.paste(output_tile_region, (paste_x, paste_y))
|
388 |
+
|
389 |
+
processed_tile_count += 1
|
390 |
+
pbar.update(1)
|
391 |
+
|
392 |
+
# Update Gradio progress
|
393 |
+
gradio_progress = 0.2 + 0.75 * (processed_tile_count / total_tiles)
|
394 |
+
progress(gradio_progress, desc=f"Processing Tile {processed_tile_count}/{total_tiles}")
|
395 |
+
|
396 |
+
tile_end_time = time.time()
|
397 |
+
print(f"Tile ({x},{y}) processed in {tile_end_time - tile_start_time:.2f}s")
|
398 |
+
# cleanup_memory()
|
399 |
+
|
400 |
+
print("Tile processing complete.")
|
401 |
+
high_res_output_image = final_image
|
402 |
+
progress(0.95, desc="Tiling Complete")
|
403 |
+
|
404 |
+
except Exception as e:
|
405 |
+
print(f"ERROR during Hi-Res Tiling Pipeline: {e}")
|
406 |
+
import traceback
|
407 |
+
traceback.print_exc()
|
408 |
+
high_res_output_image = None
|
409 |
+
raise gr.Error(f"Failed during hi-res tiling: {e}")
|
410 |
+
finally:
|
411 |
+
# 4. Cleanup Pipeline
|
412 |
+
print("Cleaning up Hi-Res pipeline...")
|
413 |
+
if pipe_hires is not None:
|
414 |
+
try:
|
415 |
+
if hasattr(pipe_hires, 'get_active_adapters') and pipe_hires.get_active_adapters():
|
416 |
+
print("Unloading LoRAs...")
|
417 |
+
pipe_hires.unload_lora_weights()
|
418 |
+
except Exception as unload_e:
|
419 |
+
print(f"Note: Error unloading LoRAs: {unload_e}")
|
420 |
+
|
421 |
+
print("Moving Hi-Res pipe components to CPU before deleting...")
|
422 |
+
try: pipe_hires.to('cpu')
|
423 |
+
except Exception as cpu_e: print(f"Note: Error moving pipe to CPU: {cpu_e}")
|
424 |
+
|
425 |
+
print("Deleting Hi-Res pipeline object...")
|
426 |
+
del pipe_hires
|
427 |
+
pipe_hires = None
|
428 |
+
|
429 |
+
print("Running garbage collection and emptying CUDA cache after Hi-Res...")
|
430 |
+
cleanup_memory()
|
431 |
+
|
432 |
+
print("--- Hi-Res Tiling Stage Finished ---")
|
433 |
+
return high_res_output_image
|
prompts.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
"""
|
4 |
+
Defines fixed prompts and provides a function to generate
|
5 |
+
randomized prompts for each run, mirroring the original notebook behavior.
|
6 |
+
Used by the main pipeline functions.
|
7 |
+
"""
|
8 |
+
|
9 |
+
BASE_PROMPT = "detailed face portrait, accurate facial features, natural features, clear eyes, keep the gender same as the input image"
|
10 |
+
|
11 |
+
STYLE_PROMPT = r"((LIMITED PALETTE)), ((RETRO COMIC)), ((1940S \(STYLE\))), ((WESTERN COMICS \(STYLE\))), ((NIGHT COMIC)), detailed illustration, sharp lines, sfw"
|
12 |
+
|
13 |
+
BASE_NEGATIVE_PROMPT = (
|
14 |
+
"generic face, distorted features, unrealistic face, bad anatomy, extra limbs, fused fingers, poorly drawn hands, poorly drawn face, "
|
15 |
+
"text, signature, watermark, letters, words, username, artist name, speech bubble, multiple panels, "
|
16 |
+
"ugly, disfigured, deformed, low quality, worst quality, blurry, jpeg artifacts, noisy, "
|
17 |
+
"weapon, gun, knife, violence, gore, blood, injury, mutilated, horrific, nsfw, nude, naked, explicit, sexual, lingerie, bikini, suggestive, provocative, disturbing, scary, offensive, illegal, unlawful"
|
18 |
+
)
|
19 |
+
|
20 |
+
# --- Background Generation Elements ---
|
21 |
+
BG_SETTINGS = [
|
22 |
+
"on a futuristic city street at night", "in a retro sci-fi control room", "in a dusty western saloon",
|
23 |
+
"in front of an abstract energy field", "in a neon-lit alleyway", "in a stark cyberpunk cityscape",
|
24 |
+
"with speed lines background", "in a manga panel frame", "in a dimly lit laboratory",
|
25 |
+
"against a dramatic explosive background", "in a cluttered artist studio", "in a dynamic action scene"
|
26 |
+
]
|
27 |
+
|
28 |
+
BG_DETAILS = [
|
29 |
+
"detailed background", "cinematic lighting", "dramatic shadows",
|
30 |
+
"high contrast", "low angle shot", "dynamic composition", "atmospheric perspective", "intricate details"
|
31 |
+
]
|
32 |
+
|
33 |
+
def get_prompts_for_run():
|
34 |
+
"""
|
35 |
+
Generates the prompts needed for one generation cycle,
|
36 |
+
including a newly randomized background for the low-res stage.
|
37 |
+
Returns prompts suitable for low-res and hi-res stages.
|
38 |
+
"""
|
39 |
+
# --- Low-Res Prompt Generation ---
|
40 |
+
chosen_bg_setting = random.choice(BG_SETTINGS)
|
41 |
+
chosen_bg_detail = random.choice(BG_DETAILS)
|
42 |
+
background_prompt = f"{chosen_bg_setting}, {chosen_bg_detail}"
|
43 |
+
positive_prompt_lowres = f"{BASE_PROMPT}, {STYLE_PROMPT}, {background_prompt}"
|
44 |
+
|
45 |
+
# --- Tile Prompt Generation ---
|
46 |
+
positive_prompt_tile = f"{BASE_PROMPT}, {STYLE_PROMPT}"
|
47 |
+
|
48 |
+
negative_prompt_tile = (
|
49 |
+
BASE_NEGATIVE_PROMPT +
|
50 |
+
", blurry face, distorted face, mangled face, bad face, low quality, blurry"
|
51 |
+
)
|
52 |
+
|
53 |
+
return positive_prompt_lowres, BASE_NEGATIVE_PROMPT, positive_prompt_tile, negative_prompt_tile
|
requirements.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Base ML
|
2 |
+
torch==2.7.0
|
3 |
+
torchvision==0.22.0
|
4 |
+
torchaudio==2.7.0
|
5 |
+
accelerate==1.6.0
|
6 |
+
|
7 |
+
# Diffusers & Transformers
|
8 |
+
diffusers==0.33.1
|
9 |
+
transformers==4.51.3
|
10 |
+
peft==0.15.2
|
11 |
+
|
12 |
+
# ControlNet & Auxiliaries
|
13 |
+
controlnet_aux==0.0.9
|
14 |
+
mediapipe
|
15 |
+
matplotlib
|
16 |
+
opencv-python-headless==4.11.0.86
|
17 |
+
Pillow==11.2.1
|
18 |
+
pillow-heif==0.22.0
|
19 |
+
numpy==2.2.5
|
20 |
+
|
21 |
+
# Web UI
|
22 |
+
gradio==5.29.0
|