Spaces:
Running
on
Zero
Running
on
Zero
import os | |
os.environ["KERAS_BACKEND"] = "jax" | |
import gradio as gr | |
import jax | |
import keras | |
import numpy as np | |
import spaces | |
from PIL import Image | |
from zea import init_device | |
from main import Config, init, run | |
from utils import load_image | |
import torch | |
import subprocess | |
CONFIG_PATH = "configs/semantic_dps.yaml" | |
SLIDER_CONFIG_PATH = "configs/slider_params.yaml" | |
ASSETS_DIR = "assets" | |
DEVICE = None | |
STATUS_STYLE_LOAD = "display:flex;align-items:center;justify-content:center;padding:40px 10px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.5;align-items:center;" | |
STATUS_STYLE = "display:flex;align-items:center;justify-content:center;padding:18px 18px 18px 10px;border-radius:8px;font-weight:bold;font-size:1.15em;line-height:1.1;align-items:center;" | |
description = """ | |
# Cardiac Ultrasound Dehazing with Semantic Diffusion | |
Select an example image below to see the dehazing algorithm in action. The algorithm was tuned for the DehazingEcho2025 challenge dataset, so be wary of using it on other datasets. | |
Tip: Adjust "Omega (Ventricle)" and "Eta (haze prior)" to control the dehazing effect. | |
""" | |
# Model and config will be loaded after UI is rendered | |
config, diffusion_model = None, None | |
model_loaded = False | |
def initialize_model(): | |
global config, diffusion_model, model_loaded | |
if config is None or diffusion_model is None: | |
config = Config.from_yaml(CONFIG_PATH) | |
diffusion_model = init(config) | |
# Warm-up: run a dummy inference to initialize weights, JIT, etc. | |
h, w = diffusion_model.input_shape[:2] | |
dummy_img = np.zeros((1, h, w), dtype=np.float32) | |
params = config.params | |
guidance_kwargs = { | |
"omega": params["guidance_kwargs"]["omega"], | |
"omega_vent": params["guidance_kwargs"].get("omega_vent", 1.0), | |
"omega_sept": params["guidance_kwargs"].get("omega_sept", 1.0), | |
"eta": params["guidance_kwargs"].get("eta", 1.0), | |
"smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"], | |
} | |
seed = jax.random.PRNGKey(config.seed) | |
run( | |
hazy_images=dummy_img, | |
diffusion_model=diffusion_model, | |
seed=seed, | |
guidance_kwargs=guidance_kwargs, | |
mask_params=params["mask_params"], | |
fixed_mask_params=params["fixed_mask_params"], | |
skeleton_params=params["skeleton_params"], | |
batch_size=1, | |
diffusion_steps=1, | |
verbose=False, | |
) | |
model_loaded = True | |
return config, diffusion_model | |
def process_image(input_img, diffusion_steps, omega, omega_vent, omega_sept, eta): | |
global config, diffusion_model, model_loaded | |
if not model_loaded: | |
yield ( | |
gr.update( | |
value=f'<div style="background:#ffeeba;{STATUS_STYLE}color:#856404;">⏳ Model is still loading. Please wait...</div>' | |
), | |
None, | |
) | |
return | |
if input_img is None: | |
yield ( | |
gr.update( | |
value=f'<div style="background:#ffeeba;{STATUS_STYLE}color:#856404;">⚠️ No input image was provided. Please select or upload an image before running.</div>' | |
), | |
None, | |
) | |
return | |
params = config.params | |
def _prepare_image(image): | |
resized = False | |
if image.mode != "L": | |
image = image.convert("L") | |
orig_shape = image.size[::-1] | |
h, w = diffusion_model.input_shape[:2] | |
if image.size != (w, h): | |
image = image.resize((w, h), Image.BILINEAR) | |
resized = True | |
image = np.array(image) | |
image = image.astype(np.float32) | |
image = image[None, ...] | |
return image, resized, orig_shape | |
try: | |
image, resized, orig_shape = _prepare_image(input_img) | |
except Exception as e: | |
yield ( | |
gr.update( | |
value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ Error preparing input image: {e}</div>' | |
), | |
None, | |
) | |
return | |
guidance_kwargs = { | |
"omega": omega, | |
"omega_vent": omega_vent, | |
"omega_sept": omega_sept, | |
"eta": eta, | |
"smooth_l1_beta": params["guidance_kwargs"]["smooth_l1_beta"], | |
} | |
seed = jax.random.PRNGKey(config.seed) | |
try: | |
yield ( | |
gr.update( | |
value=f'<div style="background:#cce5ff;{STATUS_STYLE}color:#004085;">🌀 Running dehazing algorithm...</div>' | |
), | |
None, | |
) | |
_, pred_tissue_images, *_ = run( | |
hazy_images=image, | |
diffusion_model=diffusion_model, | |
seed=seed, | |
guidance_kwargs=guidance_kwargs, | |
mask_params=params["mask_params"], | |
fixed_mask_params=params["fixed_mask_params"], | |
skeleton_params=params["skeleton_params"], | |
batch_size=1, | |
diffusion_steps=diffusion_steps, | |
threshold_output_quantile=params.get("threshold_output_quantile", None), | |
preserve_bottom_percent=params.get("preserve_bottom_percent", 30.0), | |
bottom_transition_width=params.get("bottom_transition_width", 10.0), | |
verbose=False, | |
) | |
except Exception as e: | |
yield ( | |
gr.update( | |
value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ The algorithm failed to process the image: {e}</div>' | |
), | |
None, | |
) | |
return | |
out_img = np.squeeze(pred_tissue_images[0]) | |
out_img = np.clip(out_img, 0, 255).astype(np.uint8) | |
out_pil = Image.fromarray(out_img) | |
if resized and out_pil.size != (orig_shape[1], orig_shape[0]): | |
out_pil = out_pil.resize((orig_shape[1], orig_shape[0]), Image.BILINEAR) | |
yield ( | |
gr.update( | |
value=f'<div style="background:#d4edda;{STATUS_STYLE}color:#155724;">✅ Done!</div>' | |
), | |
(input_img, out_pil), | |
) | |
slider_params = Config.from_yaml(SLIDER_CONFIG_PATH) | |
diffusion_steps_default = slider_params["diffusion_steps"]["default"] | |
diffusion_steps_min = slider_params["diffusion_steps"]["min"] | |
diffusion_steps_max = slider_params["diffusion_steps"]["max"] | |
diffusion_steps_step = slider_params["diffusion_steps"]["step"] | |
omega_default = slider_params["omega"]["default"] | |
omega_min = slider_params["omega"]["min"] | |
omega_max = slider_params["omega"]["max"] | |
omega_step = slider_params["omega"]["step"] | |
omega_vent_default = slider_params["omega_vent"]["default"] | |
omega_vent_min = slider_params["omega_vent"]["min"] | |
omega_vent_max = slider_params["omega_vent"]["max"] | |
omega_vent_step = slider_params["omega_vent"]["step"] | |
omega_sept_default = slider_params["omega_sept"]["default"] | |
omega_sept_min = slider_params["omega_sept"]["min"] | |
omega_sept_max = slider_params["omega_sept"]["max"] | |
omega_sept_step = slider_params["omega_sept"]["step"] | |
eta_default = slider_params["eta"]["default"] | |
eta_min = slider_params["eta"]["min"] | |
eta_max = slider_params["eta"]["max"] | |
eta_step = slider_params["eta"]["step"] | |
example_image_paths = [ | |
os.path.join(ASSETS_DIR, f) | |
for f in os.listdir(ASSETS_DIR) | |
if f.lower().endswith(".png") | |
] | |
example_images = [load_image(p) for p in example_image_paths] | |
examples = [[img] for img in example_images] | |
with gr.Blocks() as demo: | |
gr.Markdown(description) | |
status = gr.Markdown( | |
f'<div style="background:#ffeeba;{STATUS_STYLE_LOAD}color:#856404;">⏳ Loading model...</div>', | |
visible=True, | |
) | |
with gr.Row(): | |
with gr.Column(): | |
img1 = gr.Image( | |
label="Input Image", | |
type="pil", | |
webcam_options=False, | |
value=example_images[0] if example_images else None, | |
) | |
gr.Examples(examples=examples, inputs=[img1]) | |
with gr.Column(): | |
img2 = gr.ImageSlider(label="Dehazed Image", type="pil") | |
with gr.Row(): | |
diffusion_steps_slider = gr.Slider( | |
minimum=diffusion_steps_min, | |
maximum=diffusion_steps_max, | |
step=diffusion_steps_step, | |
value=diffusion_steps_default, | |
label="Diffusion Steps", | |
) | |
omega_slider = gr.Slider( | |
minimum=omega_min, | |
maximum=omega_max, | |
step=omega_step, | |
value=omega_default, | |
label="Omega (background)", | |
) | |
omega_vent_slider = gr.Slider( | |
minimum=omega_vent_min, | |
maximum=omega_vent_max, | |
step=omega_vent_step, | |
value=omega_vent_default, | |
label="Omega Ventricle", | |
) | |
omega_sept_slider = gr.Slider( | |
minimum=omega_sept_min, | |
maximum=omega_sept_max, | |
step=omega_sept_step, | |
value=omega_sept_default, | |
label="Omega Septum", | |
) | |
eta_slider = gr.Slider( | |
minimum=eta_min, | |
maximum=eta_max, | |
step=eta_step, | |
value=eta_default, | |
label="Eta (haze prior)", | |
) | |
run_btn = gr.Button("Run", interactive=False) | |
run_btn.click( | |
process_image, | |
inputs=[ | |
img1, | |
diffusion_steps_slider, | |
omega_slider, | |
omega_vent_slider, | |
omega_sept_slider, | |
eta_slider, | |
], | |
outputs=[status, img2], | |
queue=True, | |
) | |
def load_model_event(): | |
global config, diffusion_model, model_loaded, DEVICE | |
try: | |
if DEVICE is None: | |
try: | |
DEVICE = init_device() | |
except: | |
print("Could not initialize device using `zea.init_device()`") | |
print(f"KERAS version: {keras.__version__}") | |
try: | |
print(f"JAX version: {jax.__version__}") | |
print(f"JAX devices: {jax.devices()}") | |
except Exception as e: | |
print(f"Could not get JAX info: {e}") | |
try: | |
print(f"PyTorch version: {torch.__version__}") | |
print(f"PyTorch CUDA available: {torch.cuda.is_available()}") | |
print(f"PyTorch CUDA device count: {torch.cuda.device_count()}") | |
print(f"PyTorch devices: {[torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]}") | |
print(f"PyTorch CUDA version: {torch.version.cuda}") | |
print(f"PyTorch cuDNN version: {torch.backends.cudnn.version()}") | |
except Exception as e: | |
print(f"Could not get PyTorch info: {e}") | |
try: | |
cuda_version = subprocess.getoutput("nvcc --version") | |
print(f"nvcc version:\n{cuda_version}") | |
nvidia_smi = subprocess.getoutput("nvidia-smi") | |
print(f"nvidia-smi output:\n{nvidia_smi}") | |
except Exception as e: | |
print(f"Could not get CUDA/nvidia-smi info: {e}") | |
config, diffusion_model = initialize_model() | |
ready_msg = gr.update( | |
value=f'<div style="background:#d4edda;{STATUS_STYLE}color:#155724;">✅ Model loaded! You can now press Run.</div>' | |
) | |
return ready_msg, gr.update(interactive=True) | |
except Exception as e: | |
return gr.update( | |
value=f'<div style="background:#f8d7da;{STATUS_STYLE}color:#721c24;">❌ Error loading model: {e}</div>' | |
), gr.update(interactive=False) | |
demo.load( | |
load_model_event, | |
inputs=None, | |
outputs=[status, run_btn], | |
) | |
if __name__ == "__main__": | |
demo.launch() | |