tristan-deep's picture
get more info HF space
b97cb4f
import os
os.environ["KERAS_BACKEND"] = "jax"
import gradio as gr
import jax
import keras
import numpy as np
import spaces
from PIL import Image
from zea import init_device
from main import Config, init, run
from utils import load_image
import torch
import subprocess
CONFIG_PATH = "configs/semantic_dps.yaml"
SLIDER_CONFIG_PATH = "configs/slider_params.yaml"
ASSETS_DIR = "assets"
DEVICE = None
STATUS_STYLE_LOAD = "display:flex;align-items:center;justify-content:center;padding:40px 10px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.5;align-items:center;"
STATUS_STYLE = "display:flex;align-items:center;justify-content:center;padding:18px 18px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.1;align-items:center;"
description = """
# Cardiac Ultrasound Dehazing with Semantic Diffusion
Select an example image below to see the dehazing algorithm in action. The algorithm was tuned for the DehazingEcho2025 challenge dataset, so be wary of using it on other datasets.
Tip: Adjust "Omega (Ventricle)" and "Eta (haze prior)" to control the dehazing effect.
"""
# Model and config will be loaded after UI is rendered
config, diffusion_model = None, None
model_loaded = False
def initialize_model():
global config, diffusion_model, model_loaded
if config is None or diffusion_model is None:
config = Config.from_yaml(CONFIG_PATH)
diffusion_model = init(config)
# Warm-up: run a dummy inference to initialize weights, JIT, etc.
h, w = diffusion_model.input_shape[:2]
dummy_img = np.zeros((1, h, w), dtype=np.float32)
params = config.params
guidance_kwargs = {
"omega": params["guidance_kwargs"]["omega"],
"omega_vent": params["guidance_kwargs"].get("omega_vent", 1.0),
"omega_sept": params["guidance_kwargs"].get("omega_sept", 1.0),
"eta": params["guidance_kwargs"].get("eta", 1.0),
"smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"],
}
seed = jax.random.PRNGKey(config.seed)
run(
hazy_images=dummy_img,
diffusion_model=diffusion_model,
seed=seed,
guidance_kwargs=guidance_kwargs,
mask_params=params["mask_params"],
fixed_mask_params=params["fixed_mask_params"],
skeleton_params=params["skeleton_params"],
batch_size=1,
diffusion_steps=1,
verbose=False,
)
model_loaded = True
return config, diffusion_model
@spaces.GPU(duration=10)
def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta):
global config, diffusion_model, model_loaded
if not model_loaded:
yield (
gr.update(
value=f'<div style="background:#ffeeba;{STATUS_STYLE}color:#856404;">⏳ Model is still loading. Please wait...</div>'
),
None,
)
return
if input_img is None:
yield (
gr.update(
value=f'<div style="background:#ffeeba;{STATUS_STYLE}color:#856404;">⚠️ No input image was provided. Please select or upload an image before running.</div>'
),
None,
)
return
params = config.params
def _prepare_image(image):
resized = False
if image.mode != "L":
image = image.convert("L")
orig_shape = image.size[::-1]
h, w = diffusion_model.input_shape[:2]
if image.size != (w, h):
image = image.resize((w, h), Image.BILINEAR)
resized = True
image = np.array(image)
image = image.astype(np.float32)
image = image[None, ...]
return image, resized, orig_shape
try:
image, resized, orig_shape = _prepare_image(input_img)
except Exception as e:
yield (
gr.update(
value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ Error preparing input image: {e}</div>'
),
None,
)
return
guidance_kwargs = {
"omega": omega,
"omega_vent": omega_vent,
"omega_sept": omega_sept,
"eta": eta,
"smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"],
}
seed = jax.random.PRNGKey(config.seed)
try:
yield (
gr.update(
value=f'<div style="background:#cce5ff;{STATUS_STYLE}color:#004085;">🌀 Running dehazing algorithm...</div>'
),
None,
)
_, pred_tissue_images, *_ = run(
hazy_images=image,
diffusion_model=diffusion_model,
seed=seed,
guidance_kwargs=guidance_kwargs,
mask_params=params["mask_params"],
fixed_mask_params=params["fixed_mask_params"],
skeleton_params=params["skeleton_params"],
batch_size=1,
diffusion_steps=diffusion_steps,
threshold_output_quantile=params.get("threshold_output_quantile", None),
preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0),
bottom_transition_width=params.get("bottom_transition_width", 10.0),
verbose=False,
)
except Exception as e:
yield (
gr.update(
value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ The algorithm failed to process the image: {e}</div>'
),
None,
)
return
out_img = np.squeeze(pred_tissue_images[0])
out_img = np.clip(out_img, 0, 255).astype(np.uint8)
out_pil = Image.fromarray(out_img)
if resized and out_pil.size != (orig_shape[1], orig_shape[0]):
out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR)
yield (
gr.update(
value=f'<div style="background:#d4edda;{STATUS_STYLE}color:#155724;">✅ Done!</div>'
),
(input_img, out_pil),
)
slider_params = Config.from_yaml(SLIDER_CONFIG_PATH)
diffusion_steps_default = slider_params["diffusion_steps"]["default"]
diffusion_steps_min = slider_params["diffusion_steps"]["min"]
diffusion_steps_max = slider_params["diffusion_steps"]["max"]
diffusion_steps_step = slider_params["diffusion_steps"]["step"]
omega_default = slider_params["omega"]["default"]
omega_min = slider_params["omega"]["min"]
omega_max = slider_params["omega"]["max"]
omega_step = slider_params["omega"]["step"]
omega_vent_default = slider_params["omega_vent"]["default"]
omega_vent_min = slider_params["omega_vent"]["min"]
omega_vent_max = slider_params["omega_vent"]["max"]
omega_vent_step = slider_params["omega_vent"]["step"]
omega_sept_default = slider_params["omega_sept"]["default"]
omega_sept_min = slider_params["omega_sept"]["min"]
omega_sept_max = slider_params["omega_sept"]["max"]
omega_sept_step = slider_params["omega_sept"]["step"]
eta_default = slider_params["eta"]["default"]
eta_min = slider_params["eta"]["min"]
eta_max = slider_params["eta"]["max"]
eta_step = slider_params["eta"]["step"]
example_image_paths = [
os.path.join(ASSETS_DIR, f)
for f in os.listdir(ASSETS_DIR)
if f.lower().endswith(".png")
]
example_images = [load_image(p) for p in example_image_paths]
examples = [[img] for img in example_images]
with gr.Blocks() as demo:
gr.Markdown(description)
status = gr.Markdown(
f'<div style="background:#ffeeba;{STATUS_STYLE_LOAD}color:#856404;">⏳ Loading model...</div>',
visible=True,
)
with gr.Row():
with gr.Column():
img1 = gr.Image(
label="Input Image",
type="pil",
webcam_options=False,
value=example_images[0] if example_images else None,
)
gr.Examples(examples=examples, inputs=[img1])
with gr.Column():
img2 = gr.ImageSlider(label="Dehazed Image", type="pil")
with gr.Row():
diffusion_steps_slider = gr.Slider(
minimum=diffusion_steps_min,
maximum=diffusion_steps_max,
step=diffusion_steps_step,
value=diffusion_steps_default,
label="Diffusion Steps",
)
omega_slider = gr.Slider(
minimum=omega_min,
maximum=omega_max,
step=omega_step,
value=omega_default,
label="Omega (background)",
)
omega_vent_slider = gr.Slider(
minimum=omega_vent_min,
maximum=omega_vent_max,
step=omega_vent_step,
value=omega_vent_default,
label="Omega Ventricle",
)
omega_sept_slider = gr.Slider(
minimum=omega_sept_min,
maximum=omega_sept_max,
step=omega_sept_step,
value=omega_sept_default,
label="Omega Septum",
)
eta_slider = gr.Slider(
minimum=eta_min,
maximum=eta_max,
step=eta_step,
value=eta_default,
label="Eta (haze prior)",
)
run_btn = gr.Button("Run", interactive=False)
run_btn.click(
process_image,
inputs=[
img1,
diffusion_steps_slider,
omega_slider,
omega_vent_slider,
omega_sept_slider,
eta_slider,
],
outputs=[status, img2],
queue=True,
)
def load_model_event():
global config, diffusion_model, model_loaded, DEVICE
try:
if DEVICE is None:
try:
DEVICE = init_device()
except:
print("Could not initialize device using `zea.init_device()`")
print(f"KERAS version: {keras.__version__}")
try:
print(f"JAX version: {jax.__version__}")
print(f"JAX devices: {jax.devices()}")
except Exception as e:
print(f"Could not get JAX info: {e}")
try:
print(f"PyTorch version: {torch.__version__}")
print(f"PyTorch CUDA available: {torch.cuda.is_available()}")
print(f"PyTorch CUDA device count: {torch.cuda.device_count()}")
print(f"PyTorch devices: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}")
print(f"PyTorch CUDA version: {torch.version.cuda}")
print(f"PyTorch cuDNN version: {torch.backends.cudnn.version()}")
except Exception as e:
print(f"Could not get PyTorch info: {e}")
try:
cuda_version = subprocess.getoutput("nvcc --version")
print(f"nvcc version:\n{cuda_version}")
nvidia_smi = subprocess.getoutput("nvidia-smi")
print(f"nvidia-smi output:\n{nvidia_smi}")
except Exception as e:
print(f"Could not get CUDA/nvidia-smi info: {e}")
config, diffusion_model = initialize_model()
ready_msg = gr.update(
value=f'<div style="background:#d4edda;{STATUS_STYLE}color:#155724;">✅ Model loaded! You can now press Run.</div>'
)
return ready_msg, gr.update(interactive=True)
except Exception as e:
return gr.update(
value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ Error loading model: {e}</div>'
), gr.update(interactive=False)
demo.load(
load_model_event,
inputs=None,
outputs=[status, run_btn],
)
if __name__ == "__main__":
demo.launch()