tristan-deep commited on
Commit
3646605
·
1 Parent(s): b37af25

improving app

Browse files
Files changed (2) hide show
  1. app.py +67 -31
  2. main.py +0 -5
app.py CHANGED
@@ -1,7 +1,4 @@
1
  import os
2
- import time
3
-
4
- os.environ["KERAS_BACKEND"] = "jax"
5
 
6
  import gradio as gr
7
  import jax
@@ -24,35 +21,68 @@ Two parameters that are interesting to control and adjust the amount of dehazing
24
  """
25
 
26
 
27
- @spaces.GPU
 
 
 
 
 
 
 
 
28
  def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta):
29
  if input_img is None:
30
- raise gr.Error(
31
- "No input image was provided. Please select or upload an image before running."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  )
 
33
 
34
  def _prepare_image(image):
35
  resized = False
36
-
37
  if image.mode != "L":
38
  image = image.convert("L")
39
-
40
  orig_shape = image.size[::-1]
41
  h, w = diffusion_model.input_shape[:2]
42
  if image.size != (w, h):
43
  image = image.resize((w, h), Image.BILINEAR)
44
  resized = True
45
-
46
  image = np.array(image)
47
-
48
  image = image.astype(np.float32)
49
  image = image[None, ...]
50
  return image, resized, orig_shape
51
 
52
  try:
53
  image, resized, orig_shape = _prepare_image(input_img)
54
- except Exception:
55
- raise gr.Error("Something went wrong with preparing the input image.")
 
 
 
 
 
 
56
 
57
  guidance_kwargs = {
58
  "omega": omega,
@@ -65,6 +95,12 @@ def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta
65
  seed = jax.random.PRNGKey(config.seed)
66
 
67
  try:
 
 
 
 
 
 
68
  _, pred_tissue_images, *_ = run(
69
  hazy_images=image,
70
  diffusion_model=diffusion_model,
@@ -75,23 +111,32 @@ def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta
75
  skeleton_params=params["skeleton_params"],
76
  batch_size=1,
77
  diffusion_steps=diffusion_steps,
78
- initial_diffusion_step=params.get("initial_diffusion_step", 0),
79
  threshold_output_quantile=params.get("threshold_output_quantile", None),
80
  preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0),
81
  bottom_transition_width=params.get("bottom_transition_width", 10.0),
82
  verbose=False,
83
  )
84
- except Exception:
85
- raise gr.Error("The algorithm failed to process the image.")
 
 
 
 
 
 
86
 
87
  out_img = np.squeeze(pred_tissue_images[0])
88
  out_img = np.clip(out_img, 0, 255).astype(np.uint8)
89
  out_pil = Image.fromarray(out_img)
90
- # Resize back to original input size if needed
91
  if resized and out_pil.size != (orig_shape[1], orig_shape[0]):
92
  out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR)
93
- # Return tuple for ImageSlider: (input, output)
94
- return (input_img, out_pil)
 
 
 
 
 
95
 
96
 
97
  slider_params = Config.from_yaml(SLIDER_CONFIG_PATH)
@@ -133,7 +178,7 @@ examples = [[img] for img in example_images]
133
 
134
  with gr.Blocks() as demo:
135
  gr.Markdown(description)
136
- status = gr.Markdown("Initializing model, please wait...", visible=True)
137
  with gr.Row():
138
  img1 = gr.Image(label="Input Image", type="pil", webcam_options=False)
139
  img2 = gr.ImageSlider(label="Dehazed Image", type="pil")
@@ -176,16 +221,6 @@ with gr.Blocks() as demo:
176
  )
177
  run_btn = gr.Button("Run")
178
 
179
- def initialize_model():
180
- time.sleep(0.5) # Let UI update
181
- config = Config.from_yaml(CONFIG_PATH)
182
- diffusion_model = init(config)
183
- params = config.params
184
- return config, diffusion_model, params
185
-
186
- config, diffusion_model, params = initialize_model()
187
- status.visible = False
188
-
189
  run_btn.click(
190
  process_image,
191
  inputs=[
@@ -196,8 +231,9 @@ with gr.Blocks() as demo:
196
  omega_sept_slider,
197
  eta_slider,
198
  ],
199
- outputs=[img2],
 
200
  )
201
 
202
  if __name__ == "__main__":
203
- demo.launch(share=True)
 
1
  import os
 
 
 
2
 
3
  import gradio as gr
4
  import jax
 
21
  """
22
 
23
 
24
+ def initialize_model():
25
+ config = Config.from_yaml(CONFIG_PATH)
26
+ diffusion_model = init(config)
27
+ return config, diffusion_model
28
+
29
+
30
+ @spaces.GPU(duration=30)
31
+
32
+ # Generator function for status updates
33
  def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta):
34
  if input_img is None:
35
+ yield (
36
+ gr.update(
37
+ value='<div style="background:#ffeeba;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#856404;">⚠️ No input image was provided. Please select or upload an image before running.</div>'
38
+ ),
39
+ None,
40
+ )
41
+ return
42
+ # Show loading message
43
+ yield (
44
+ gr.update(
45
+ value='<div style="background:#ffeeba;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#856404;">⏳ Loading model...</div>'
46
+ ),
47
+ None,
48
+ )
49
+
50
+ try:
51
+ config, diffusion_model = initialize_model()
52
+ params = config.params
53
+ except Exception as e:
54
+ yield (
55
+ gr.update(
56
+ value=f'<div style="background:#f8d7da;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#721c24;">❌ Error initializing model: {e}</div>'
57
+ ),
58
+ None,
59
  )
