Commit
·
3646605
1
Parent(s):
b37af25
improving app
Browse files
app.py
CHANGED
|
@@ -1,7 +1,4 @@
|
|
| 1 |
import os
|
| 2 |
-
import time
|
| 3 |
-
|
| 4 |
-
os.environ["KERAS_BACKEND"] = "jax"
|
| 5 |
|
| 6 |
import gradio as gr
|
| 7 |
import jax
|
|
@@ -24,35 +21,68 @@ Two parameters that are interesting to control and adjust the amount of dehazing
|
|
| 24 |
"""
|
| 25 |
|
| 26 |
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta):
|
| 29 |
if input_img is None:
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
)
|
|
|
|
| 33 |
|
| 34 |
def _prepare_image(image):
|
| 35 |
resized = False
|
| 36 |
-
|
| 37 |
if image.mode != "L":
|
| 38 |
image = image.convert("L")
|
| 39 |
-
|
| 40 |
orig_shape = image.size[::-1]
|
| 41 |
h, w = diffusion_model.input_shape[:2]
|
| 42 |
if image.size != (w, h):
|
| 43 |
image = image.resize((w, h), Image.BILINEAR)
|
| 44 |
resized = True
|
| 45 |
-
|
| 46 |
image = np.array(image)
|
| 47 |
-
|
| 48 |
image = image.astype(np.float32)
|
| 49 |
image = image[None, ...]
|
| 50 |
return image, resized, orig_shape
|
| 51 |
|
| 52 |
try:
|
| 53 |
image, resized, orig_shape = _prepare_image(input_img)
|
| 54 |
-
except Exception:
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
guidance_kwargs = {
|
| 58 |
"omega": omega,
|
|
@@ -65,6 +95,12 @@ def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta
|
|
| 65 |
seed = jax.random.PRNGKey(config.seed)
|
| 66 |
|
| 67 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
_, pred_tissue_images, *_ = run(
|
| 69 |
hazy_images=image,
|
| 70 |
diffusion_model=diffusion_model,
|
|
@@ -75,23 +111,32 @@ def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta
|
|
| 75 |
skeleton_params=params["skeleton_params"],
|
| 76 |
batch_size=1,
|
| 77 |
diffusion_steps=diffusion_steps,
|
| 78 |
-
initial_diffusion_step=params.get("initial_diffusion_step", 0),
|
| 79 |
threshold_output_quantile=params.get("threshold_output_quantile", None),
|
| 80 |
preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0),
|
| 81 |
bottom_transition_width=params.get("bottom_transition_width", 10.0),
|
| 82 |
verbose=False,
|
| 83 |
)
|
| 84 |
-
except Exception:
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
|
| 87 |
out_img = np.squeeze(pred_tissue_images[0])
|
| 88 |
out_img = np.clip(out_img, 0, 255).astype(np.uint8)
|
| 89 |
out_pil = Image.fromarray(out_img)
|
| 90 |
-
# Resize back to original input size if needed
|
| 91 |
if resized and out_pil.size != (orig_shape[1], orig_shape[0]):
|
| 92 |
out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR)
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
|
| 97 |
slider_params = Config.from_yaml(SLIDER_CONFIG_PATH)
|
|
@@ -133,7 +178,7 @@ examples = [[img] for img in example_images]
|
|
| 133 |
|
| 134 |
with gr.Blocks() as demo:
|
| 135 |
gr.Markdown(description)
|
| 136 |
-
status = gr.Markdown("
|
| 137 |
with gr.Row():
|
| 138 |
img1 = gr.Image(label="Input Image", type="pil", webcam_options=False)
|
| 139 |
img2 = gr.ImageSlider(label="Dehazed Image", type="pil")
|
|
@@ -176,16 +221,6 @@ with gr.Blocks() as demo:
|
|
| 176 |
)
|
| 177 |
run_btn = gr.Button("Run")
|
| 178 |
|
| 179 |
-
def initialize_model():
|
| 180 |
-
time.sleep(0.5) # Let UI update
|
| 181 |
-
config = Config.from_yaml(CONFIG_PATH)
|
| 182 |
-
diffusion_model = init(config)
|
| 183 |
-
params = config.params
|
| 184 |
-
return config, diffusion_model, params
|
| 185 |
-
|
| 186 |
-
config, diffusion_model, params = initialize_model()
|
| 187 |
-
status.visible = False
|
| 188 |
-
|
| 189 |
run_btn.click(
|
| 190 |
process_image,
|
| 191 |
inputs=[
|
|
@@ -196,8 +231,9 @@ with gr.Blocks() as demo:
|
|
| 196 |
omega_sept_slider,
|
| 197 |
eta_slider,
|
| 198 |
],
|
| 199 |
-
outputs=[img2],
|
|
|
|
| 200 |
)
|
| 201 |
|
| 202 |
if __name__ == "__main__":
|
| 203 |
-
demo.launch(
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import gradio as gr
|
| 4 |
import jax
|
|
|
|
| 21 |
"""
|
| 22 |
|
| 23 |
|
| 24 |
+
def initialize_model():
|
| 25 |
+
config = Config.from_yaml(CONFIG_PATH)
|
| 26 |
+
diffusion_model = init(config)
|
| 27 |
+
return config, diffusion_model
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@spaces.GPU(duration=30)
|
| 31 |
+
|
| 32 |
+
# Generator function for status updates
|
| 33 |
def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta):
|
| 34 |
if input_img is None:
|
| 35 |
+
yield (
|
| 36 |
+
gr.update(
|
| 37 |
+
value='<div style="background:#ffeeba;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#856404;">⚠️ No input image was provided. Please select or upload an image before running.</div>'
|
| 38 |
+
),
|
| 39 |
+
None,
|
| 40 |
+
)
|
| 41 |
+
return
|
| 42 |
+
# Show loading message
|
| 43 |
+
yield (
|
| 44 |
+
gr.update(
|
| 45 |
+
value='<div style="background:#ffeeba;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#856404;">⏳ Loading model...</div>'
|
| 46 |
+
),
|
| 47 |
+
None,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
try:
|
| 51 |
+
config, diffusion_model = initialize_model()
|
| 52 |
+
params = config.params
|
| 53 |
+
except Exception as e:
|
| 54 |
+
yield (
|
| 55 |
+
gr.update(
|
| 56 |
+
value=f'<div style="background:#f8d7da;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#721c24;">❌ Error initializing model: {e}</div>'
|
| 57 |
+
),
|
| 58 |
+
None,
|
| 59 |
)
|
| 60 |
+
return
|
| 61 |
|
| 62 |
def _prepare_image(image):
|
| 63 |
resized = False
|
|
|
|
| 64 |
if image.mode != "L":
|
| 65 |
image = image.convert("L")
|
|
|
|
| 66 |
orig_shape = image.size[::-1]
|
| 67 |
h, w = diffusion_model.input_shape[:2]
|
| 68 |
if image.size != (w, h):
|
| 69 |
image = image.resize((w, h), Image.BILINEAR)
|
| 70 |
resized = True
|
|
|
|
| 71 |
image = np.array(image)
|
|
|
|
| 72 |
image = image.astype(np.float32)
|
| 73 |
image = image[None, ...]
|
| 74 |
return image, resized, orig_shape
|
| 75 |
|
| 76 |
try:
|
| 77 |
image, resized, orig_shape = _prepare_image(input_img)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
yield (
|
| 80 |
+
gr.update(
|
| 81 |
+
value=f'<div style="background:#f8d7da;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#721c24;">❌ Error preparing input image: {e}</div>'
|
| 82 |
+
),
|
| 83 |
+
None,
|
| 84 |
+
)
|
| 85 |
+
return
|
| 86 |
|
| 87 |
guidance_kwargs = {
|
| 88 |
"omega": omega,
|
|
|
|
| 95 |
seed = jax.random.PRNGKey(config.seed)
|
| 96 |
|
| 97 |
try:
|
| 98 |
+
yield (
|
| 99 |
+
gr.update(
|
| 100 |
+
value='<div style="background:#ffeeba;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#856404;">🌀 Running dehazing algorithm... (First time takes longer...) <span style="font-weight:normal;font-size:0.95em;">(first time takes longer)</span></div>'
|
| 101 |
+
),
|
| 102 |
+
None,
|
| 103 |
+
)
|
| 104 |
_, pred_tissue_images, *_ = run(
|
| 105 |
hazy_images=image,
|
| 106 |
diffusion_model=diffusion_model,
|
|
|
|
| 111 |
skeleton_params=params["skeleton_params"],
|
| 112 |
batch_size=1,
|
| 113 |
diffusion_steps=diffusion_steps,
|
|
|
|
| 114 |
threshold_output_quantile=params.get("threshold_output_quantile", None),
|
| 115 |
preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0),
|
| 116 |
bottom_transition_width=params.get("bottom_transition_width", 10.0),
|
| 117 |
verbose=False,
|
| 118 |
)
|
| 119 |
+
except Exception as e:
|
| 120 |
+
yield (
|
| 121 |
+
gr.update(
|
| 122 |
+
value=f'<div style="background:#f8d7da;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#721c24;">❌ The algorithm failed to process the image: {e}</div>'
|
| 123 |
+
),
|
| 124 |
+
None,
|
| 125 |
+
)
|
| 126 |
+
return
|
| 127 |
|
| 128 |
out_img = np.squeeze(pred_tissue_images[0])
|
| 129 |
out_img = np.clip(out_img, 0, 255).astype(np.uint8)
|
| 130 |
out_pil = Image.fromarray(out_img)
|
|
|
|
| 131 |
if resized and out_pil.size != (orig_shape[1], orig_shape[0]):
|
| 132 |
out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR)
|
| 133 |
+
yield gr.update(value="Done!"), (input_img, out_pil)
|
| 134 |
+
yield (
|
| 135 |
+
gr.update(
|
| 136 |
+
value='<div style="background:#d4edda;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#155724;">✅ Done!</div>'
|
| 137 |
+
),
|
| 138 |
+
(input_img, out_pil),
|
| 139 |
+
)
|
| 140 |
|
| 141 |
|
| 142 |
slider_params = Config.from_yaml(SLIDER_CONFIG_PATH)
|
|
|
|
| 178 |
|
| 179 |
with gr.Blocks() as demo:
|
| 180 |
gr.Markdown(description)
|
| 181 |
+
status = gr.Markdown("", visible=True)
|
| 182 |
with gr.Row():
|
| 183 |
img1 = gr.Image(label="Input Image", type="pil", webcam_options=False)
|
| 184 |
img2 = gr.ImageSlider(label="Dehazed Image", type="pil")
|
|
|
|
| 221 |
)
|
| 222 |
run_btn = gr.Button("Run")
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
run_btn.click(
|
| 225 |
process_image,
|
| 226 |
inputs=[
|
|
|
|
| 231 |
omega_sept_slider,
|
| 232 |
eta_slider,
|
| 233 |
],
|
| 234 |
+
outputs=[status, img2],
|
| 235 |
+
queue=True,
|
| 236 |
)
|
| 237 |
|
| 238 |
if __name__ == "__main__":
|
| 239 |
+
demo.launch()
|
main.py
CHANGED
|
@@ -1,9 +1,6 @@
|
|
| 1 |
import copy
|
| 2 |
-
import os
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
-
os.environ["KERAS_BACKEND"] = "jax"
|
| 6 |
-
|
| 7 |
import jax
|
| 8 |
import keras
|
| 9 |
import matplotlib.pyplot as plt
|
|
@@ -277,7 +274,6 @@ def run(
|
|
| 277 |
skeleton_params: dict,
|
| 278 |
batch_size: int = 4,
|
| 279 |
diffusion_steps: int = 100,
|
| 280 |
-
initial_diffusion_step: int = 0,
|
| 281 |
threshold_output_quantile: float = None,
|
| 282 |
preserve_bottom_percent: float = 30.0,
|
| 283 |
bottom_transition_width: float = 10.0,
|
|
@@ -306,7 +302,6 @@ def run(
|
|
| 306 |
batch,
|
| 307 |
n_samples=1,
|
| 308 |
n_steps=diffusion_steps,
|
| 309 |
-
initial_step=initial_diffusion_step,
|
| 310 |
seed=seed,
|
| 311 |
verbose=True,
|
| 312 |
per_pixel_omega=masks["per_pixel_omega"],
|
|
|
|
| 1 |
import copy
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
|
|
|
|
|
|
| 4 |
import jax
|
| 5 |
import keras
|
| 6 |
import matplotlib.pyplot as plt
|
|
|
|
| 274 |
skeleton_params: dict,
|
| 275 |
batch_size: int = 4,
|
| 276 |
diffusion_steps: int = 100,
|
|
|
|
| 277 |
threshold_output_quantile: float = None,
|
| 278 |
preserve_bottom_percent: float = 30.0,
|
| 279 |
bottom_transition_width: float = 10.0,
|
|
|
|
| 302 |
batch,
|
| 303 |
n_samples=1,
|
| 304 |
n_steps=diffusion_steps,
|
|
|
|
| 305 |
seed=seed,
|
| 306 |
verbose=True,
|
| 307 |
per_pixel_omega=masks["per_pixel_omega"],
|