# --------------------------------------------------------------- # This file includes code adapted from: # https://github.com/jbdel/RadEval/blob/null-hypothesis/utils.py # Original author: Justin Xu # --------------------------------------------------------------- import re import nltk import random from typing import List, Dict, Tuple, Callable, Optional from collections import defaultdict nltk.download("punkt_tab", quiet=True) def clean_numbered_list(text): """ Clean a report if it's a numbered list by: 1. Adding proper spacing between numbered items 2. Removing the numbered list markers 3. Adding spaces after periods between sentences """ # First, separate numbered items that are stuck together without spaces # Example: "textx.2. text2" -> "texty. 2. text2" text = re.sub(r'\.(\d+\.)', r'. \1', text) # Handle patterns where there's no period between numbered entries # Example: "1. item1 2. item2" -> "1. item1. 2. item2" text = re.sub(r'(\d+\.\s*[^.]+?)\s+(?=\d+\.)', r'\1. ', text) # Then remove the numbered list markers # But avoid removing decimal numbers in measurements like "3.5 cm" text = re.sub(r'(? "sentence1. sentence2" # But don't split decimal numbers like "3.5 cm" text = re.sub(r'\.([A-Za-z])', r'. \1', text) return nltk.sent_tokenize(text) class PairedTest: """ Paired significance testing for comparing radiology report generation systems. Supports paired approximate randomization (AR). """ def __init__(self, systems: Dict[str, List[str]], metrics: Dict[str, Callable], references: Optional[List[str]], n_samples: int = 10000, n_jobs: int = 1, seed: int = 12345): """ Args: systems: Dictionary mapping system names to their generated reports metrics: Dictionary mapping metric names to metric functions references: List of reference reports n_samples: Number of resampling trials (default: 10000) n_jobs: Number of parallel jobs (default: 1) seed: Random seed for reproducibility """ self.systems = systems self.metrics = metrics self.references = references self.n_samples = n_samples self.n_jobs = n_jobs self.seed = seed random.seed(seed) if not systems: raise ValueError("At least one system is required") system_lengths = [len(outputs) for outputs in systems.values()] if len(set(system_lengths)) > 1: raise ValueError("All systems must have the same number of outputs") if references and len(references) != system_lengths[0]: raise ValueError("References must have same length as system outputs") self.n_instances = system_lengths[0] def __call__(self) -> Tuple[Dict[str, str], Dict[str, Dict[str, float]]]: """ Run the paired significance test. Returns: Tuple of (signatures, scores) where: - signatures: Dict mapping metric names to signature strings - scores: Dict mapping system names to metric scores and p-values """ # Calculate baseline scores for all systems and metrics baseline_scores = self._calculate_baseline_scores() # Get baseline system (first system) baseline_name = list(self.systems.keys())[0] scores = {} signatures = {} # Calculate scores and p-values for each system for system_name in self.systems.keys(): scores[system_name] = {} for metric_name in self.metrics.keys(): score = baseline_scores[system_name][metric_name] scores[system_name][metric_name] = score if system_name != baseline_name: p_value = self._calculate_p_value( baseline_name, system_name, metric_name, baseline_scores ) scores[system_name][f'{metric_name}_pvalue'] = p_value for metric_name in self.metrics.keys(): signatures[metric_name] = f"{metric_name}|{'ar'}:{self.n_samples}|seed:{self.seed}" return signatures, scores def _calculate_baseline_scores(self) -> Dict[str, Dict[str, float]]: """Calculate baseline scores for all systems and metrics.""" scores = defaultdict(dict) for system_name, outputs in self.systems.items(): for metric_name, metric_func in self.metrics.items(): if self.references: score = metric_func(outputs, self.references) else: score = metric_func(outputs) if isinstance(score, dict): if 'score' in score: scores[system_name][metric_name] = score['score'] else: scores[system_name][metric_name] = list(score.values())[0] elif isinstance(score, (tuple, list)): scores[system_name][metric_name] = score[0] else: scores[system_name][metric_name] = score return scores def _calculate_p_value(self, baseline_name: str, system_name: str, metric_name: str, baseline_scores: Dict[str, Dict[str, float]]) -> float: """Calculate p-value using AR test""" baseline_outputs = self.systems[baseline_name] system_outputs = self.systems[system_name] metric_func = self.metrics[metric_name] baseline_score = baseline_scores[baseline_name][metric_name] system_score = baseline_scores[system_name][metric_name] original_delta = abs(system_score - baseline_score) return self._approximate_randomization_test( baseline_outputs, system_outputs, metric_func, original_delta ) def _approximate_randomization_test(self, baseline_outputs: List[str], system_outputs: List[str], metric_func: Callable, original_delta: float) -> float: """ Perform AR test. For each trial, randomly swap outputs between systems and calculate the score difference. P-value is the proportion of trials where the randomized delta >= original delta. """ count_greater = 0 for _ in range(self.n_samples): randomized_baseline = [] randomized_system = [] for i in range(self.n_instances): if random.random() < 0.5: # Don't swap randomized_baseline.append(baseline_outputs[i]) randomized_system.append(system_outputs[i]) else: # Swap randomized_baseline.append(system_outputs[i]) randomized_system.append(baseline_outputs[i]) if self.references: rand_baseline_score = metric_func(randomized_baseline, self.references) rand_system_score = metric_func(randomized_system, self.references) else: rand_baseline_score = metric_func(randomized_baseline) rand_system_score = metric_func(randomized_system) if isinstance(rand_baseline_score, dict): rand_baseline_score = rand_baseline_score.get('score', list(rand_baseline_score.values())[0]) elif isinstance(rand_baseline_score, (tuple, list)): rand_baseline_score = rand_baseline_score[0] if isinstance(rand_system_score, dict): rand_system_score = rand_system_score.get('score', list(rand_system_score.values())[0]) elif isinstance(rand_system_score, (tuple, list)): rand_system_score = rand_system_score[0] rand_delta = abs(rand_system_score - rand_baseline_score) if rand_delta >= original_delta: count_greater += 1 return count_greater / self.n_samples def print_significance_results(scores: Dict[str, Dict[str, float]], signatures: Dict[str, str], baseline_name: str, significance_level: float = 0.05): """ Args: scores: Dictionary of system scores and p-values signatures: Dictionary of metric signatures baseline_name: Name of the baseline system significance_level: Significance threshold (default: 0.05) """ assert baseline_name in scores, f"Baseline system '{baseline_name}' not found in scores." metric_names = [name for name in signatures.keys()] system_names = list(scores.keys()) print("=" * 80) print("PAIRED SIGNIFICANCE TEST RESULTS") print("=" * 80) header = f"{'System':<40}" for metric in metric_names: header += f"{metric:>15}" print(header) print("-" * len(header)) baseline_row = f"Baseline: {baseline_name:<32}" for metric in metric_names: score = scores[baseline_name][metric] baseline_row += f"{score:>12.4f} " print(baseline_row) print("-" * len(header)) for system_name in system_names: if system_name == baseline_name: continue system_row = f"{system_name:<40}" for metric in metric_names: score = scores[system_name].get(metric, 0.0) if isinstance(score, float): system_row += f"{score:>12.4f} " else: system_row += f"{str(score):>12} " print(system_row) # P-value row pvalue_row = " " * 40 for metric in metric_names: pvalue_key = f"{metric}_pvalue" if pvalue_key in scores[system_name]: p_val = scores[system_name][pvalue_key] significance_marker = "*" if p_val < significance_level else "" pvalue_row += f"(p={p_val:.4f}){significance_marker:<2}".rjust(15) else: pvalue_row += " " * 15 print(pvalue_row) print("-" * len(header)) # Footer print(f"- Significance level: {significance_level}") print("- '*' indicates significant difference (p < significance level)") print("- Null hypothesis: systems are essentially the same") print("- Significant results suggest systems are meaningfully different\n") print("METRIC SIGNATURES:") for metric, signature in signatures.items(): print(f"- {metric}: {signature}") def compare_systems(systems: Dict[str, List[str]], metrics: Dict[str, Callable], references: Optional[List[str]] = None, n_samples: int = 10000, significance_level: float = 0.05, seed: int = 12345, print_results: bool = True) -> Tuple[Dict[str, str], Dict[str, Dict[str, float]]]: """ Args: systems: Dictionary mapping system names to their generated reports metrics: Dictionary mapping metric names to metric functions references: Optional list of reference reports n_samples: Number of resampling trials significance_level: Significance threshold for printing results seed: Random seed for reproducibility print_results: Whether to print formatted results Returns: Tuple of (signatures, scores) Example: ```python systems = { 'baseline_model': baseline_reports, 'new_model': new_model_reports, 'other_model': other_model_reports } metrics = { 'bleu': lambda hyp, ref: bleu_score(hyp, ref), 'rouge': lambda hyp, ref: rouge_score(hyp, ref), 'bertscore': lambda hyp, ref: bert_score(hyp, ref) 'custom_metric': lambda hyp, ref: custom_metric(hyp, ref) } signatures, scores = compare_systems( systems, metrics, references, n_samples=10000 ) ``` """ paired_test = PairedTest( systems=systems, metrics=metrics, references=references, n_samples=n_samples, seed=seed ) signatures, scores = paired_test() if print_results: baseline_name = list(systems.keys())[0] print_significance_results(scores, signatures, baseline_name, significance_level) return signatures, scores