import os os.environ["KERAS_BACKEND"] = "jax" import gradio as gr import jax import keras import numpy as np import spaces from PIL import Image from zea import init_device from main import Config, init, run from utils import load_image import torch import subprocess CONFIG_PATH = "configs/semantic_dps.yaml" SLIDER_CONFIG_PATH = "configs/slider_params.yaml" ASSETS_DIR = "assets" DEVICE = None STATUS_STYLE_LOAD = "display:flex;align-items:center;justify-content:center;padding:40px 10px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.5;align-items:center;" STATUS_STYLE = "display:flex;align-items:center;justify-content:center;padding:18px 18px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.1;align-items:center;" description = """ # Cardiac Ultrasound Dehazing with Semantic Diffusion Select an example image below to see the dehazing algorithm in action. The algorithm was tuned for the DehazingEcho2025 challenge dataset, so be wary of using it on other datasets. Tip: Adjust "Omega (Ventricle)" and "Eta (haze prior)" to control the dehazing effect. """ # Model and config will be loaded after UI is rendered config, diffusion_model = None, None model_loaded = False def initialize_model(): global config, diffusion_model, model_loaded if config is None or diffusion_model is None: config = Config.from_yaml(CONFIG_PATH) diffusion_model = init(config) # Warm-up: run a dummy inference to initialize weights, JIT, etc. h, w = diffusion_model.input_shape[:2] dummy_img = np.zeros((1, h, w), dtype=np.float32) params = config.params guidance_kwargs = { "omega": params["guidance_kwargs"]["omega"], "omega_vent": params["guidance_kwargs"].get("omega_vent", 1.0), "omega_sept": params["guidance_kwargs"].get("omega_sept", 1.0), "eta": params["guidance_kwargs"].get("eta", 1.0), "smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"], } seed = jax.random.PRNGKey(config.seed) run( hazy_images=dummy_img, diffusion_model=diffusion_model, seed=seed, guidance_kwargs=guidance_kwargs, mask_params=params["mask_params"], fixed_mask_params=params["fixed_mask_params"], skeleton_params=params["skeleton_params"], batch_size=1, diffusion_steps=1, verbose=False, ) model_loaded = True return config, diffusion_model @spaces.GPU(duration=10) def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta): global config, diffusion_model, model_loaded if not model_loaded: yield ( gr.update( value=f'
⏳ Model is still loading. Please wait...
' ), None, ) return if input_img is None: yield ( gr.update( value=f'
⚠️ No input image was provided. Please select or upload an image before running.
' ), None, ) return params = config.params def _prepare_image(image): resized = False if image.mode != "L": image = image.convert("L") orig_shape = image.size[::-1] h, w = diffusion_model.input_shape[:2] if image.size != (w, h): image = image.resize((w, h), Image.BILINEAR) resized = True image = np.array(image) image = image.astype(np.float32) image = image[None, ...] return image, resized, orig_shape try: image, resized, orig_shape = _prepare_image(input_img) except Exception as e: yield ( gr.update( value=f'
❌ Error preparing input image: {e}
' ), None, ) return guidance_kwargs = { "omega": omega, "omega_vent": omega_vent, "omega_sept": omega_sept, "eta": eta, "smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"], } seed = jax.random.PRNGKey(config.seed) try: yield ( gr.update( value=f'
🌀 Running dehazing algorithm...
' ), None, ) _, pred_tissue_images, *_ = run( hazy_images=image, diffusion_model=diffusion_model, seed=seed, guidance_kwargs=guidance_kwargs, mask_params=params["mask_params"], fixed_mask_params=params["fixed_mask_params"], skeleton_params=params["skeleton_params"], batch_size=1, diffusion_steps=diffusion_steps, threshold_output_quantile=params.get("threshold_output_quantile", None), preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0), bottom_transition_width=params.get("bottom_transition_width", 10.0), verbose=False, ) except Exception as e: yield ( gr.update( value=f'
❌ The algorithm failed to process the image: {e}
' ), None, ) return out_img = np.squeeze(pred_tissue_images[0]) out_img = np.clip(out_img, 0, 255).astype(np.uint8) out_pil = Image.fromarray(out_img) if resized and out_pil.size != (orig_shape[1], orig_shape[0]): out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR) yield ( gr.update( value=f'
✅ Done!
' ), (input_img, out_pil), ) slider_params = Config.from_yaml(SLIDER_CONFIG_PATH) diffusion_steps_default = slider_params["diffusion_steps"]["default"] diffusion_steps_min = slider_params["diffusion_steps"]["min"] diffusion_steps_max = slider_params["diffusion_steps"]["max"] diffusion_steps_step = slider_params["diffusion_steps"]["step"] omega_default = slider_params["omega"]["default"] omega_min = slider_params["omega"]["min"] omega_max = slider_params["omega"]["max"] omega_step = slider_params["omega"]["step"] omega_vent_default = slider_params["omega_vent"]["default"] omega_vent_min = slider_params["omega_vent"]["min"] omega_vent_max = slider_params["omega_vent"]["max"] omega_vent_step = slider_params["omega_vent"]["step"] omega_sept_default = slider_params["omega_sept"]["default"] omega_sept_min = slider_params["omega_sept"]["min"] omega_sept_max = slider_params["omega_sept"]["max"] omega_sept_step = slider_params["omega_sept"]["step"] eta_default = slider_params["eta"]["default"] eta_min = slider_params["eta"]["min"] eta_max = slider_params["eta"]["max"] eta_step = slider_params["eta"]["step"] example_image_paths = [ os.path.join(ASSETS_DIR, f) for f in os.listdir(ASSETS_DIR) if f.lower().endswith(".png") ] example_images = [load_image(p) for p in example_image_paths] examples = [[img] for img in example_images] with gr.Blocks() as demo: gr.Markdown(description) status = gr.Markdown( f'
⏳ Loading model...
', visible=True, ) with gr.Row(): with gr.Column(): img1 = gr.Image( label="Input Image", type="pil", webcam_options=False, value=example_images[0] if example_images else None, ) gr.Examples(examples=examples, inputs=[img1]) with gr.Column(): img2 = gr.ImageSlider(label="Dehazed Image", type="pil") with gr.Row(): diffusion_steps_slider = gr.Slider( minimum=diffusion_steps_min, maximum=diffusion_steps_max, step=diffusion_steps_step, value=diffusion_steps_default, label="Diffusion Steps", ) omega_slider = gr.Slider( minimum=omega_min, maximum=omega_max, step=omega_step, value=omega_default, label="Omega (background)", ) omega_vent_slider = gr.Slider( minimum=omega_vent_min, maximum=omega_vent_max, step=omega_vent_step, value=omega_vent_default, label="Omega Ventricle", ) omega_sept_slider = gr.Slider( minimum=omega_sept_min, maximum=omega_sept_max, step=omega_sept_step, value=omega_sept_default, label="Omega Septum", ) eta_slider = gr.Slider( minimum=eta_min, maximum=eta_max, step=eta_step, value=eta_default, label="Eta (haze prior)", ) run_btn = gr.Button("Run", interactive=False) run_btn.click( process_image, inputs=[ img1, diffusion_steps_slider, omega_slider, omega_vent_slider, omega_sept_slider, eta_slider, ], outputs=[status, img2], queue=True, ) def load_model_event(): global config, diffusion_model, model_loaded, DEVICE try: if DEVICE is None: try: DEVICE = init_device() except: print("Could not initialize device using `zea.init_device()`") print(f"KERAS version: {keras.__version__}") try: print(f"JAX version: {jax.__version__}") print(f"JAX devices: {jax.devices()}") except Exception as e: print(f"Could not get JAX info: {e}") try: print(f"PyTorch version: {torch.__version__}") print(f"PyTorch CUDA available: {torch.cuda.is_available()}") print(f"PyTorch CUDA device count: {torch.cuda.device_count()}") print(f"PyTorch devices: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}") print(f"PyTorch CUDA version: {torch.version.cuda}") print(f"PyTorch cuDNN version: {torch.backends.cudnn.version()}") except Exception as e: print(f"Could not get PyTorch info: {e}") try: cuda_version = subprocess.getoutput("nvcc --version") print(f"nvcc version:\n{cuda_version}") nvidia_smi = subprocess.getoutput("nvidia-smi") print(f"nvidia-smi output:\n{nvidia_smi}") except Exception as e: print(f"Could not get CUDA/nvidia-smi info: {e}") config, diffusion_model = initialize_model() ready_msg = gr.update( value=f'
✅ Model loaded! You can now press Run.
' ) return ready_msg, gr.update(interactive=True) except Exception as e: return gr.update( value=f'
❌ Error loading model: {e}
' ), gr.update(interactive=False) demo.load( load_model_event, inputs=None, outputs=[status, run_btn], ) if __name__ == "__main__": demo.launch()