60
+ return
61
 
62
  def _prepare_image(image):
63
  resized = False
 
64
  if image.mode != "L":
65
  image = image.convert("L")
 
66
  orig_shape = image.size[::-1]
67
  h, w = diffusion_model.input_shape[:2]
68
  if image.size != (w, h):
69
  image = image.resize((w, h), Image.BILINEAR)
70
  resized = True
 
71
  image = np.array(image)
 
72
  image = image.astype(np.float32)
73
  image = image[None, ...]
74
  return image, resized, orig_shape
75
 
76
  try:
77
  image, resized, orig_shape = _prepare_image(input_img)
78
+ except Exception as e:
79
+ yield (
80
+ gr.update(
81
+ value=f'<div style="background:#f8d7da;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#721c24;">❌ Error preparing input image: {e}</div>'
82
+ ),
83
+ None,
84
+ )
85
+ return
86
 
87
  guidance_kwargs = {
88
  "omega": omega,
 
95
  seed = jax.random.PRNGKey(config.seed)
96
 
97
  try:
98
+ yield (
99
+ gr.update(
100
+ value='<div style="background:#ffeeba;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#856404;">🌀 Running dehazing algorithm... (First time takes longer...) <span style="font-weight:normal;font-size:0.95em;">(first time takes longer)</span></div>'
101
+ ),
102
+ None,
103
+ )
104
  _, pred_tissue_images, *_ = run(
105
  hazy_images=image,
106
  diffusion_model=diffusion_model,
 
111
  skeleton_params=params["skeleton_params"],
112
  batch_size=1,
113
  diffusion_steps=diffusion_steps,
 
114
  threshold_output_quantile=params.get("threshold_output_quantile", None),
115
  preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0),
116
  bottom_transition_width=params.get("bottom_transition_width", 10.0),
117
  verbose=False,
118
  )
119
+ except Exception as e:
120
+ yield (
121
+ gr.update(
122
+ value=f'<div style="background:#f8d7da;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#721c24;">❌ The algorithm failed to process the image: {e}</div>'
123
+ ),
124
+ None,
125
+ )
126
+ return
127
 
128
  out_img = np.squeeze(pred_tissue_images[0])
129
  out_img = np.clip(out_img, 0, 255).astype(np.uint8)
130
  out_pil = Image.fromarray(out_img)
 
131
  if resized and out_pil.size != (orig_shape[1], orig_shape[0]):
132
  out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR)
133
+ yield gr.update(value="Done!"), (input_img, out_pil)
134
+ yield (
135
+ gr.update(
136
+ value='<div style="background:#d4edda;padding:10px;border-radius:6px;font-weight:bold;font-size:1.1em;color:#155724;">✅ Done!</div>'
137
+ ),
138
+ (input_img, out_pil),
139
+ )
140
 
141
 
142
  slider_params = Config.from_yaml(SLIDER_CONFIG_PATH)
 
178
 
179
  with gr.Blocks() as demo:
180
  gr.Markdown(description)
181
+ status = gr.Markdown("", visible=True)
182
  with gr.Row():
183
  img1 = gr.Image(label="Input Image", type="pil", webcam_options=False)
184
  img2 = gr.ImageSlider(label="Dehazed Image", type="pil")
 
221
  )
222
  run_btn = gr.Button("Run")
223
 
 
 
 
 
 
 
 
 
 
 
224
  run_btn.click(
225
  process_image,
226
  inputs=[
 
231
  omega_sept_slider,
232
  eta_slider,
233
  ],
234
+ outputs=[status, img2],
235
+ queue=True,
236
  )
237
 
238
  if __name__ == "__main__":
239
+ demo.launch()
main.py CHANGED
@@ -1,9 +1,6 @@
1
  import copy
2
- import os
3
  from pathlib import Path
4
 
5
- os.environ["KERAS_BACKEND"] = "jax"
6
-
7
  import jax
8
  import keras
9
  import matplotlib.pyplot as plt
@@ -277,7 +274,6 @@ def run(
277
  skeleton_params: dict,
278
  batch_size: int = 4,
279
  diffusion_steps: int = 100,
280
- initial_diffusion_step: int = 0,
281
  threshold_output_quantile: float = None,
282
  preserve_bottom_percent: float = 30.0,
283
  bottom_transition_width: float = 10.0,
@@ -306,7 +302,6 @@ def run(
306
  batch,
307
  n_samples=1,
308
  n_steps=diffusion_steps,
309
- initial_step=initial_diffusion_step,
310
  seed=seed,
311
  verbose=True,
312
  per_pixel_omega=masks["per_pixel_omega"],
 
1
  import copy
 
2
  from pathlib import Path
3
 
 
 
4
  import jax
5
  import keras
6
  import matplotlib.pyplot as plt
 
274
  skeleton_params: dict,
275
  batch_size: int = 4,
276
  diffusion_steps: int = 100,
 
277
  threshold_output_quantile: float = None,
278
  preserve_bottom_percent: float = 30.0,
279
  bottom_transition_width: float = 10.0,
 
302
  batch,
303
  n_samples=1,
304
  n_steps=diffusion_steps,
 
305
  seed=seed,
306
  verbose=True,
307
  per_pixel_omega=masks["per_pixel_omega"],