|
""" |
|
|
|
NOTE: pip install optuna |
|
|
|
""" |
|
|
|
import dataclasses |
|
import json |
|
import shutil |
|
import tempfile |
|
from pathlib import Path |
|
from typing import Any, Dict, Optional |
|
|
|
import jax |
|
import numpy as np |
|
import optuna |
|
import tyro |
|
import yaml |
|
import zea |
|
from keras import ops |
|
from PIL import Image |
|
from zea import init_device, log |
|
|
|
from eval import evaluate |
|
from main import init, run |
|
|
|
|
|
def load_images_from_dir(input_folder): |
|
"""Load images from directory, similar to main.py implementation.""" |
|
paths = list(Path(input_folder).glob("*.png")) |
|
|
|
images = [] |
|
for path in paths: |
|
image = zea.io_lib.load_image(path) |
|
images.append(image) |
|
|
|
if len(images) == 0: |
|
raise ValueError(f"No PNG images found in {input_folder}") |
|
|
|
images = ops.stack(images, axis=0) |
|
return images, paths |
|
|
|
|
|
def save_images_to_temp_dir(images, image_paths, prefix=""): |
|
"""Save numpy arrays as PNG images to a temporary directory.""" |
|
temp_dir = tempfile.mkdtemp(prefix=prefix) |
|
temp_dir_path = Path(temp_dir) |
|
|
|
for i, (img, path) in enumerate(zip(images, image_paths)): |
|
|
|
filename = Path(path).name |
|
|
|
|
|
if img.dtype != np.uint8: |
|
|
|
if img.max() <= 1.0: |
|
img = (img * 255).astype(np.uint8) |
|
else: |
|
img = img.astype(np.uint8) |
|
|
|
|
|
if len(img.shape) == 3 and img.shape[-1] == 1: |
|
img = img.squeeze(-1) |
|
|
|
|
|
img_pil = Image.fromarray(img) |
|
save_path = temp_dir_path / filename |
|
img_pil.save(save_path) |
|
|
|
return str(temp_dir_path) |
|
|
|
|
|
@dataclasses.dataclass |
|
class SweeperConfig: |
|
"""Configuration for hyperparameter sweeping with Optuna.""" |
|
|
|
|
|
input_image_dir: str |
|
roi_folder: str |
|
reference_folder: str |
|
base_config_path: str = "configs/semantic_dps.yaml" |
|
|
|
|
|
method: str = "semantic_dps" |
|
broad_sweep: bool = False |
|
|
|
|
|
study_name: str = "dehaze_optimization" |
|
storage: Optional[str] = None |
|
n_trials: int = 100 |
|
|
|
|
|
objective_metric: str = "final_score" |
|
direction: str = "maximize" |
|
|
|
|
|
output_dir: str = "sweep_results" |
|
|
|
|
|
skip_fid: bool = False |
|
|
|
|
|
device: str = "auto:1" |
|
|
|
|
|
enable_pruning: bool = True |
|
pruner_type: str = "median" |
|
|
|
|
|
class OptunaObjective: |
|
"""Optuna objective function for hyperparameter optimization.""" |
|
|
|
def __init__(self, config: SweeperConfig): |
|
self.config = config |
|
self.base_config = self._load_base_config() |
|
self.hazy_images, self.image_paths = load_images_from_dir( |
|
config.input_image_dir |
|
) |
|
|
|
|
|
init_device(config.device) |
|
|
|
|
|
self.diffusion_model = init(self.base_config) |
|
|
|
def _load_base_config(self): |
|
"""Load base configuration from YAML file.""" |
|
with open(self.config.base_config_path, "r") as f: |
|
config_dict = yaml.safe_load(f) |
|
return zea.Config(**config_dict) |
|
|
|
def _create_trial_params(self, trial: optuna.Trial) -> Dict[str, Any]: |
|
"""Create trial parameters by suggesting hyperparameters.""" |
|
params = { |
|
"guidance_kwargs": { |
|
"omega": trial.suggest_float("omega", 0.5, 50.0, log=True), |
|
"omega_vent": trial.suggest_float("omega_vent", 0.0001, 50.0, log=True), |
|
"omega_sept": trial.suggest_float("omega_sept", 0.1, 50.0, log=True), |
|
"eta": trial.suggest_float("eta", 0.001, 1.0, log=True), |
|
"smooth_l1_beta": trial.suggest_float( |
|
"smooth_l1_beta", 0.1, 10.0, log=True |
|
), |
|
}, |
|
"skeleton_params": { |
|
"sigma_pre": trial.suggest_float("skeleton_sigma_pre", 0.0, 10.0), |
|
"sigma_post": trial.suggest_float("skeleton_sigma_post", 0.0, 10.0), |
|
"threshold": trial.suggest_float("skeleton_threshold", 0.0, 1.0), |
|
}, |
|
"mask_params": { |
|
"threshold": trial.suggest_float("mask_threshold", 0.0, 1.0), |
|
"sigma": trial.suggest_float("mask_sigma", 0.0, 10.0), |
|
}, |
|
} |
|
|
|
|
|
if hasattr(self.base_config, "params"): |
|
base_params = self.base_config.params |
|
for key, value in base_params.items(): |
|
if key not in params: |
|
params[key] = value |
|
|
|
return params |
|
|
|
def __call__(self, trial: optuna.Trial) -> float: |
|
"""Optuna objective function.""" |
|
|
|
params = self._create_trial_params(trial) |
|
|
|
|
|
seed = jax.random.PRNGKey(self.base_config.seed + trial.number) |
|
|
|
|
|
try: |
|
hazy_images, pred_tissue_images, pred_haze_images, masks = run( |
|
hazy_images=self.hazy_images, |
|
diffusion_model=self.diffusion_model, |
|
seed=seed, |
|
**params, |
|
) |
|
except Exception as e: |
|
log.error(f"Error during model inference: {e}") |
|
return 0.0 |
|
|
|
|
|
if hasattr(pred_tissue_images, "numpy"): |
|
pred_tissue_images = pred_tissue_images.numpy() |
|
|
|
|
|
pred_tissue_temp_dir = None |
|
|
|
try: |
|
|
|
pred_tissue_temp_dir = save_images_to_temp_dir( |
|
pred_tissue_images, self.image_paths, prefix="pred_tissue_" |
|
) |
|
|
|
|
|
results = evaluate( |
|
folder=pred_tissue_temp_dir, |
|
noisy_folder=self.config.input_image_dir, |
|
roi_folder=self.config.roi_folder, |
|
reference_folder=self.config.reference_folder, |
|
) |
|
|
|
objective_value = results[self.config.objective_metric] |
|
|
|
except Exception as e: |
|
log.error(f"Error during evaluation: {e}") |
|
objective_value = 0.0 |
|
|
|
finally: |
|
|
|
if pred_tissue_temp_dir and Path(pred_tissue_temp_dir).exists(): |
|
try: |
|
shutil.rmtree(pred_tissue_temp_dir) |
|
except Exception as e: |
|
log.warning( |
|
f"Failed to clean up temp directory {pred_tissue_temp_dir}: {e}" |
|
) |
|
|
|
|
|
trial.report(objective_value, step=0) |
|
|
|
|
|
if trial.should_prune(): |
|
raise optuna.TrialPruned() |
|
|
|
|
|
for key, value in params.items(): |
|
if isinstance(value, dict): |
|
for subkey, subvalue in value.items(): |
|
trial.set_user_attr(f"{key}_{subkey}", subvalue) |
|
else: |
|
trial.set_user_attr(key, value) |
|
|
|
log.info( |
|
f"Trial {trial.number}: {self.config.objective_metric} = {objective_value:.4f}" |
|
) |
|
|
|
return objective_value |
|
|
|
|
|
def create_pruner(pruner_type: str) -> optuna.pruners.BasePruner: |
|
"""Create an Optuna pruner based on the specified type.""" |
|
if pruner_type == "median": |
|
return optuna.pruners.MedianPruner( |
|
n_startup_trials=5, n_warmup_steps=0, interval_steps=1 |
|
) |
|
elif pruner_type == "hyperband": |
|
return optuna.pruners.HyperbandPruner( |
|
min_resource=1, max_resource=100, reduction_factor=3 |
|
) |
|
else: |
|
return optuna.pruners.NopPruner() |
|
|
|
|
|
def run_optimization(config: SweeperConfig): |
|
"""Run hyperparameter optimization using Optuna.""" |
|
|
|
|
|
pruner = create_pruner(config.pruner_type) if config.enable_pruning else None |
|
|
|
|
|
study = optuna.create_study( |
|
study_name=config.study_name, |
|
storage=config.storage, |
|
direction=config.direction, |
|
pruner=pruner, |
|
load_if_exists=True, |
|
) |
|
|
|
log.info(f"Starting optimization for method: {config.method}") |
|
log.info(f"Study name: {config.study_name}") |
|
log.info(f"Number of trials: {config.n_trials}") |
|
log.info(f"Objective metric: {config.objective_metric} ({config.direction})") |
|
|
|
|
|
objective = OptunaObjective(config) |
|
|
|
|
|
study.optimize(objective, n_trials=config.n_trials) |
|
|
|
|
|
output_dir = Path(config.output_dir) |
|
output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
best_trial = study.best_trial |
|
best_results = { |
|
"best_value": best_trial.value, |
|
"best_params": best_trial.params, |
|
"best_user_attrs": best_trial.user_attrs, |
|
"study_stats": { |
|
"n_trials": len(study.trials), |
|
"n_complete_trials": len( |
|
[t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE] |
|
), |
|
"n_pruned_trials": len( |
|
[t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED] |
|
), |
|
"n_failed_trials": len( |
|
[t for t in study.trials if t.state == optuna.trial.TrialState.FAIL] |
|
), |
|
}, |
|
} |
|
|
|
with open( |
|
output_dir / f"best_results_{config.method}_{config.study_name}.json", "w" |
|
) as f: |
|
json.dump(best_results, f, indent=2) |
|
|
|
|
|
trials_data = [] |
|
for trial in study.trials: |
|
trial_data = { |
|
"number": trial.number, |
|
"value": trial.value, |
|
"params": trial.params, |
|
"user_attrs": trial.user_attrs, |
|
"state": trial.state.name, |
|
"datetime_start": trial.datetime_start.isoformat() |
|
if trial.datetime_start |
|
else None, |
|
"datetime_complete": trial.datetime_complete.isoformat() |
|
if trial.datetime_complete |
|
else None, |
|
} |
|
trials_data.append(trial_data) |
|
|
|
with open( |
|
output_dir / f"all_trials_{config.method}_{config.study_name}.json", "w" |
|
) as f: |
|
json.dump(trials_data, f, indent=2) |
|
|
|
|
|
log.success("Optimization completed!") |
|
log.info(f"Best {config.objective_metric}: {best_trial.value:.4f}") |
|
log.info("Best parameters:") |
|
for key, value in best_trial.params.items(): |
|
log.info(f" {key}: {value}") |
|
|
|
|
|
stats = best_results["study_stats"] |
|
log.info("Study statistics:") |
|
log.info(f" Total trials: {stats['n_trials']}") |
|
log.info(f" Complete trials: {stats['n_complete_trials']}") |
|
log.info(f" Pruned trials: {stats['n_pruned_trials']}") |
|
log.info(f" Failed trials: {stats['n_failed_trials']}") |
|
|
|
return study |
|
|
|
|
|
def main(): |
|
"""Main function for running hyperparameter optimization.""" |
|
config = tyro.cli(SweeperConfig) |
|
|
|
|
|
required_paths = [ |
|
(config.input_image_dir, "Input image directory"), |
|
(config.roi_folder, "ROI folder"), |
|
(config.reference_folder, "Reference folder"), |
|
] |
|
|
|
for path, description in required_paths: |
|
if not Path(path).exists(): |
|
raise FileNotFoundError(f"{description} not found: {path}") |
|
|
|
|
|
zea.visualize.set_mpl_style() |
|
|
|
|
|
study = run_optimization(config) |
|
|
|
|
|
try: |
|
import matplotlib.pyplot as plt |
|
import optuna.visualization as vis |
|
|
|
output_dir = Path(config.output_dir) |
|
|
|
|
|
fig = vis.matplotlib.plot_optimization_history(study).figure |
|
fig.savefig( |
|
output_dir / f"optimization_history_{config.method}.png", |
|
dpi=300, |
|
bbox_inches="tight", |
|
) |
|
plt.close(fig) |
|
|
|
|
|
fig = vis.matplotlib.plot_param_importances(study).figure |
|
fig.savefig( |
|
output_dir / f"param_importances_{config.method}.png", |
|
dpi=300, |
|
bbox_inches="tight", |
|
) |
|
plt.close(fig) |
|
|
|
|
|
fig = vis.matplotlib.plot_parallel_coordinate(study).figure |
|
fig.savefig( |
|
output_dir / f"parallel_coordinate_{config.method}.png", |
|
dpi=300, |
|
bbox_inches="tight", |
|
) |
|
plt.close(fig) |
|
|
|
log.success(f"Optimization plots saved to {output_dir}") |
|
|
|
except ImportError: |
|
log.warning( |
|
"Optuna visualization not available. Install with: pip install optuna[visualization]" |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|