initial
Browse files- .gitattributes +1 -0
- RadEval.py +424 -0
- RadEval_banner.png +3 -0
- __init__.py +2 -0
- factual/RaTEScore/__init__.py +2 -0
- factual/RaTEScore/score.py +83 -0
- factual/RaTEScore/scorer.py +146 -0
- factual/RaTEScore/utils.py +143 -0
- factual/RadCliQv1/radcliq.py +213 -0
- factual/RadCliQv1/radcliq_bertscore.py +10 -0
- factual/RadCliQv1/radcliq_radgraph.py +80 -0
- factual/RadCliQv1/semb_score.py +74 -0
- factual/SRRBert/leaves_mapping.json +58 -0
- factual/SRRBert/leaves_with_statuses_mapping.json +165 -0
- factual/SRRBert/srr_bert.py +160 -0
- factual/SRRBert/upper_mapping.json +28 -0
- factual/SRRBert/upper_with_statuses_mapping.json +76 -0
- factual/__init__.py +0 -0
- factual/f1chexbert.py +254 -0
- factual/f1temporal.py +167 -0
- factual/green_score/__init__.py +1 -0
- factual/green_score/green.py +465 -0
- factual/green_score/utils.py +200 -0
- nlg/__init__.py +0 -0
- nlg/bertscore/__init__.py +1 -0
- nlg/bertscore/bertscore.py +50 -0
- nlg/bleu/__init__.py +1 -0
- nlg/bleu/bleu.py +49 -0
- nlg/bleu/bleu_scorer.py +268 -0
- nlg/radevalbertscore.py +53 -0
- nlg/rouge/rouge.py +37 -0
- utils.py +341 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
RadEval_banner.png filter=lfs diff=lfs merge=lfs -text
|
RadEval.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import stanza
|
3 |
+
import warnings
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
from nlg.rouge.rouge import Rouge
|
8 |
+
from nlg.bleu.bleu import Bleu
|
9 |
+
from nlg.bertscore.bertscore import BertScore
|
10 |
+
from radgraph import F1RadGraph
|
11 |
+
from factual.green_score import GREEN
|
12 |
+
from factual.RaTEScore import RaTEScore
|
13 |
+
from factual.f1temporal import F1Temporal
|
14 |
+
from torch import nn
|
15 |
+
import pandas as pd
|
16 |
+
import numpy as np
|
17 |
+
from sklearn.metrics import classification_report
|
18 |
+
from sklearn.exceptions import UndefinedMetricWarning
|
19 |
+
import json
|
20 |
+
from factual.f1chexbert import F1CheXbert
|
21 |
+
import nltk
|
22 |
+
from utils import clean_numbered_list
|
23 |
+
from factual.RadCliQv1.radcliq import CompositeMetric
|
24 |
+
from factual.SRRBert.srr_bert import SRRBert, srr_bert_parse_sentences
|
25 |
+
from nlg.radevalbertscore import RadEvalBERTScorer
|
26 |
+
# Suppress Warning
|
27 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
28 |
+
warnings.filterwarnings('ignore')
|
29 |
+
logging.basicConfig(level=logging.ERROR)
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
|
34 |
+
class RadEval():
|
35 |
+
def __init__(self,
|
36 |
+
do_radgraph=False,
|
37 |
+
do_green=False,
|
38 |
+
do_bleu=False,
|
39 |
+
do_rouge=False,
|
40 |
+
do_bertscore=False,
|
41 |
+
do_srr_bert=False,
|
42 |
+
do_chexbert=False,
|
43 |
+
do_ratescore=False,
|
44 |
+
do_radcliq=False,
|
45 |
+
do_radeval_bertsore=False,
|
46 |
+
do_temporal=False,
|
47 |
+
do_details=False,
|
48 |
+
):
|
49 |
+
super(RadEval, self).__init__()
|
50 |
+
|
51 |
+
self.do_radgraph = do_radgraph
|
52 |
+
self.do_green = do_green
|
53 |
+
self.do_bleu = do_bleu
|
54 |
+
self.do_rouge = do_rouge
|
55 |
+
self.do_bertscore = do_bertscore
|
56 |
+
self.do_srr_bert = do_srr_bert
|
57 |
+
self.do_chexbert = do_chexbert
|
58 |
+
self.do_ratescore = do_ratescore
|
59 |
+
self.do_radcliq = do_radcliq
|
60 |
+
self.do_temporal = do_temporal
|
61 |
+
self.do_radeval_bertsore = do_radeval_bertsore
|
62 |
+
self.do_details = do_details
|
63 |
+
|
64 |
+
# Initialize scorers only once
|
65 |
+
if self.do_radgraph:
|
66 |
+
self.radgraph_scorer = F1RadGraph(reward_level="all", model_type="radgraph-xl")
|
67 |
+
if self.do_bleu:
|
68 |
+
self.bleu_scorer = Bleu()
|
69 |
+
self.bleu_scorer_1 = Bleu(n=1)
|
70 |
+
self.bleu_scorer_2 = Bleu(n=2)
|
71 |
+
self.bleu_scorer_3 = Bleu(n=3)
|
72 |
+
if self.do_bertscore:
|
73 |
+
self.bertscore_scorer = BertScore(model_type='distilbert-base-uncased',
|
74 |
+
num_layers=5)
|
75 |
+
if self.do_green:
|
76 |
+
# Initialize green scorer here if needed
|
77 |
+
self.green_scorer = GREEN("StanfordAIMI/GREEN-radllama2-7b",
|
78 |
+
output_dir=".")
|
79 |
+
|
80 |
+
if self.do_rouge:
|
81 |
+
self.rouge_scorers = {
|
82 |
+
"rouge1": Rouge(rouges=["rouge1"]),
|
83 |
+
"rouge2": Rouge(rouges=["rouge2"]),
|
84 |
+
"rougeL": Rouge(rouges=["rougeL"])
|
85 |
+
}
|
86 |
+
|
87 |
+
if self.do_srr_bert:
|
88 |
+
nltk.download('punkt_tab', quiet=True)
|
89 |
+
self.srr_bert_scorer = SRRBert(model_type="leaves_with_statuses")
|
90 |
+
|
91 |
+
|
92 |
+
if self.do_chexbert:
|
93 |
+
self.chexbert_scorer = F1CheXbert()
|
94 |
+
|
95 |
+
if self.do_ratescore:
|
96 |
+
self.ratescore_scorer = RaTEScore()
|
97 |
+
|
98 |
+
if self.do_radcliq:
|
99 |
+
self.radcliq_scorer = CompositeMetric()
|
100 |
+
|
101 |
+
if self.do_temporal:
|
102 |
+
stanza.download('en', package='radiology', processors={'ner': 'radiology'})
|
103 |
+
self.F1Temporal = F1Temporal
|
104 |
+
|
105 |
+
if self.do_radeval_bertsore:
|
106 |
+
self.radeval_bertsore = RadEvalBERTScorer(
|
107 |
+
model_type="IAMJB/RadEvalModernBERT",
|
108 |
+
num_layers=22,
|
109 |
+
use_fast_tokenizer=True,
|
110 |
+
rescale_with_baseline=False)
|
111 |
+
# Store the metric keys
|
112 |
+
self.metric_keys = []
|
113 |
+
if self.do_radgraph:
|
114 |
+
self.metric_keys.extend(["radgraph_simple", "radgraph_partial", "radgraph_complete"])
|
115 |
+
if self.do_bleu:
|
116 |
+
self.metric_keys.append("bleu")
|
117 |
+
if self.do_green:
|
118 |
+
self.metric_keys.append("green")
|
119 |
+
if self.do_bertscore:
|
120 |
+
self.metric_keys.append("bertscore")
|
121 |
+
if self.do_rouge:
|
122 |
+
self.metric_keys.extend(self.rouge_scorers.keys())
|
123 |
+
if self.do_srr_bert:
|
124 |
+
self.metric_keys.extend(["samples_avg_precision", "samples_avg_recall", "samples_avg_f1-score"])
|
125 |
+
|
126 |
+
if self.do_chexbert:
|
127 |
+
self.metric_keys.extend([
|
128 |
+
"chexbert-5_micro avg_f1-score",
|
129 |
+
"chexbert-all_micro avg_f1-score",
|
130 |
+
"chexbert-5_macro avg_f1-score",
|
131 |
+
"chexbert-all_macro avg_f1-score"
|
132 |
+
])
|
133 |
+
|
134 |
+
if self.do_ratescore:
|
135 |
+
self.metric_keys.append("ratescore")
|
136 |
+
if self.do_radcliq:
|
137 |
+
self.metric_keys.append("radcliqv1")
|
138 |
+
if self.do_temporal:
|
139 |
+
self.metric_keys.append("temporal_f1")
|
140 |
+
if self.do_radeval_bertsore:
|
141 |
+
self.metric_keys.append("radeval_bertsore")
|
142 |
+
|
143 |
+
def __call__(self, refs, hyps):
|
144 |
+
if not (isinstance(hyps, list) and isinstance(refs, list)):
|
145 |
+
raise TypeError("hyps and refs must be of type list")
|
146 |
+
if len(hyps) != len(refs):
|
147 |
+
raise ValueError("hyps and refs lists don't have the same size")
|
148 |
+
if len(refs) == 0:
|
149 |
+
return {}
|
150 |
+
|
151 |
+
scores = self.compute_scores(refs=refs, hyps=hyps)
|
152 |
+
return scores
|
153 |
+
|
154 |
+
def compute_scores(self, refs, hyps):
|
155 |
+
if not (isinstance(hyps, list) and isinstance(refs, list)):
|
156 |
+
raise TypeError("hyps and refs must be of type list")
|
157 |
+
if len(hyps) != len(refs):
|
158 |
+
raise ValueError("hyps and refs lists don't have the same size")
|
159 |
+
|
160 |
+
scores = {}
|
161 |
+
if self.do_radgraph:
|
162 |
+
radgraph_scores = self.radgraph_scorer(refs=refs, hyps=hyps)
|
163 |
+
|
164 |
+
if self.do_details:
|
165 |
+
f1_scores = radgraph_scores[0]
|
166 |
+
individual_scores = radgraph_scores[1]
|
167 |
+
hyps_entities = radgraph_scores[2]
|
168 |
+
refs_entities = radgraph_scores[3]
|
169 |
+
|
170 |
+
scores["radgraph"] = {
|
171 |
+
"radgraph_simple": f1_scores[0],
|
172 |
+
"radgraph_partial": f1_scores[1],
|
173 |
+
"radgraph_complete": f1_scores[2],
|
174 |
+
"reward_list": individual_scores,
|
175 |
+
"hypothesis_annotation_lists": hyps_entities,
|
176 |
+
"reference_annotation_lists": refs_entities
|
177 |
+
}
|
178 |
+
|
179 |
+
else:
|
180 |
+
radgraph_scores = radgraph_scores[0]
|
181 |
+
scores["radgraph_simple"] = radgraph_scores[0]
|
182 |
+
scores["radgraph_partial"] = radgraph_scores[1]
|
183 |
+
scores["radgraph_complete"] = radgraph_scores[2]
|
184 |
+
|
185 |
+
if self.do_bleu:
|
186 |
+
if self.do_details:
|
187 |
+
bleu_1_score = self.bleu_scorer_1(refs, hyps)[0]
|
188 |
+
bleu_2_score = self.bleu_scorer_2(refs, hyps)[0]
|
189 |
+
bleu_3_score = self.bleu_scorer_3(refs, hyps)[0]
|
190 |
+
bleu_4_score = self.bleu_scorer(refs, hyps)[0]
|
191 |
+
|
192 |
+
scores["bleu"] = {
|
193 |
+
"bleu_1": bleu_1_score,
|
194 |
+
"bleu_2": bleu_2_score,
|
195 |
+
"bleu_3": bleu_3_score,
|
196 |
+
"bleu_4": bleu_4_score
|
197 |
+
}
|
198 |
+
else:
|
199 |
+
scores["bleu"] = self.bleu_scorer(refs, hyps)[0]
|
200 |
+
|
201 |
+
if self.do_bertscore:
|
202 |
+
if self.do_details:
|
203 |
+
bertscore_scores, sample_scores = self.bertscore_scorer(refs, hyps)
|
204 |
+
scores["bertscore"] = {
|
205 |
+
"mean_score": bertscore_scores,
|
206 |
+
"sample_scores": sample_scores
|
207 |
+
}
|
208 |
+
else:
|
209 |
+
scores["bertscore"] = self.bertscore_scorer(refs, hyps)[0]
|
210 |
+
|
211 |
+
if self.do_green:
|
212 |
+
# Use the initialized green scorer
|
213 |
+
mean, std, sample_scores, summary, _ = self.green_scorer(refs, hyps)
|
214 |
+
if self.do_details:
|
215 |
+
scores["green"] = {
|
216 |
+
"mean": mean,
|
217 |
+
"std": std,
|
218 |
+
"sample_scores": sample_scores,
|
219 |
+
"summary": summary
|
220 |
+
}
|
221 |
+
else:
|
222 |
+
scores["green"] = mean
|
223 |
+
|
224 |
+
if self.do_rouge:
|
225 |
+
if self.do_details:
|
226 |
+
rouge_scores = {}
|
227 |
+
for key, scorer in self.rouge_scorers.items():
|
228 |
+
mean, sample_scores = scorer(refs, hyps)
|
229 |
+
rouge_scores[key] = {
|
230 |
+
"mean_score": mean,
|
231 |
+
"sample_scores": sample_scores
|
232 |
+
}
|
233 |
+
|
234 |
+
scores["rouge"] = rouge_scores
|
235 |
+
else:
|
236 |
+
for key, scorer in self.rouge_scorers.items():
|
237 |
+
scores[key] = scorer(refs, hyps)[0]
|
238 |
+
|
239 |
+
if self.do_srr_bert:
|
240 |
+
# Clean reports before tokenization
|
241 |
+
parsed_refs = [srr_bert_parse_sentences(ref) for ref in refs]
|
242 |
+
parsed_hyps = [srr_bert_parse_sentences(hyp) for hyp in hyps]
|
243 |
+
|
244 |
+
|
245 |
+
section_level_hyps_pred = []
|
246 |
+
section_level_refs_pred = []
|
247 |
+
for parsed_hyp, parsed_ref in zip(parsed_hyps, parsed_refs):
|
248 |
+
outputs, _ = self.srr_bert_scorer(sentences=parsed_ref + parsed_hyp)
|
249 |
+
|
250 |
+
refs_preds = outputs[:len(parsed_ref)]
|
251 |
+
hyps_preds = outputs[len(parsed_ref):]
|
252 |
+
|
253 |
+
merged_refs_preds = np.any(refs_preds, axis=0).astype(int)
|
254 |
+
merged_hyps_preds = np.any(hyps_preds, axis=0).astype(int)
|
255 |
+
|
256 |
+
section_level_hyps_pred.append(merged_hyps_preds)
|
257 |
+
section_level_refs_pred.append(merged_refs_preds)
|
258 |
+
|
259 |
+
label_names = [label for label, idx in sorted(self.srr_bert_scorer.mapping.items(), key=lambda x: x[1])]
|
260 |
+
classification_dict = classification_report(section_level_refs_pred,
|
261 |
+
section_level_hyps_pred,
|
262 |
+
target_names=label_names,
|
263 |
+
output_dict=True,
|
264 |
+
zero_division=0)
|
265 |
+
|
266 |
+
if self.do_details:
|
267 |
+
label_scores = {}
|
268 |
+
for label in label_names:
|
269 |
+
if label in classification_dict:
|
270 |
+
f1 = classification_dict[label]["f1-score"]
|
271 |
+
support = classification_dict[label]["support"]
|
272 |
+
if f1 > 0 or support > 0:
|
273 |
+
label_scores[label] = {
|
274 |
+
"f1-score": f1,
|
275 |
+
"precision": classification_dict[label]["precision"],
|
276 |
+
"recall": classification_dict[label]["recall"],
|
277 |
+
"support": support
|
278 |
+
}
|
279 |
+
|
280 |
+
scores["srr_bert"] = {
|
281 |
+
"srr_bert_weighted_f1": classification_dict["weighted avg"]["f1-score"],
|
282 |
+
"srr_bert_weighted_precision": classification_dict["weighted avg"]["precision"],
|
283 |
+
"srr_bert_weighted_recall": classification_dict["weighted avg"]["recall"],
|
284 |
+
"label_scores": label_scores
|
285 |
+
}
|
286 |
+
else:
|
287 |
+
scores["srr_bert_weighted_f1"] = classification_dict["weighted avg"]["f1-score"]
|
288 |
+
scores["srr_bert_weighted_precision"] = classification_dict["weighted avg"]["precision"]
|
289 |
+
scores["srr_bert_weighted_recall"] = classification_dict["weighted avg"]["recall"]
|
290 |
+
|
291 |
+
|
292 |
+
|
293 |
+
if self.do_chexbert:
|
294 |
+
accuracy, accuracy_per_sample, chexbert_all, chexbert_5 = self.chexbert_scorer(hyps, refs)
|
295 |
+
if self.do_details:
|
296 |
+
chexbert_5_labels = {
|
297 |
+
k: v["f1-score"]
|
298 |
+
for k, v in list(chexbert_5.items())[:-4]
|
299 |
+
}
|
300 |
+
|
301 |
+
chexbert_all_labels = {
|
302 |
+
k: v["f1-score"]
|
303 |
+
for k, v in list(chexbert_all.items())[:-4]
|
304 |
+
}
|
305 |
+
|
306 |
+
scores["chexbert"] = {
|
307 |
+
"chexbert-5_micro avg_f1-score": chexbert_5["micro avg"]["f1-score"],
|
308 |
+
"chexbert-all_micro avg_f1-score": chexbert_all["micro avg"]["f1-score"],
|
309 |
+
"chexbert-5_macro avg_f1-score": chexbert_5["macro avg"]["f1-score"],
|
310 |
+
"chexbert-all_macro avg_f1-score": chexbert_all["macro avg"]["f1-score"],
|
311 |
+
"chexbert-5_weighted_f1": chexbert_5["weighted avg"]["f1-score"],
|
312 |
+
"chexbert-all_weighted_f1": chexbert_all["weighted avg"]["f1-score"],
|
313 |
+
"label_scores_f1-score": {
|
314 |
+
"chexbert-5": chexbert_5_labels,
|
315 |
+
"chexbert_all": chexbert_all_labels
|
316 |
+
}
|
317 |
+
}
|
318 |
+
else:
|
319 |
+
scores["chexbert-5_micro avg_f1-score"] = chexbert_5["micro avg"]["f1-score"]
|
320 |
+
scores["chexbert-all_micro avg_f1-score"] = chexbert_all["micro avg"]["f1-score"]
|
321 |
+
scores["chexbert-5_macro avg_f1-score"] = chexbert_5["macro avg"]["f1-score"]
|
322 |
+
scores["chexbert-all_macro avg_f1-score"] = chexbert_all["macro avg"]["f1-score"]
|
323 |
+
scores["chexbert-5_weighted_f1"] = chexbert_5["weighted avg"]["f1-score"]
|
324 |
+
scores["chexbert-all_weighted_f1"] = chexbert_all["weighted avg"]["f1-score"]
|
325 |
+
|
326 |
+
if self.do_ratescore:
|
327 |
+
rate_score, pred_pairs_raw ,gt_pairs_raw = self.ratescore_scorer.compute_score(candidate_list=hyps, reference_list=refs)
|
328 |
+
f1_ratescore = float(np.mean(rate_score))
|
329 |
+
if self.do_details:
|
330 |
+
pred_pairs = [
|
331 |
+
{ent: label for ent, label in sample}
|
332 |
+
for sample in pred_pairs_raw
|
333 |
+
]
|
334 |
+
gt_pairs = [
|
335 |
+
{ent: label for ent, label in sample}
|
336 |
+
for sample in gt_pairs_raw
|
337 |
+
]
|
338 |
+
scores["ratescore"] = {
|
339 |
+
"f1-score": f1_ratescore,
|
340 |
+
"hyps_pairs": pred_pairs,
|
341 |
+
"refs_pairs": gt_pairs
|
342 |
+
}
|
343 |
+
else:
|
344 |
+
scores["ratescore"] = f1_ratescore
|
345 |
+
|
346 |
+
if self.do_radcliq:
|
347 |
+
mean_scores, detail_scores = self.radcliq_scorer.predict(refs, hyps)
|
348 |
+
if self.do_details:
|
349 |
+
scores["radcliq-v1"] = {
|
350 |
+
"mean_score": mean_scores,
|
351 |
+
"sample_scores": detail_scores.tolist()
|
352 |
+
}
|
353 |
+
else:
|
354 |
+
scores["radcliq-v1"] = mean_scores
|
355 |
+
|
356 |
+
if self.do_temporal:
|
357 |
+
temporal_scores = self.F1Temporal(predictions=hyps, references=refs)
|
358 |
+
if self.do_details:
|
359 |
+
hyp_entities = [
|
360 |
+
sorted(list(group)) if group else []
|
361 |
+
for group in temporal_scores.get("prediction_entities", [])
|
362 |
+
]
|
363 |
+
ref_entities = [
|
364 |
+
sorted(list(group)) if group else []
|
365 |
+
for group in temporal_scores.get("reference_entities", [])
|
366 |
+
]
|
367 |
+
scores["temporal_f1"] = {
|
368 |
+
"f1-score": temporal_scores["f1"],
|
369 |
+
"hyps_entities": hyp_entities,
|
370 |
+
"refs_entities": ref_entities
|
371 |
+
}
|
372 |
+
else:
|
373 |
+
scores["temporal_f1"] = temporal_scores["f1"]
|
374 |
+
|
375 |
+
if self.do_radeval_bertsore:
|
376 |
+
radeval_bertsores = self.radeval_bertsore.score(refs=refs, hyps=hyps)
|
377 |
+
if self.do_details:
|
378 |
+
scores["radeval_bertsore"] = {
|
379 |
+
"f1-score": radeval_bertsores[0],
|
380 |
+
"sample_scores": radeval_bertsores[1].tolist()
|
381 |
+
}
|
382 |
+
else:
|
383 |
+
scores["radeval_bertsore"] = radeval_bertsores[0]
|
384 |
+
|
385 |
+
return scores
|
386 |
+
|
387 |
+
|
388 |
+
def main():
|
389 |
+
refs = [
|
390 |
+
"No acute cardiopulmonary process.",
|
391 |
+
"No radiographic findings to suggest pneumonia.",
|
392 |
+
"1.Status post median sternotomy for CABG with stable cardiac enlargement and calcification of the aorta consistent with atherosclerosis.Relatively lower lung volumes with no focal airspace consolidation appreciated.Crowding of the pulmonary vasculature with possible minimal perihilar edema, but no overt pulmonary edema.No pleural effusions or pneumothoraces.",
|
393 |
+
"1. Left PICC tip appears to terminate in the distal left brachiocephalic vein.2. Mild pulmonary vascular congestion.3. Interval improvement in aeration of the lung bases with residual streaky opacity likely reflective of atelectasis.Interval resolution of the left pleural effusion.",
|
394 |
+
"No definite acute cardiopulmonary process.Enlarged cardiac silhouette could be accentuated by patient's positioning.",
|
395 |
+
"Increased mild pulmonary edema and left basal atelectasis.",
|
396 |
+
]
|
397 |
+
|
398 |
+
hyps = [
|
399 |
+
"No acute cardiopulmonary process.",
|
400 |
+
"No radiographic findings to suggest pneumonia.",
|
401 |
+
"Status post median sternotomy for CABG with stable cardiac enlargement and calcification of the aorta consistent with atherosclerosis.",
|
402 |
+
"Relatively lower lung volumes with no focal airspace consolidation appreciated.",
|
403 |
+
"Crowding of the pulmonary vasculature with possible minimal perihilar edema, but no overt pulmonary edema.",
|
404 |
+
"No pleural effusions or pneumothoraces.",
|
405 |
+
]
|
406 |
+
|
407 |
+
evaluator = RadEval(do_radgraph=True,
|
408 |
+
do_green=False,
|
409 |
+
do_bleu=True,
|
410 |
+
do_rouge=True,
|
411 |
+
do_bertscore=True,
|
412 |
+
do_srr_bert=True,
|
413 |
+
do_chexbert=True,
|
414 |
+
do_temporal=True,
|
415 |
+
do_ratescore=True,
|
416 |
+
do_radcliq=True,
|
417 |
+
do_radeval_bertsore=True)
|
418 |
+
|
419 |
+
results = evaluator(refs=refs, hyps=hyps)
|
420 |
+
print(json.dumps(results, indent=4))
|
421 |
+
|
422 |
+
|
423 |
+
if __name__ == '__main__':
|
424 |
+
main()
|
RadEval_banner.png
ADDED
![]() |
Git LFS Details
|
__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .RadEval import RadEval
|
2 |
+
from .utils import compare_systems
|
factual/RaTEScore/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .score import *
|
2 |
+
from .scorer import RaTEScore
|
factual/RaTEScore/score.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import medspacy
|
3 |
+
nlp = medspacy.load(medspacy_enable=["medspacy_pyrush", "medspacy_context"])
|
4 |
+
|
5 |
+
from .utils import sentence_split, post_process
|
6 |
+
|
7 |
+
def run_ner(texts, idx2label, tokenizer, model, device, batch_size):
|
8 |
+
|
9 |
+
clean_text_list, is_start_list = sentence_split(texts)
|
10 |
+
|
11 |
+
predicted_labels = []
|
12 |
+
|
13 |
+
for i in range(0, len(clean_text_list), batch_size):
|
14 |
+
batch_text = clean_text_list[i:i+batch_size]
|
15 |
+
|
16 |
+
inputs = tokenizer(batch_text,
|
17 |
+
max_length=512,
|
18 |
+
padding=True,
|
19 |
+
truncation=True,
|
20 |
+
return_tensors="pt").to(device)
|
21 |
+
|
22 |
+
with torch.no_grad():
|
23 |
+
outputs = model(**inputs)
|
24 |
+
|
25 |
+
predicted_labels.extend(torch.argmax(outputs.logits, dim=2).tolist())
|
26 |
+
|
27 |
+
inputs = tokenizer(clean_text_list,
|
28 |
+
max_length=512,
|
29 |
+
padding=True,
|
30 |
+
truncation=True,
|
31 |
+
return_tensors="pt")
|
32 |
+
|
33 |
+
save_pairs = []
|
34 |
+
|
35 |
+
pad_token_id = tokenizer.pad_token_id
|
36 |
+
|
37 |
+
for i, is_start in enumerate(is_start_list):
|
38 |
+
|
39 |
+
predicted_entities = [idx2label[label] for label in predicted_labels[i]]
|
40 |
+
|
41 |
+
non_pad_mask = inputs["input_ids"][i] != pad_token_id
|
42 |
+
non_pad_length = non_pad_mask.sum().item()
|
43 |
+
non_pad_input_ids = inputs["input_ids"][i][:non_pad_length]
|
44 |
+
|
45 |
+
tokenized_text = tokenizer.convert_ids_to_tokens(non_pad_input_ids)
|
46 |
+
|
47 |
+
if is_start:
|
48 |
+
save_pair = post_process(tokenized_text, predicted_entities, tokenizer)
|
49 |
+
else:
|
50 |
+
save_pair = post_process(tokenized_text, predicted_entities, tokenizer)
|
51 |
+
save_pairs[-1].extend(save_pair)
|
52 |
+
continue
|
53 |
+
|
54 |
+
save_pairs.append(save_pair)
|
55 |
+
|
56 |
+
return save_pairs
|
57 |
+
|
58 |
+
|
59 |
+
def process_embedding(pair, eval_tokenizer, eval_model, device):
|
60 |
+
entities = [pair[0] for pair in pair]
|
61 |
+
types = [pair[1] for pair in pair]
|
62 |
+
|
63 |
+
if len(entities) == 0:
|
64 |
+
embeds_word = torch.tensor([])
|
65 |
+
else:
|
66 |
+
embeds_word = torch.tensor([]).to(device)
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
# tokenize the queries
|
70 |
+
encoded = eval_tokenizer(
|
71 |
+
entities,
|
72 |
+
truncation=True,
|
73 |
+
padding=True,
|
74 |
+
return_tensors='pt',
|
75 |
+
max_length=30,
|
76 |
+
).to(device)
|
77 |
+
|
78 |
+
# encode the queries (use the [CLS] last hidden states as the representations)
|
79 |
+
embeds_word = torch.cat((embeds_word.to('cpu'),
|
80 |
+
eval_model(**encoded).last_hidden_state[:, 0, :].to('cpu')), dim=0)
|
81 |
+
|
82 |
+
return embeds_word, types
|
83 |
+
|
factual/RaTEScore/scorer.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForTokenClassification
|
5 |
+
import pandas as pd
|
6 |
+
import os
|
7 |
+
|
8 |
+
from .score import run_ner, process_embedding
|
9 |
+
from .utils import compute
|
10 |
+
|
11 |
+
|
12 |
+
DEFAULT_MATRIX_LONG = {"abnormality_abnormality": 0.4276119164393705, "abnormality_anatomy": 0.6240929990607657, "abnormality_disease": 0.0034478181112993847, "abnormality_non-abnormality": 0.5431049700217344, "abnormality_non-disease": 0.27005425386213877, "anatomy_abnormality": 0.7487824274337533, "anatomy_anatomy": 0.2856134859160784, "anatomy_disease": 0.4592143222158069, "anatomy_non-abnormality": 0.02097055139911715, "anatomy_non-disease": 0.00013736314126696204, "disease_abnormality": 0.8396510075734789, "disease_anatomy": 0.9950209388542061, "disease_disease": 0.8460555030578727, "disease_non-abnormality": 0.9820689020512646, "disease_non-disease": 0.3789136708096537, "non-abnormality_abnormality": 0.16546764653692908, "non-abnormality_anatomy": 0.018670610691852826, "non-abnormality_disease": 0.719397354576018, "non-abnormality_non-abnormality": 0.0009357166071730684, "non-abnormality_non-disease": 0.0927333564267591, "non-disease_abnormality": 0.7759420231214385, "non-disease_anatomy": 0.1839139293714062, "non-disease_disease": 0.10073046076318157, "non-disease_non-abnormality": 0.03860183811876373, "non-disease_non-disease": 0.34065681486566446, "neg_weight":0.8716553966489615}
|
13 |
+
DEFAULT_MATRIX_SHORT = {"abnormality_abnormality": 0.4070293318365468, "abnormality_anatomy": 0.6952639610605605, "abnormality_disease": 0.28342529466226446, "abnormality_non-abnormality": 0.9479148658006686, "abnormality_non-disease": 0.23875064111146294, "anatomy_abnormality": 0.5829759950441763, "anatomy_anatomy": 0.7709590751917746, "anatomy_disease": 0.0006059634829551632, "anatomy_non-abnormality": 0.794672584951181, "anatomy_non-disease": 0.27982942400798977, "disease_abnormality": 0.8840397619834857, "disease_anatomy": 0.9637659445696822, "disease_disease": 0.19018958438059513, "disease_non-abnormality": 0.6962283914800402, "disease_non-disease": 0.943727057946997, "non-abnormality_abnormality": 0.1712744286898638, "non-abnormality_anatomy": 0.4485149671497294, "non-abnormality_disease": 0.00045065329822896076, "non-abnormality_non-abnormality": 0.0007887930317199857, "non-abnormality_non-disease": 0.8555432840895761, "non-disease_abnormality": 0.9555801066212176, "non-disease_anatomy": 0.13122106162635216, "non-disease_disease": 0.6072996585919443, "non-disease_non-abnormality": 0.05650711141169969, "non-disease_non-disease": 0.3214769399791204, "neg_weight":0.3611577852354489}
|
14 |
+
|
15 |
+
|
16 |
+
class RaTEScore:
|
17 |
+
def __init__(self,
|
18 |
+
bert_model="Angelakeke/RaTE-NER-Deberta",
|
19 |
+
eval_model='FremyCompany/BioLORD-2023-C',
|
20 |
+
batch_size=1,
|
21 |
+
use_gpu=True,
|
22 |
+
visualization_path=None,
|
23 |
+
affinity_matrix="long",
|
24 |
+
):
|
25 |
+
""" RaTEScore is a novel, entity-aware metric to assess the quality of medical reports generated by AI models.
|
26 |
+
It emphasizes crucial medical entities such as diagnostic outcomes and anatomical details, and is robust
|
27 |
+
against complex medical synonyms and sensitive to negation expressions. The evaluations demonstrate that
|
28 |
+
RaTEScore aligns more closely with human preference than existing metrics.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
bert_model (str, optional): Medical entity recognition modul module. Defaults to "Angelakeke/RaTE-NER-Deberta".
|
32 |
+
eval_model (str, optional): Synonym disambuation encoding module. Defaults to 'FremyCompany/BioLORD-2023-C'.
|
33 |
+
batch_size (int, optional): Batch size to choose. Defaults to 1.
|
34 |
+
use_gpu (bool, optional): If to use gpu. Defaults to True.
|
35 |
+
visualization_path (str, optional): Output the visualized files, default to save as a json file. Defaults to None.
|
36 |
+
affinity_matrix (str, optional):pre-searched type weight and can be changed due to the human rating bias.
|
37 |
+
Defaults to 'long'.
|
38 |
+
|
39 |
+
"""
|
40 |
+
|
41 |
+
# if use_gpu
|
42 |
+
if use_gpu:
|
43 |
+
self.device = torch.device('cuda')
|
44 |
+
else:
|
45 |
+
self.device = torch.device('cpu')
|
46 |
+
|
47 |
+
# load the Medical entity recognition module
|
48 |
+
self.tokenizer = AutoTokenizer.from_pretrained(bert_model)
|
49 |
+
self.model = AutoModelForTokenClassification.from_pretrained(bert_model).eval().to(self.device)
|
50 |
+
|
51 |
+
# load the Synonym disambuation module
|
52 |
+
self.eval_tokenizer = AutoTokenizer.from_pretrained(eval_model)
|
53 |
+
self.eval_model = AutoModel.from_pretrained(eval_model).eval().to(self.device)
|
54 |
+
|
55 |
+
# load the weight matrix
|
56 |
+
if isinstance(affinity_matrix, str):
|
57 |
+
# Choose the appropriate matrix based on the argument
|
58 |
+
if affinity_matrix.lower() == "long":
|
59 |
+
self.matrix_path = DEFAULT_MATRIX_LONG
|
60 |
+
elif affinity_matrix.lower() == "short":
|
61 |
+
self.matrix_path = DEFAULT_MATRIX_SHORT
|
62 |
+
else:
|
63 |
+
# Assume it's a file path
|
64 |
+
try:
|
65 |
+
with open(affinity_matrix, 'r') as f:
|
66 |
+
self.matrix_path = json.load(f)
|
67 |
+
except Exception as e:
|
68 |
+
raise ValueError(f"Failed to load affinity matrix from {affinity_matrix}: {e}")
|
69 |
+
else:
|
70 |
+
raise ValueError("affinity_matrix must be a string")
|
71 |
+
|
72 |
+
self.affinity_matrix = {(k.split('_')[0].upper(), k.split('_')[1].upper()):v for k,v in self.matrix_path.items()}
|
73 |
+
|
74 |
+
# load the label file
|
75 |
+
self.config = AutoConfig.from_pretrained(bert_model)
|
76 |
+
self.label2idx = self.config.label2id
|
77 |
+
self.idx2label = self.config.id2label
|
78 |
+
|
79 |
+
# save the input
|
80 |
+
self.batch_size = batch_size
|
81 |
+
|
82 |
+
if visualization_path:
|
83 |
+
self.visualization_path = visualization_path
|
84 |
+
if not os.path.exists(os.path.dirname(visualization_path)):
|
85 |
+
os.makedirs(os.path.dirname(visualization_path))
|
86 |
+
else:
|
87 |
+
self.visualization_path = None
|
88 |
+
|
89 |
+
|
90 |
+
def compute_score(self, candidate_list, reference_list):
|
91 |
+
'''Compute the RaTEScore for the candidate and reference reports.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
candidate_list (list): list of candidate reports
|
95 |
+
reference_list (list): list of reference reports
|
96 |
+
'''
|
97 |
+
|
98 |
+
# check if candidate and reference are list
|
99 |
+
if not isinstance(candidate_list, list):
|
100 |
+
raise ValueError("candidate must be a list")
|
101 |
+
if not isinstance(reference_list, list):
|
102 |
+
raise ValueError("reference must be a list")
|
103 |
+
|
104 |
+
assert len(candidate_list) == len(reference_list), "candidate and reference must have the same length"
|
105 |
+
|
106 |
+
# check if candidate and reference are list of strings
|
107 |
+
if not all(isinstance(x, str) for x in candidate_list):
|
108 |
+
raise ValueError("candidate must be a list of strings")
|
109 |
+
|
110 |
+
gt_pairs = run_ner(reference_list, self.idx2label, self.tokenizer, self.model, self.device, self.batch_size)
|
111 |
+
pred_pairs = run_ner(candidate_list, self.idx2label, self.tokenizer, self.model, self.device, self.batch_size)
|
112 |
+
|
113 |
+
rate_score = []
|
114 |
+
|
115 |
+
for gt_pair, pred_pair in zip(gt_pairs, pred_pairs):
|
116 |
+
|
117 |
+
# process the embedding for gt
|
118 |
+
gt_embeds_word, gt_types = process_embedding(gt_pair, self.eval_tokenizer, self.eval_model, self.device)
|
119 |
+
|
120 |
+
# process the embedding for pred
|
121 |
+
pred_embeds_word, pred_types = process_embedding(pred_pair, self.eval_tokenizer, self.eval_model, self.device)
|
122 |
+
|
123 |
+
# compute the score, if the length of gt or pred is 0, the score is 0.5
|
124 |
+
if len(gt_embeds_word) == 0 or len(pred_embeds_word) == 0:
|
125 |
+
rate_score.append(0.5)
|
126 |
+
continue
|
127 |
+
|
128 |
+
precision_score = compute(gt_embeds_word, pred_embeds_word, gt_types, pred_types, self.affinity_matrix)
|
129 |
+
recall_score = compute(pred_embeds_word, gt_embeds_word, pred_types, gt_types, self.affinity_matrix)
|
130 |
+
|
131 |
+
if precision_score + recall_score == 0:
|
132 |
+
rate_score.append(0)
|
133 |
+
else:
|
134 |
+
rate_score.append(2*precision_score*recall_score/(precision_score+recall_score))
|
135 |
+
|
136 |
+
if self.visualization_path:
|
137 |
+
save_file = pd.DataFrame({
|
138 |
+
'candidate': candidate_list,
|
139 |
+
'reference': reference_list,
|
140 |
+
'candidate_entities': pred_pairs,
|
141 |
+
'reference_entities': gt_pairs,
|
142 |
+
'rate_score': rate_score
|
143 |
+
})
|
144 |
+
save_file.to_json(os.path.join(self.visualization_path, 'rate_score.json'), lines=True, orient='records')
|
145 |
+
|
146 |
+
return rate_score, pred_pairs ,gt_pairs
|
factual/RaTEScore/utils.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import medspacy
|
4 |
+
nlp = medspacy.load(medspacy_enable=["medspacy_pyrush", "medspacy_conte"])
|
5 |
+
|
6 |
+
def sentence_split(text_list):
|
7 |
+
"""
|
8 |
+
split sentences by medspacy
|
9 |
+
"""
|
10 |
+
clean_text_list = []
|
11 |
+
is_start_list = []
|
12 |
+
|
13 |
+
for text in text_list:
|
14 |
+
|
15 |
+
doc = nlp(text)
|
16 |
+
|
17 |
+
is_start = 1
|
18 |
+
|
19 |
+
for sent in doc.sents:
|
20 |
+
sent = str(sent).strip()
|
21 |
+
# # check if the sentence has no words
|
22 |
+
if len(sent.split()) == 0:
|
23 |
+
continue
|
24 |
+
if len(sent) < 3:
|
25 |
+
continue
|
26 |
+
is_start_list.append(is_start)
|
27 |
+
clean_text_list.append(sent)
|
28 |
+
is_start = 0
|
29 |
+
|
30 |
+
return clean_text_list, is_start_list
|
31 |
+
|
32 |
+
def post_process(tokenized_text, predicted_entities, tokenizer):
|
33 |
+
entity_spans = []
|
34 |
+
start = end = None
|
35 |
+
entity_type = None
|
36 |
+
|
37 |
+
for i, (token, label) in enumerate(zip(tokenized_text, predicted_entities[:len(tokenized_text)])):
|
38 |
+
if token in ["[CLS]", "[SEP]"]:
|
39 |
+
continue
|
40 |
+
if label != "O" and i < len(predicted_entities) - 1:
|
41 |
+
if label.startswith("B-") and predicted_entities[i+1].startswith("I-"):
|
42 |
+
start = i
|
43 |
+
entity_type = label[2:]
|
44 |
+
elif label.startswith("B-") and predicted_entities[i+1].startswith("B-"):
|
45 |
+
start = i
|
46 |
+
end = i
|
47 |
+
entity_spans.append((start, end, label[2:]))
|
48 |
+
start = i
|
49 |
+
entity_type = label[2:]
|
50 |
+
elif label.startswith("B-") and predicted_entities[i+1].startswith("O"):
|
51 |
+
start = i
|
52 |
+
end = i
|
53 |
+
entity_spans.append((start, end, label[2:]))
|
54 |
+
start = end = None
|
55 |
+
entity_type = None
|
56 |
+
elif label.startswith("I-") and predicted_entities[i+1].startswith("B-"):
|
57 |
+
end = i
|
58 |
+
if start is not None:
|
59 |
+
entity_spans.append((start, end, entity_type))
|
60 |
+
start = i
|
61 |
+
entity_type = label[2:]
|
62 |
+
elif label.startswith("I-") and predicted_entities[i+1].startswith("O"):
|
63 |
+
end = i
|
64 |
+
if start is not None:
|
65 |
+
entity_spans.append((start, end, entity_type))
|
66 |
+
start = end = None
|
67 |
+
entity_type = None
|
68 |
+
|
69 |
+
# 处理最后一个实体
|
70 |
+
if start is not None and end is None:
|
71 |
+
end = len(tokenized_text) - 2
|
72 |
+
entity_spans.append((start, end, entity_type))
|
73 |
+
|
74 |
+
# 输出结果
|
75 |
+
save_pair = []
|
76 |
+
for start, end, entity_type in entity_spans:
|
77 |
+
entity_str = tokenizer.convert_tokens_to_string(tokenized_text[start:end+1])
|
78 |
+
# print(f"实体: {entity_str}, 类型: {entity_type}")
|
79 |
+
save_pair.append((entity_str, entity_type))
|
80 |
+
|
81 |
+
return save_pair
|
82 |
+
|
83 |
+
|
84 |
+
def topk_similarity(embeddings1, embeddings2, k=1):
|
85 |
+
"""
|
86 |
+
Compute the top-k similarity between two sets of embeddings using PyTorch.
|
87 |
+
"""
|
88 |
+
|
89 |
+
### Normalize the embeddings to use cosine similarity
|
90 |
+
embeddings1 = F.normalize(embeddings1, p=2, dim=1)
|
91 |
+
embeddings2 = F.normalize(embeddings2, p=2, dim=1)
|
92 |
+
|
93 |
+
topk_values = []
|
94 |
+
topk_indices = []
|
95 |
+
|
96 |
+
### Iterate over each embedding in the first set
|
97 |
+
for emb1 in embeddings1:
|
98 |
+
|
99 |
+
### Calculate cosine similarity between this embedding and all embeddings in the second set
|
100 |
+
similarities = torch.matmul(embeddings2, emb1)
|
101 |
+
|
102 |
+
### Find the top-k highest similarity values
|
103 |
+
values, indices = torch.topk(similarities, k, largest=True)
|
104 |
+
|
105 |
+
topk_values.append(values[0])
|
106 |
+
topk_indices.append(indices[0])
|
107 |
+
|
108 |
+
return topk_indices, topk_values
|
109 |
+
|
110 |
+
def compute(gt_embeds_word, pred_embeds_word, gt_types, pred_types, weight_matrix):
|
111 |
+
neg_class = [('NON-DISEASE', 'DISEASE'),
|
112 |
+
('NON-ABNORMALITY', 'ABNORMALITY'),
|
113 |
+
('DISEASE', 'NON-DISEASE'),
|
114 |
+
('ABNORMALITY', 'NON-ABNORMALITY'),
|
115 |
+
('NON-DISEASE', 'ABNORMALITY'),
|
116 |
+
('NON-ABNORMALITY', 'DISEASE'),
|
117 |
+
('DISEASE', 'NON-ABNORMALITY'),
|
118 |
+
('ABNORMALITY', 'NON-DISEASE'),]
|
119 |
+
neg_weight = weight_matrix[("NEG", "WEIGHT")]
|
120 |
+
topk_indices, topk_values = topk_similarity(gt_embeds_word, pred_embeds_word, k=1)
|
121 |
+
|
122 |
+
|
123 |
+
for i in range(len(topk_indices)):
|
124 |
+
topk_indices[i] = topk_indices[i].cpu().numpy().tolist()
|
125 |
+
topk_values[i] = topk_values[i].cpu().numpy().tolist()
|
126 |
+
|
127 |
+
# map the indices to type
|
128 |
+
topk_map = [pred_types[i] for i in topk_indices]
|
129 |
+
|
130 |
+
weight_score = [weight_matrix[(gt_type, pred_type)] for gt_type, pred_type in zip(gt_types, topk_map)]
|
131 |
+
type_score = [neg_weight if (gt_type, pred_type) in neg_class else 1 for gt_type, pred_type in zip(gt_types, topk_map)]
|
132 |
+
|
133 |
+
weighted_avg_score = 0
|
134 |
+
weighted_sum = 0
|
135 |
+
for score, weight, type in zip(topk_values, weight_score, type_score):
|
136 |
+
weighted_avg_score += score*weight*type
|
137 |
+
weighted_sum += weight
|
138 |
+
if weighted_sum != 0:
|
139 |
+
RaTE = weighted_avg_score/weighted_sum
|
140 |
+
else:
|
141 |
+
RaTE = 0
|
142 |
+
|
143 |
+
return RaTE
|
factual/RadCliQv1/radcliq.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from nlg.bertscore.bertscore import BertScore
|
4 |
+
from radgraph import RadGraph
|
5 |
+
from factual.f1chexbert import F1CheXbert
|
6 |
+
from sklearn.preprocessing import StandardScaler
|
7 |
+
from nlg.bleu.bleu import Bleu
|
8 |
+
|
9 |
+
|
10 |
+
def radcliq_bertscore(refs, hyps, model_type='distilroberta-base'):
|
11 |
+
"""
|
12 |
+
Computes BERTScore for each pair of reference and hypothesis.
|
13 |
+
|
14 |
+
Returns:
|
15 |
+
np.ndarray of shape (N,) with the BERTScore F1 values per pair.
|
16 |
+
"""
|
17 |
+
# https://github.com/rajpurkarlab/CXR-Report-Metric/blob/9c9ecad39be6cb2be8e75be1d1c50ef8888a3e40/CXRMetric/run_eval.py#L103
|
18 |
+
scorer = BertScore(
|
19 |
+
model_type=model_type,
|
20 |
+
rescale_with_baseline=True,
|
21 |
+
idf=False,
|
22 |
+
num_layers=None
|
23 |
+
)
|
24 |
+
_, scores = scorer(refs, hyps)
|
25 |
+
# scores is a list of torch.Tensor, convert to numpy
|
26 |
+
return np.array([float(s) for s in scores])
|
27 |
+
|
28 |
+
|
29 |
+
def compute_f1(test_set, retrieved_set):
|
30 |
+
"""Helper to compute F1 between two sets of items."""
|
31 |
+
tp = len(test_set & retrieved_set)
|
32 |
+
fp = len(retrieved_set) - tp
|
33 |
+
fn = len(test_set) - tp
|
34 |
+
precision = tp / (tp + fp) if (tp + fp) else 0.0
|
35 |
+
recall = tp / (tp + fn) if (tp + fn) else 0.0
|
36 |
+
return 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
|
37 |
+
|
38 |
+
|
39 |
+
def extract_entities(output):
|
40 |
+
"""Extracts set of (tokens, label) tuples from RadGraph output."""
|
41 |
+
return {(tuple(ent["tokens"]), ent["label"]) for ent in output.get("entities", {}).values()}
|
42 |
+
|
43 |
+
|
44 |
+
def extract_relations(output):
|
45 |
+
"""Extracts set of (src, tgt, relation) tuples from RadGraph output."""
|
46 |
+
rels = set()
|
47 |
+
entities = output.get("entities", {})
|
48 |
+
for ent in entities.values():
|
49 |
+
src = (tuple(ent["tokens"]), ent["label"])
|
50 |
+
for rel_type, tgt_idx in ent.get("relations", []):
|
51 |
+
tgt_ent = entities.get(tgt_idx)
|
52 |
+
if tgt_ent:
|
53 |
+
tgt = (tuple(tgt_ent["tokens"]), tgt_ent["label"])
|
54 |
+
rels.add((src, tgt, rel_type))
|
55 |
+
return rels
|
56 |
+
|
57 |
+
|
58 |
+
def radcliq_radgraph_scores(refs, hyps, model_name='radgraph'):
|
59 |
+
"""
|
60 |
+
Computes entity and relation F1 via RadGraph for each report pair and returns their average.
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
np.ndarray of shape (N,) with (entity_f1 + relation_f1)/2 per pair.
|
64 |
+
"""
|
65 |
+
rad = RadGraph(model_type=model_name)
|
66 |
+
gt_outputs = rad(refs)
|
67 |
+
pred_outputs = rad(hyps)
|
68 |
+
scores = []
|
69 |
+
for i in range(len(refs)):
|
70 |
+
gt_out = gt_outputs.get(str(i), {})
|
71 |
+
pred_out = pred_outputs.get(str(i), {})
|
72 |
+
|
73 |
+
ents_gt = extract_entities(gt_out)
|
74 |
+
ents_pred = extract_entities(pred_out)
|
75 |
+
rels_gt = extract_relations(gt_out)
|
76 |
+
rels_pred = extract_relations(pred_out)
|
77 |
+
|
78 |
+
ent_f1 = compute_f1(ents_gt, ents_pred)
|
79 |
+
rel_f1 = compute_f1(rels_gt, rels_pred)
|
80 |
+
scores.append((ent_f1 + rel_f1) / 2)
|
81 |
+
return np.array(scores)
|
82 |
+
|
83 |
+
|
84 |
+
def semantic_embedding_scores(refs, hyps, device='cpu'):
|
85 |
+
"""
|
86 |
+
Computes per-pair cosine similarity between embeddings from CheXbert labeler.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
np.ndarray of shape (N,) with cosine similarities per pair.
|
90 |
+
"""
|
91 |
+
if len(refs) != len(hyps):
|
92 |
+
raise ValueError(f"refs ({len(refs)}) and hyps ({len(hyps)}) must be same length")
|
93 |
+
labeler = F1CheXbert(device=device)
|
94 |
+
gt_embs = np.vstack(labeler.get_embeddings(refs))
|
95 |
+
pred_embs = np.vstack(labeler.get_embeddings(hyps))
|
96 |
+
# https://github.com/rajpurkarlab/CXR-Report-Metric/blob/9c9ecad39be6cb2be8e75be1d1c50ef8888a3e40/CXRMetric/run_eval.py#L126
|
97 |
+
dot = np.einsum("nd,nd->n", gt_embs, pred_embs)
|
98 |
+
norms = np.linalg.norm(gt_embs, axis=1) * np.linalg.norm(pred_embs, axis=1)
|
99 |
+
with np.errstate(divide='ignore', invalid='ignore'):
|
100 |
+
sims = np.where(norms > 0, dot / norms, 0.0)
|
101 |
+
return sims
|
102 |
+
|
103 |
+
|
104 |
+
def radcliq_scores(refs, hyps,
|
105 |
+
bert_model='distilroberta-base',
|
106 |
+
radgraph_model='radgraph'):
|
107 |
+
"""
|
108 |
+
Computes BERTScore, RadGraph score, and semantic embedding similarity for each ref-hyp pair.
|
109 |
+
|
110 |
+
Args:
|
111 |
+
refs: List of reference report strings.
|
112 |
+
hyps: List of hypothesis report strings.
|
113 |
+
device: Device for embedding model ('cpu' or 'cuda').
|
114 |
+
bert_model: HuggingFace model name for BERTScore.
|
115 |
+
radgraph_model: Model name for RadGraph inference.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Dict with keys 'bertscore', 'radgraph', 'semantic', each mapping to a numpy array of shape (N,).
|
119 |
+
"""
|
120 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
121 |
+
# BERTScore
|
122 |
+
bert_scores = radcliq_bertscore(refs, hyps, model_type=bert_model)
|
123 |
+
# RadGraph
|
124 |
+
rad_scores = radcliq_radgraph_scores(refs, hyps, model_name=radgraph_model)
|
125 |
+
# Semantic embeddings
|
126 |
+
sem_scores = semantic_embedding_scores(refs, hyps, device=device)
|
127 |
+
|
128 |
+
# BLEU
|
129 |
+
bleu_scorer = Bleu()
|
130 |
+
bleu_scores = bleu_scorer(refs, hyps)[1]
|
131 |
+
|
132 |
+
return {
|
133 |
+
'bertscore': bert_scores,
|
134 |
+
'radgraph': rad_scores,
|
135 |
+
'semb_score': sem_scores,
|
136 |
+
'bleu_score': bleu_scores
|
137 |
+
}
|
138 |
+
|
139 |
+
|
140 |
+
|
141 |
+
class CompositeMetric:
|
142 |
+
def __init__(self):
|
143 |
+
scaler = StandardScaler(with_mean=True, with_std=True)
|
144 |
+
# learnt parameters, infered from
|
145 |
+
# https://github.com/rajpurkarlab/CXR-Report-Metric/blob/main/CXRMetric/run_eval.py#L219
|
146 |
+
scaler.mean_ = np.array([0.53792312, 0.61757256, 0.76479421, 0.44738335])
|
147 |
+
scaler.scale_ = np.array([0.30282584, 0.22430938, 0.25394391, 0.29892717])
|
148 |
+
scaler.var_ = np.array([0.09170349, 0.05031470, 0.06448751, 0.08935745])
|
149 |
+
scaler.n_samples_seen_ = 160 # integer
|
150 |
+
scaler.n_features_in_ = 4 # integer
|
151 |
+
|
152 |
+
self.scaler = scaler
|
153 |
+
self.coefs = np.array([
|
154 |
+
-3.77083683e-01, # radgraph weight
|
155 |
+
-3.70300100e-01, # bertscore weight
|
156 |
+
-2.52616218e-01, # s-emb weight
|
157 |
+
4.31504841e-12, # bleu weight
|
158 |
+
2.46655256e-10 # intercept / bias
|
159 |
+
])
|
160 |
+
self.cols = ["radgraph", "bertscore", "semb_score", "bleu_score"]
|
161 |
+
|
162 |
+
def predict(self, X):
|
163 |
+
Xn = self.scaler.transform(X)
|
164 |
+
Xn = np.hstack([Xn, np.ones((Xn.shape[0], 1))])
|
165 |
+
return Xn @ self.coefs
|
166 |
+
|
167 |
+
def _build_matrix(self, metrics: dict[str, np.ndarray]) -> np.ndarray:
|
168 |
+
"""Stack features in the canonical column order."""
|
169 |
+
return np.column_stack([metrics[c] for c in self.cols])
|
170 |
+
|
171 |
+
def predict(self, refs, hyps) -> np.ndarray:
|
172 |
+
"""
|
173 |
+
Args
|
174 |
+
----
|
175 |
+
metrics : dict returned by `radcliq_scores`
|
176 |
+
|
177 |
+
Returns
|
178 |
+
-------
|
179 |
+
np.ndarray of shape (N,) – RadCliQ-v1 score for each ref/hyp pair.
|
180 |
+
"""
|
181 |
+
metrics = radcliq_scores(refs, hyps)
|
182 |
+
|
183 |
+
X = self._build_matrix(metrics)
|
184 |
+
|
185 |
+
Xn = self.scaler.transform(X)
|
186 |
+
|
187 |
+
# Append bias term
|
188 |
+
Xn = np.hstack([Xn, np.ones((Xn.shape[0], 1))])
|
189 |
+
scores = Xn @ self.coefs
|
190 |
+
|
191 |
+
return 1/scores.mean(), scores
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
refs = [
|
195 |
+
"No evidence of pneumothorax following chest tube removal.",
|
196 |
+
"There is a left pleural effusion.",
|
197 |
+
"There is a left pleural effusion."
|
198 |
+
]
|
199 |
+
hyps = [
|
200 |
+
"No pneumothorax detected.",
|
201 |
+
"Left pleural effusion is present.",
|
202 |
+
"No pneumothorax detected.",
|
203 |
+
]
|
204 |
+
|
205 |
+
# Step-1: compute the four individual metrics
|
206 |
+
|
207 |
+
# Step-2: get the RadCliQ-v1 composite
|
208 |
+
radcliq = CompositeMetric()
|
209 |
+
mean_scores, detail_scores = radcliq.predict(refs, hyps)
|
210 |
+
for i, s in enumerate(detail_scores, 1):
|
211 |
+
print(f"Pair {i}: RadCliQ-v1 = {s:.4f}")
|
212 |
+
|
213 |
+
print(f"RadCliQ-v1 score: {mean_scores:.4f}")
|
factual/RadCliQv1/radcliq_bertscore.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from nlg.bertscore.bertscore import BertScore
|
2 |
+
|
3 |
+
def radcliq_bertscore(refs, hyps):
|
4 |
+
bertscore_scorer = BertScore(model_type='distilroberta-base',
|
5 |
+
rescale_with_baseline=True,
|
6 |
+
idf=False,
|
7 |
+
num_layers=None)
|
8 |
+
print(bertscore_scorer)
|
9 |
+
avg, scores = bertscore_scorer(refs, hyps)
|
10 |
+
return scores
|
factual/RadCliQv1/radcliq_radgraph.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from radgraph import RadGraph
|
3 |
+
|
4 |
+
|
5 |
+
def compute_f1(test, retrieved):
|
6 |
+
"""Computes F1 between test/retrieved report's entities or relations."""
|
7 |
+
tp = len(test & retrieved)
|
8 |
+
fp = len(retrieved) - tp
|
9 |
+
fn = len(test) - tp
|
10 |
+
precision = tp / (tp + fp) if (tp + fp) else 0
|
11 |
+
recall = tp / (tp + fn) if (tp + fn) else 0
|
12 |
+
return 2 * precision * recall / (precision + recall) if (precision + recall) else 0
|
13 |
+
|
14 |
+
|
15 |
+
def extract_entities(output):
|
16 |
+
"""Extracts set of (tokens, label) from a RadGraph output dict."""
|
17 |
+
return {(tuple(ent["tokens"]), ent["label"]) for ent in output.get("entities", {}).values()}
|
18 |
+
|
19 |
+
|
20 |
+
def extract_relations(output):
|
21 |
+
"""Extracts set of (src, tgt, relation) from a RadGraph output dict."""
|
22 |
+
rels = set()
|
23 |
+
entities = output.get("entities", {})
|
24 |
+
for ent in entities.values():
|
25 |
+
src = (tuple(ent["tokens"]), ent["label"])
|
26 |
+
for rel_type, tgt_idx in ent.get("relations", []):
|
27 |
+
tgt_ent = entities.get(tgt_idx)
|
28 |
+
if tgt_ent:
|
29 |
+
tgt = (tuple(tgt_ent["tokens"]), tgt_ent["label"])
|
30 |
+
rels.add((src, tgt, rel_type))
|
31 |
+
return rels
|
32 |
+
|
33 |
+
|
34 |
+
def compute_radgraph_scores(refs, hyps, model_name='radgraph'):
|
35 |
+
"""
|
36 |
+
Computes combined RadGraph F1 scores for each pair of reference and hypothesis reports.
|
37 |
+
Returns:
|
38 |
+
List of floats: (entity_f1 + relation_f1)/2 per report.
|
39 |
+
"""
|
40 |
+
# Initialize RadGraph model
|
41 |
+
rad = RadGraph(model_type=model_name)
|
42 |
+
|
43 |
+
# Perform inference
|
44 |
+
gt_outputs = rad(refs)
|
45 |
+
pred_outputs = rad(hyps)
|
46 |
+
|
47 |
+
scores = []
|
48 |
+
for i in range(len(gt_outputs)):
|
49 |
+
gt_out = gt_outputs[str(i)]
|
50 |
+
pred_out = pred_outputs[str(i)]
|
51 |
+
|
52 |
+
gt_ents = extract_entities(gt_out)
|
53 |
+
pred_ents = extract_entities(pred_out)
|
54 |
+
gt_rels = extract_relations(gt_out)
|
55 |
+
pred_rels = extract_relations(pred_out)
|
56 |
+
|
57 |
+
ent_f1 = compute_f1(gt_ents, pred_ents)
|
58 |
+
rel_f1 = compute_f1(gt_rels, pred_rels)
|
59 |
+
scores.append((ent_f1 + rel_f1) / 2)
|
60 |
+
|
61 |
+
return scores
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
# Example usage
|
66 |
+
refs = [
|
67 |
+
"No evidence of pneumothorax following chest tube removal.",
|
68 |
+
"There is a left pleural effusion."
|
69 |
+
]
|
70 |
+
hyps = [
|
71 |
+
"No pneumothorax detected.",
|
72 |
+
"Left pleural effusion is present."
|
73 |
+
]
|
74 |
+
|
75 |
+
combined_scores = compute_radgraph_scores(refs, hyps)
|
76 |
+
print(combined_scores) # e.g., [1.0, 1.0]
|
77 |
+
from radgraph import F1RadGraph
|
78 |
+
f1_radgraph = F1RadGraph(model_type="radgraph", reward_level="simple")
|
79 |
+
f1_scores = f1_radgraph(refs, hyps,)
|
80 |
+
print(f1_scores)
|
factual/RadCliQv1/semb_score.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Sequence, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from factual.f1chexbert import F1CheXbert
|
7 |
+
|
8 |
+
|
9 |
+
def semantic_embedding_scores(
|
10 |
+
refs: Sequence[str],
|
11 |
+
hyps: Sequence[str],
|
12 |
+
*,
|
13 |
+
device: Union[str, torch.device] = "cpu",
|
14 |
+
) -> np.ndarray:
|
15 |
+
"""Return per‑pair cosine similarities between `refs` and `hyps`.
|
16 |
+
|
17 |
+
All heavy math is vectorised; no Python loops.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
refs: Iterable of ground‑truth report strings.
|
21 |
+
hyps: Iterable of predicted report strings (must match `refs` length).
|
22 |
+
device: Computation device (e.g. "cpu", "cuda", "cuda:0").
|
23 |
+
|
24 |
+
Returns
|
25 |
+
-------
|
26 |
+
np.ndarray
|
27 |
+
Shape ``(N,)`` – cosine similarity for each pair, where
|
28 |
+
``N == len(refs) == len(hyps)``.
|
29 |
+
|
30 |
+
Raises
|
31 |
+
------
|
32 |
+
ValueError
|
33 |
+
If `refs` and `hyps` are of different lengths.
|
34 |
+
"""
|
35 |
+
|
36 |
+
if len(refs) != len(hyps):
|
37 |
+
raise ValueError(f"refs ({len(refs)}) and hyps ({len(hyps)}) differ in length")
|
38 |
+
|
39 |
+
labeler = F1CheXbert(device=device)
|
40 |
+
|
41 |
+
# Stack embeddings into (N, dim) matrices
|
42 |
+
gt_embeds = np.vstack(labeler.get_embeddings(refs)) # (N, dim)
|
43 |
+
pred_embeds = np.vstack(labeler.get_embeddings(hyps)) # (N, dim)
|
44 |
+
|
45 |
+
# Cosine similarity – fully vectorised
|
46 |
+
dot = np.einsum("nd,nd->n", gt_embeds, pred_embeds)
|
47 |
+
norms = np.linalg.norm(gt_embeds, axis=1) * np.linalg.norm(pred_embeds, axis=1)
|
48 |
+
with np.errstate(divide="ignore", invalid="ignore"):
|
49 |
+
sims = np.where(norms > 0, dot / norms, 0.0)
|
50 |
+
|
51 |
+
return sims
|
52 |
+
|
53 |
+
|
54 |
+
def mean_semantic_score(scores: np.ndarray) -> float:
|
55 |
+
"""Convenience helper: mean of an array of scores."""
|
56 |
+
return float(scores.mean())
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
_refs = [
|
61 |
+
"No evidence of pneumothorax following chest tube removal.",
|
62 |
+
"There is a left pleural effusion.",
|
63 |
+
"No evidence of pneumothorax following chest tube removal.",
|
64 |
+
|
65 |
+
]
|
66 |
+
_hyps = [
|
67 |
+
"No pneumothorax detected.",
|
68 |
+
"Left pleural effusion is present.",
|
69 |
+
"Left pleural effusion is present.",
|
70 |
+
]
|
71 |
+
|
72 |
+
_scores = semantic_embedding_scores(_refs, _hyps, device="cpu")
|
73 |
+
print("Per‑pair cosine:", _scores)
|
74 |
+
print("Mean:", mean_semantic_score(_scores))
|
factual/SRRBert/leaves_mapping.json
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"No Finding": 0,
|
3 |
+
"Lung Lesion": 1,
|
4 |
+
"Edema": 2,
|
5 |
+
"Pneumonia": 3,
|
6 |
+
"Atelectasis": 4,
|
7 |
+
"Aspiration": 5,
|
8 |
+
"Lung collapse": 6,
|
9 |
+
"Perihilar airspace opacity": 7,
|
10 |
+
"Air space opacity\u2013multifocal": 8,
|
11 |
+
"Mass/Solitary lung mass": 9,
|
12 |
+
"Nodule/Solitary lung nodule": 10,
|
13 |
+
"Cavitating mass with content": 11,
|
14 |
+
"Cavitating masses": 12,
|
15 |
+
"Emphysema": 13,
|
16 |
+
"Fibrosis": 14,
|
17 |
+
"Pulmonary congestion": 15,
|
18 |
+
"Hilar lymphadenopathy": 16,
|
19 |
+
"Bronchiectasis": 17,
|
20 |
+
"Simple pneumothorax": 18,
|
21 |
+
"Loculated pneumothorax": 19,
|
22 |
+
"Tension pneumothorax": 20,
|
23 |
+
"Simple pleural effusion": 21,
|
24 |
+
"Loculated pleural effusion": 22,
|
25 |
+
"Pleural scarring": 23,
|
26 |
+
"Hydropneumothorax": 24,
|
27 |
+
"Pleural Other": 25,
|
28 |
+
"Cardiomegaly": 26,
|
29 |
+
"Pericardial effusion": 27,
|
30 |
+
"Inferior mediastinal mass": 28,
|
31 |
+
"Superior mediastinal mass": 29,
|
32 |
+
"Tortuous Aorta": 30,
|
33 |
+
"Calcification of the Aorta": 31,
|
34 |
+
"Enlarged pulmonary artery": 32,
|
35 |
+
"Hernia": 33,
|
36 |
+
"Pneumomediastinum": 34,
|
37 |
+
"Tracheal deviation": 35,
|
38 |
+
"Acute humerus fracture": 36,
|
39 |
+
"Acute rib fracture": 37,
|
40 |
+
"Acute clavicle fracture": 38,
|
41 |
+
"Acute scapula fracture": 39,
|
42 |
+
"Compression fracture": 40,
|
43 |
+
"Shoulder dislocation": 41,
|
44 |
+
"Subcutaneous Emphysema": 42,
|
45 |
+
"Suboptimal central line": 43,
|
46 |
+
"Suboptimal endotracheal tube": 44,
|
47 |
+
"Suboptimal nasogastric tube": 45,
|
48 |
+
"Suboptimal pulmonary arterial catheter": 46,
|
49 |
+
"Pleural tube": 47,
|
50 |
+
"PICC line": 48,
|
51 |
+
"Port catheter": 49,
|
52 |
+
"Pacemaker": 50,
|
53 |
+
"Implantable defibrillator": 51,
|
54 |
+
"LVAD": 52,
|
55 |
+
"Intraaortic balloon pump": 53,
|
56 |
+
"Pneumoperitoneum": 54
|
57 |
+
}
|
58 |
+
|
factual/SRRBert/leaves_with_statuses_mapping.json
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Lung Lesion (Present)": 0,
|
3 |
+
"Edema (Present)": 1,
|
4 |
+
"Pneumonia (Present)": 2,
|
5 |
+
"Atelectasis (Present)": 3,
|
6 |
+
"Aspiration (Present)": 4,
|
7 |
+
"Lung collapse (Present)": 5,
|
8 |
+
"Perihilar airspace opacity (Present)": 6,
|
9 |
+
"Air space opacity\u2013multifocal (Present)": 7,
|
10 |
+
"Mass/Solitary lung mass (Present)": 8,
|
11 |
+
"Nodule/Solitary lung nodule (Present)": 9,
|
12 |
+
"Cavitating mass with content (Present)": 10,
|
13 |
+
"Cavitating masses (Present)": 11,
|
14 |
+
"Emphysema (Present)": 12,
|
15 |
+
"Fibrosis (Present)": 13,
|
16 |
+
"Pulmonary congestion (Present)": 14,
|
17 |
+
"Hilar lymphadenopathy (Present)": 15,
|
18 |
+
"Bronchiectasis (Present)": 16,
|
19 |
+
"Simple pneumothorax (Present)": 17,
|
20 |
+
"Loculated pneumothorax (Present)": 18,
|
21 |
+
"Tension pneumothorax (Present)": 19,
|
22 |
+
"Simple pleural effusion (Present)": 20,
|
23 |
+
"Loculated pleural effusion (Present)": 21,
|
24 |
+
"Pleural scarring (Present)": 22,
|
25 |
+
"Hydropneumothorax (Present)": 23,
|
26 |
+
"Pleural Other (Present)": 24,
|
27 |
+
"Cardiomegaly (Present)": 25,
|
28 |
+
"Pericardial effusion (Present)": 26,
|
29 |
+
"Inferior mediastinal mass (Present)": 27,
|
30 |
+
"Superior mediastinal mass (Present)": 28,
|
31 |
+
"Tortuous Aorta (Present)": 29,
|
32 |
+
"Calcification of the Aorta (Present)": 30,
|
33 |
+
"Enlarged pulmonary artery (Present)": 31,
|
34 |
+
"Hernia (Present)": 32,
|
35 |
+
"Pneumomediastinum (Present)": 33,
|
36 |
+
"Tracheal deviation (Present)": 34,
|
37 |
+
"Acute humerus fracture (Present)": 35,
|
38 |
+
"Acute rib fracture (Present)": 36,
|
39 |
+
"Acute clavicle fracture (Present)": 37,
|
40 |
+
"Acute scapula fracture (Present)": 38,
|
41 |
+
"Compression fracture (Present)": 39,
|
42 |
+
"Shoulder dislocation (Present)": 40,
|
43 |
+
"Subcutaneous Emphysema (Present)": 41,
|
44 |
+
"Suboptimal central line (Present)": 42,
|
45 |
+
"Suboptimal endotracheal tube (Present)": 43,
|
46 |
+
"Suboptimal nasogastric tube (Present)": 44,
|
47 |
+
"Suboptimal pulmonary arterial catheter (Present)": 45,
|
48 |
+
"Pleural tube (Present)": 46,
|
49 |
+
"PICC line (Present)": 47,
|
50 |
+
"Port catheter (Present)": 48,
|
51 |
+
"Pacemaker (Present)": 49,
|
52 |
+
"Implantable defibrillator (Present)": 50,
|
53 |
+
"LVAD (Present)": 51,
|
54 |
+
"Intraaortic balloon pump (Present)": 52,
|
55 |
+
"Pneumoperitoneum (Present)": 53,
|
56 |
+
"Lung Lesion (Uncertain)": 54,
|
57 |
+
"Edema (Uncertain)": 55,
|
58 |
+
"Pneumonia (Uncertain)": 56,
|
59 |
+
"Atelectasis (Uncertain)": 57,
|
60 |
+
"Aspiration (Uncertain)": 58,
|
61 |
+
"Lung collapse (Uncertain)": 59,
|
62 |
+
"Perihilar airspace opacity (Uncertain)": 60,
|
63 |
+
"Air space opacity\u2013multifocal (Uncertain)": 61,
|
64 |
+
"Mass/Solitary lung mass (Uncertain)": 62,
|
65 |
+
"Nodule/Solitary lung nodule (Uncertain)": 63,
|
66 |
+
"Cavitating mass with content (Uncertain)": 64,
|
67 |
+
"Cavitating masses (Uncertain)": 65,
|
68 |
+
"Emphysema (Uncertain)": 66,
|
69 |
+
"Fibrosis (Uncertain)": 67,
|
70 |
+
"Pulmonary congestion (Uncertain)": 68,
|
71 |
+
"Hilar lymphadenopathy (Uncertain)": 69,
|
72 |
+
"Bronchiectasis (Uncertain)": 70,
|
73 |
+
"Simple pneumothorax (Uncertain)": 71,
|
74 |
+
"Loculated pneumothorax (Uncertain)": 72,
|
75 |
+
"Tension pneumothorax (Uncertain)": 73,
|
76 |
+
"Simple pleural effusion (Uncertain)": 74,
|
77 |
+
"Loculated pleural effusion (Uncertain)": 75,
|
78 |
+
"Pleural scarring (Uncertain)": 76,
|
79 |
+
"Hydropneumothorax (Uncertain)": 77,
|
80 |
+
"Pleural Other (Uncertain)": 78,
|
81 |
+
"Cardiomegaly (Uncertain)": 79,
|
82 |
+
"Pericardial effusion (Uncertain)": 80,
|
83 |
+
"Inferior mediastinal mass (Uncertain)": 81,
|
84 |
+
"Superior mediastinal mass (Uncertain)": 82,
|
85 |
+
"Tortuous Aorta (Uncertain)": 83,
|
86 |
+
"Calcification of the Aorta (Uncertain)": 84,
|
87 |
+
"Enlarged pulmonary artery (Uncertain)": 85,
|
88 |
+
"Hernia (Uncertain)": 86,
|
89 |
+
"Pneumomediastinum (Uncertain)": 87,
|
90 |
+
"Tracheal deviation (Uncertain)": 88,
|
91 |
+
"Acute humerus fracture (Uncertain)": 89,
|
92 |
+
"Acute rib fracture (Uncertain)": 90,
|
93 |
+
"Acute clavicle fracture (Uncertain)": 91,
|
94 |
+
"Acute scapula fracture (Uncertain)": 92,
|
95 |
+
"Compression fracture (Uncertain)": 93,
|
96 |
+
"Shoulder dislocation (Uncertain)": 94,
|
97 |
+
"Subcutaneous Emphysema (Uncertain)": 95,
|
98 |
+
"Suboptimal central line (Uncertain)": 96,
|
99 |
+
"Suboptimal endotracheal tube (Uncertain)": 97,
|
100 |
+
"Suboptimal nasogastric tube (Uncertain)": 98,
|
101 |
+
"Suboptimal pulmonary arterial catheter (Uncertain)": 99,
|
102 |
+
"Pleural tube (Uncertain)": 100,
|
103 |
+
"PICC line (Uncertain)": 101,
|
104 |
+
"Port catheter (Uncertain)": 102,
|
105 |
+
"Pacemaker (Uncertain)": 103,
|
106 |
+
"Implantable defibrillator (Uncertain)": 104,
|
107 |
+
"LVAD (Uncertain)": 105,
|
108 |
+
"Intraaortic balloon pump (Uncertain)": 106,
|
109 |
+
"Pneumoperitoneum (Uncertain)": 107,
|
110 |
+
"Lung Lesion (Absent)": 108,
|
111 |
+
"Edema (Absent)": 109,
|
112 |
+
"Pneumonia (Absent)": 110,
|
113 |
+
"Atelectasis (Absent)": 111,
|
114 |
+
"Aspiration (Absent)": 112,
|
115 |
+
"Lung collapse (Absent)": 113,
|
116 |
+
"Perihilar airspace opacity (Absent)": 114,
|
117 |
+
"Air space opacity\u2013multifocal (Absent)": 115,
|
118 |
+
"Mass/Solitary lung mass (Absent)": 116,
|
119 |
+
"Nodule/Solitary lung nodule (Absent)": 117,
|
120 |
+
"Cavitating mass with content (Absent)": 118,
|
121 |
+
"Cavitating masses (Absent)": 119,
|
122 |
+
"Emphysema (Absent)": 120,
|
123 |
+
"Fibrosis (Absent)": 121,
|
124 |
+
"Pulmonary congestion (Absent)": 122,
|
125 |
+
"Hilar lymphadenopathy (Absent)": 123,
|
126 |
+
"Bronchiectasis (Absent)": 124,
|
127 |
+
"Simple pneumothorax (Absent)": 125,
|
128 |
+
"Loculated pneumothorax (Absent)": 126,
|
129 |
+
"Tension pneumothorax (Absent)": 127,
|
130 |
+
"Simple pleural effusion (Absent)": 128,
|
131 |
+
"Loculated pleural effusion (Absent)": 129,
|
132 |
+
"Pleural scarring (Absent)": 130,
|
133 |
+
"Hydropneumothorax (Absent)": 131,
|
134 |
+
"Pleural Other (Absent)": 132,
|
135 |
+
"Cardiomegaly (Absent)": 133,
|
136 |
+
"Pericardial effusion (Absent)": 134,
|
137 |
+
"Inferior mediastinal mass (Absent)": 135,
|
138 |
+
"Superior mediastinal mass (Absent)": 136,
|
139 |
+
"Tortuous Aorta (Absent)": 137,
|
140 |
+
"Calcification of the Aorta (Absent)": 138,
|
141 |
+
"Enlarged pulmonary artery (Absent)": 139,
|
142 |
+
"Hernia (Absent)": 140,
|
143 |
+
"Pneumomediastinum (Absent)": 141,
|
144 |
+
"Tracheal deviation (Absent)": 142,
|
145 |
+
"Acute humerus fracture (Absent)": 143,
|
146 |
+
"Acute rib fracture (Absent)": 144,
|
147 |
+
"Acute clavicle fracture (Absent)": 145,
|
148 |
+
"Acute scapula fracture (Absent)": 146,
|
149 |
+
"Compression fracture (Absent)": 147,
|
150 |
+
"Shoulder dislocation (Absent)": 148,
|
151 |
+
"Subcutaneous Emphysema (Absent)": 149,
|
152 |
+
"Suboptimal central line (Absent)": 150,
|
153 |
+
"Suboptimal endotracheal tube (Absent)": 151,
|
154 |
+
"Suboptimal nasogastric tube (Absent)": 152,
|
155 |
+
"Suboptimal pulmonary arterial catheter (Absent)": 153,
|
156 |
+
"Pleural tube (Absent)": 154,
|
157 |
+
"PICC line (Absent)": 155,
|
158 |
+
"Port catheter (Absent)": 156,
|
159 |
+
"Pacemaker (Absent)": 157,
|
160 |
+
"Implantable defibrillator (Absent)": 158,
|
161 |
+
"LVAD (Absent)": 159,
|
162 |
+
"Intraaortic balloon pump (Absent)": 160,
|
163 |
+
"Pneumoperitoneum (Absent)": 161,
|
164 |
+
"No Finding": 162
|
165 |
+
}
|
factual/SRRBert/srr_bert.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from transformers import BertForSequenceClassification, BertTokenizer
|
7 |
+
from tqdm import tqdm
|
8 |
+
import re
|
9 |
+
import nltk
|
10 |
+
|
11 |
+
|
12 |
+
def srr_bert_parse_sentences(text):
|
13 |
+
# Handle numbers followed by a dot, not followed by a digit (to avoid decimals like 3.5)
|
14 |
+
|
15 |
+
# Case 1: Number at beginning of text
|
16 |
+
text = re.sub(r'^\s*\d+\.(?!\d)\s*', '', text)
|
17 |
+
|
18 |
+
# Case 2: Number after a period, like "word.2."
|
19 |
+
text = re.sub(r'(\w)\.(\d+)\.(?!\d)\s*', r'\1. ', text)
|
20 |
+
|
21 |
+
# Case 3: Number attached to a word, like "word2."
|
22 |
+
text = re.sub(r'(\w)(\d+)\.(?!\d)\s*', r'\1. ', text)
|
23 |
+
|
24 |
+
# Case 4: Number after space following a word, like "word 2."
|
25 |
+
text = re.sub(r'(\w)\s+\d+\.(?!\d)\s*', r'\1. ', text)
|
26 |
+
|
27 |
+
# Case 5: Standalone number in the middle, like ". 2. word"
|
28 |
+
text = re.sub(r'([.!?])\s*\d+\.(?!\d)\s*', r'\1 ', text)
|
29 |
+
|
30 |
+
# Add space after periods followed immediately by uppercase letter (new sentence without space)
|
31 |
+
text = re.sub(r'\.([A-Z])', r'. \1', text)
|
32 |
+
|
33 |
+
# Make sure the text ends with a period
|
34 |
+
if not text.strip().endswith(('.', '!', '?')):
|
35 |
+
text = text.strip() + '.'
|
36 |
+
|
37 |
+
# Tokenize into sentences
|
38 |
+
sentences = nltk.sent_tokenize(text)
|
39 |
+
|
40 |
+
return sentences
|
41 |
+
|
42 |
+
|
43 |
+
class SRRBert(nn.Module):
|
44 |
+
# Supported model types and their configs
|
45 |
+
MODEL_CONFIGS = {
|
46 |
+
"leaves": {
|
47 |
+
"model_path": "StanfordAIMI/SRR-BERT-Leaves",
|
48 |
+
"mapping_file": "leaves_mapping.json"
|
49 |
+
},
|
50 |
+
"upper": {
|
51 |
+
"model_path": "StanfordAIMI/SRR-BERT-Upper",
|
52 |
+
"mapping_file": "upper_mapping.json"
|
53 |
+
},
|
54 |
+
"leaves_with_statuses": {
|
55 |
+
"model_path": "StanfordAIMI/SRR-BERT-Leaves-with-Statuses",
|
56 |
+
"mapping_file": "leaves_with_statuses_mapping.json"
|
57 |
+
},
|
58 |
+
"upper_with_statuses": {
|
59 |
+
"model_path": "StanfordAIMI/SRRG-BERT-Upper-with-Statuses",
|
60 |
+
"mapping_file": "upper_with_statuses_mapping.json"
|
61 |
+
},
|
62 |
+
}
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
model_type: str = "leaves",
|
67 |
+
batch_size: int = 4,
|
68 |
+
tqdm_enable: bool = False
|
69 |
+
):
|
70 |
+
super().__init__()
|
71 |
+
if model_type not in self.MODEL_CONFIGS:
|
72 |
+
raise ValueError(
|
73 |
+
f"model_type must be one of {list(self.MODEL_CONFIGS.keys())}"
|
74 |
+
)
|
75 |
+
config = self.MODEL_CONFIGS[model_type]
|
76 |
+
|
77 |
+
# Load mapping
|
78 |
+
mapping_path = os.path.join(
|
79 |
+
os.path.dirname(__file__),
|
80 |
+
config["mapping_file"]
|
81 |
+
)
|
82 |
+
with open(mapping_path, 'r') as f:
|
83 |
+
self.mapping = json.load(f)
|
84 |
+
|
85 |
+
# Device setup
|
86 |
+
self.device = torch.device(
|
87 |
+
'cuda' if torch.cuda.is_available() else 'cpu'
|
88 |
+
)
|
89 |
+
|
90 |
+
# Load model
|
91 |
+
self.model = BertForSequenceClassification.from_pretrained(
|
92 |
+
config["model_path"],
|
93 |
+
num_labels=len(self.mapping)
|
94 |
+
)
|
95 |
+
self.model.to(self.device)
|
96 |
+
self.model.eval()
|
97 |
+
|
98 |
+
# Tokenizer
|
99 |
+
self.tokenizer = BertTokenizer.from_pretrained(
|
100 |
+
"microsoft/BiomedVLP-CXR-BERT-general"
|
101 |
+
)
|
102 |
+
|
103 |
+
# Settings
|
104 |
+
self.batch_size = batch_size
|
105 |
+
self.tqdm_enable = tqdm_enable
|
106 |
+
|
107 |
+
def map_predictions_to_labels(self, outputs):
|
108 |
+
inverted_mapping = {v: k for k, v in self.mapping.items()}
|
109 |
+
all_labels = []
|
110 |
+
for output in outputs:
|
111 |
+
labels = [inverted_mapping[i] for i, flag in enumerate(output) if flag == 1]
|
112 |
+
all_labels.append(labels)
|
113 |
+
return all_labels
|
114 |
+
|
115 |
+
def forward(self, sentences):
|
116 |
+
# Batch sentences
|
117 |
+
batches = [
|
118 |
+
sentences[i:i + self.batch_size]
|
119 |
+
for i in range(0, len(sentences), self.batch_size)
|
120 |
+
]
|
121 |
+
outputs = []
|
122 |
+
with torch.no_grad():
|
123 |
+
for batch in tqdm(
|
124 |
+
batches, desc="Predicting", disable=not self.tqdm_enable
|
125 |
+
):
|
126 |
+
inputs = self.tokenizer.batch_encode_plus(
|
127 |
+
batch,
|
128 |
+
add_special_tokens=True,
|
129 |
+
max_length=512,
|
130 |
+
padding="max_length",
|
131 |
+
truncation=True,
|
132 |
+
return_attention_mask=True,
|
133 |
+
return_tensors="pt",
|
134 |
+
)
|
135 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
136 |
+
logits = self.model(**inputs).logits
|
137 |
+
preds = (torch.sigmoid(logits) > 0.5).cpu().numpy().astype(int)
|
138 |
+
outputs.append(preds)
|
139 |
+
|
140 |
+
outputs = np.concatenate(outputs, axis=0)
|
141 |
+
return outputs, self.map_predictions_to_labels(outputs)
|
142 |
+
|
143 |
+
|
144 |
+
if __name__ == "__main__":
|
145 |
+
example_sentences = [
|
146 |
+
"Layering pleural effusions",
|
147 |
+
"Moderate pulmonary edema.",
|
148 |
+
"Chronic fracture and dislocation involving the left humeral surgical neck and glenoid.",
|
149 |
+
"Stable cardiomegaly.",
|
150 |
+
]
|
151 |
+
|
152 |
+
# Initialize model (choose one of: leaves, upper, leaves_with_statuses, upper_with_statuses)
|
153 |
+
model = SRRBert(
|
154 |
+
model_type="leaves",
|
155 |
+
batch_size=4,
|
156 |
+
tqdm_enable=True
|
157 |
+
)
|
158 |
+
outputs, labels = model(example_sentences)
|
159 |
+
print("Raw outputs:", outputs)
|
160 |
+
print("Predicted labels:", labels)
|
factual/SRRBert/upper_mapping.json
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Pleural Effusion": 0,
|
3 |
+
"Upper abdominal finding": 1,
|
4 |
+
"Widened cardiac silhouette": 2,
|
5 |
+
"Lung Finding": 3,
|
6 |
+
"No Finding": 4,
|
7 |
+
"Widened aortic contour": 5,
|
8 |
+
"Pleural Thickening": 6,
|
9 |
+
"Vascular finding": 7,
|
10 |
+
"Consolidation": 8,
|
11 |
+
"Pneumothorax": 9,
|
12 |
+
"Subdiaphragmatic gas": 10,
|
13 |
+
"Masslike opacity": 11,
|
14 |
+
"Chest wall finding": 12,
|
15 |
+
"Focal air space opacity": 13,
|
16 |
+
"Segmental collapse": 14,
|
17 |
+
"Fracture": 15,
|
18 |
+
"Mediastinal mass": 16,
|
19 |
+
"Solitary masslike opacity": 17,
|
20 |
+
"Support Devices": 18,
|
21 |
+
"Mediastinal finding": 19,
|
22 |
+
"Pleural finding": 20,
|
23 |
+
"Air space opacity": 21,
|
24 |
+
"Diffuse air space opacity": 22,
|
25 |
+
"Multiple masslike opacities": 23,
|
26 |
+
"Musculoskeletal finding": 24
|
27 |
+
}
|
28 |
+
|
factual/SRRBert/upper_with_statuses_mapping.json
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"Pleural Effusion (Present)": 0,
|
3 |
+
"Upper abdominal finding (Present)": 1,
|
4 |
+
"Widened cardiac silhouette (Present)": 2,
|
5 |
+
"Lung Finding (Present)": 3,
|
6 |
+
"Widened aortic contour (Present)": 4,
|
7 |
+
"Pleural Thickening (Present)": 5,
|
8 |
+
"Vascular finding (Present)": 6,
|
9 |
+
"Consolidation (Present)": 7,
|
10 |
+
"Pneumothorax (Present)": 8,
|
11 |
+
"Subdiaphragmatic gas (Present)": 9,
|
12 |
+
"Masslike opacity (Present)": 10,
|
13 |
+
"Chest wall finding (Present)": 11,
|
14 |
+
"Focal air space opacity (Present)": 12,
|
15 |
+
"Segmental collapse (Present)": 13,
|
16 |
+
"Fracture (Present)": 14,
|
17 |
+
"Mediastinal mass (Present)": 15,
|
18 |
+
"Solitary masslike opacity (Present)": 16,
|
19 |
+
"Support Devices (Present)": 17,
|
20 |
+
"Mediastinal finding (Present)": 18,
|
21 |
+
"Pleural finding (Present)": 19,
|
22 |
+
"Air space opacity (Present)": 20,
|
23 |
+
"Diffuse air space opacity (Present)": 21,
|
24 |
+
"Multiple masslike opacities (Present)": 22,
|
25 |
+
"Musculoskeletal finding (Present)": 23,
|
26 |
+
"Pleural Effusion (Uncertain)": 24,
|
27 |
+
"Upper abdominal finding (Uncertain)": 25,
|
28 |
+
"Widened cardiac silhouette (Uncertain)": 26,
|
29 |
+
"Lung Finding (Uncertain)": 27,
|
30 |
+
"Widened aortic contour (Uncertain)": 28,
|
31 |
+
"Pleural Thickening (Uncertain)": 29,
|
32 |
+
"Vascular finding (Uncertain)": 30,
|
33 |
+
"Consolidation (Uncertain)": 31,
|
34 |
+
"Pneumothorax (Uncertain)": 32,
|
35 |
+
"Subdiaphragmatic gas (Uncertain)": 33,
|
36 |
+
"Masslike opacity (Uncertain)": 34,
|
37 |
+
"Chest wall finding (Uncertain)": 35,
|
38 |
+
"Focal air space opacity (Uncertain)": 36,
|
39 |
+
"Segmental collapse (Uncertain)": 37,
|
40 |
+
"Fracture (Uncertain)": 38,
|
41 |
+
"Mediastinal mass (Uncertain)": 39,
|
42 |
+
"Solitary masslike opacity (Uncertain)": 40,
|
43 |
+
"Support Devices (Uncertain)": 41,
|
44 |
+
"Mediastinal finding (Uncertain)": 42,
|
45 |
+
"Pleural finding (Uncertain)": 43,
|
46 |
+
"Air space opacity (Uncertain)": 44,
|
47 |
+
"Diffuse air space opacity (Uncertain)": 45,
|
48 |
+
"Multiple masslike opacities (Uncertain)": 46,
|
49 |
+
"Musculoskeletal finding (Uncertain)": 47,
|
50 |
+
"Pleural Effusion (Absent)": 48,
|
51 |
+
"Upper abdominal finding (Absent)": 49,
|
52 |
+
"Widened cardiac silhouette (Absent)": 50,
|
53 |
+
"Lung Finding (Absent)": 51,
|
54 |
+
"Widened aortic contour (Absent)": 52,
|
55 |
+
"Pleural Thickening (Absent)": 53,
|
56 |
+
"Vascular finding (Absent)": 54,
|
57 |
+
"Consolidation (Absent)": 55,
|
58 |
+
"Pneumothorax (Absent)": 56,
|
59 |
+
"Subdiaphragmatic gas (Absent)": 57,
|
60 |
+
"Masslike opacity (Absent)": 58,
|
61 |
+
"Chest wall finding (Absent)": 59,
|
62 |
+
"Focal air space opacity (Absent)": 60,
|
63 |
+
"Segmental collapse (Absent)": 61,
|
64 |
+
"Fracture (Absent)": 62,
|
65 |
+
"Mediastinal mass (Absent)": 63,
|
66 |
+
"Solitary masslike opacity (Absent)": 64,
|
67 |
+
"Support Devices (Absent)": 65,
|
68 |
+
"Mediastinal finding (Absent)": 66,
|
69 |
+
"Pleural finding (Absent)": 67,
|
70 |
+
"Air space opacity (Absent)": 68,
|
71 |
+
"Diffuse air space opacity (Absent)": 69,
|
72 |
+
"Multiple masslike opacities (Absent)": 70,
|
73 |
+
"Musculoskeletal finding (Absent)": 71,
|
74 |
+
"No Finding": 72
|
75 |
+
}
|
76 |
+
|
factual/__init__.py
ADDED
File without changes
|
factual/f1chexbert.py
ADDED
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
"""CheXbert evaluation utilities – **device‑safe end‑to‑end**
|
3 |
+
|
4 |
+
This is a drop‑in replacement for your previous `f1chexbert.py` **and** for the helper
|
5 |
+
`SemanticEmbeddingScorer`. All tensors – model weights *and* inputs – are created on
|
6 |
+
exactly the same device so the ``Expected all tensors to be on the same device``
|
7 |
+
run‑time error disappears. The public API stays identical, so the rest of your
|
8 |
+
pipeline does not need to change.
|
9 |
+
"""
|
10 |
+
|
11 |
+
from __future__ import annotations
|
12 |
+
|
13 |
+
import os
|
14 |
+
import warnings
|
15 |
+
import logging
|
16 |
+
from typing import List, Sequence, Tuple, Union
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import numpy as np
|
21 |
+
from transformers import (
|
22 |
+
AutoConfig,
|
23 |
+
BertModel,
|
24 |
+
BertTokenizer,
|
25 |
+
)
|
26 |
+
from sklearn.metrics import (
|
27 |
+
accuracy_score,
|
28 |
+
classification_report,
|
29 |
+
)
|
30 |
+
from sklearn.metrics._classification import _check_targets
|
31 |
+
from sklearn.utils.sparsefuncs import count_nonzero
|
32 |
+
from huggingface_hub import hf_hub_download
|
33 |
+
from appdirs import user_cache_dir
|
34 |
+
|
35 |
+
# -----------------------------------------------------------------------------
|
36 |
+
# GLOBALS & UTILITIES
|
37 |
+
# -----------------------------------------------------------------------------
|
38 |
+
|
39 |
+
CACHE_DIR = user_cache_dir("chexbert")
|
40 |
+
warnings.filterwarnings("ignore")
|
41 |
+
logging.getLogger("urllib3").setLevel(logging.ERROR)
|
42 |
+
|
43 |
+
# Helper ----------------------------------------------------------------------
|
44 |
+
|
45 |
+
def _generate_attention_masks(batch_ids: torch.LongTensor) -> torch.FloatTensor:
|
46 |
+
"""Create a padding mask: 1 for real tokens, 0 for pads."""
|
47 |
+
# batch_ids shape: (B, L)
|
48 |
+
lengths = (batch_ids != 0).sum(dim=1) # (B,)
|
49 |
+
max_len = batch_ids.size(1)
|
50 |
+
idxs = torch.arange(max_len, device=batch_ids.device).unsqueeze(0) # (1, L)
|
51 |
+
return (idxs < lengths.unsqueeze(1)).float() # (B, L)
|
52 |
+
|
53 |
+
# -----------------------------------------------------------------------------
|
54 |
+
# MODEL COMPONENTS
|
55 |
+
# -----------------------------------------------------------------------------
|
56 |
+
|
57 |
+
class BertLabeler(nn.Module):
|
58 |
+
"""BERT backbone + 14 small classification heads (CheXbert)."""
|
59 |
+
|
60 |
+
def __init__(self, *, device: Union[str, torch.device]):
|
61 |
+
super().__init__()
|
62 |
+
|
63 |
+
if isinstance(device, str):
|
64 |
+
self.device = torch.device(device)
|
65 |
+
else:
|
66 |
+
self.device = device
|
67 |
+
|
68 |
+
# 1) Backbone on *CPU* first – we'll move to correct device after weights load
|
69 |
+
config = AutoConfig.from_pretrained("bert-base-uncased")
|
70 |
+
self.bert = BertModel(config)
|
71 |
+
|
72 |
+
hidden = self.bert.config.hidden_size
|
73 |
+
# 13 heads with 4‑way logits, + 1 head with 2‑way logits
|
74 |
+
self.linear_heads = nn.ModuleList([nn.Linear(hidden, 4) for _ in range(13)])
|
75 |
+
self.linear_heads.append(nn.Linear(hidden, 2))
|
76 |
+
|
77 |
+
self.dropout = nn.Dropout(0.1)
|
78 |
+
|
79 |
+
# 2) Load checkpoint weights directly onto CPU first -------------------
|
80 |
+
ckpt_path = hf_hub_download(
|
81 |
+
repo_id="StanfordAIMI/RRG_scorers",
|
82 |
+
filename="chexbert.pth",
|
83 |
+
cache_dir=CACHE_DIR,
|
84 |
+
)
|
85 |
+
state = torch.load(ckpt_path, map_location="cpu")["model_state_dict"]
|
86 |
+
state = {k.replace("module.", ""): v for k, v in state.items()}
|
87 |
+
self.load_state_dict(state, strict=True)
|
88 |
+
|
89 |
+
# 3) NOW move the entire module (recursively) to `self.device` ----------
|
90 |
+
self.to(self.device)
|
91 |
+
|
92 |
+
# freeze ---------------------------------------------------------------
|
93 |
+
for p in self.parameters():
|
94 |
+
p.requires_grad = False
|
95 |
+
|
96 |
+
# ---------------------------------------------------------------------
|
97 |
+
# forward helpers
|
98 |
+
# ---------------------------------------------------------------------
|
99 |
+
|
100 |
+
@torch.no_grad()
|
101 |
+
def cls_logits(self, input_ids: torch.LongTensor) -> List[torch.Tensor]:
|
102 |
+
"""Returns a list of logits for each head (no softmax)."""
|
103 |
+
attn = _generate_attention_masks(input_ids)
|
104 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attn)
|
105 |
+
cls_repr = self.dropout(outputs.last_hidden_state[:, 0])
|
106 |
+
return [head(cls_repr) for head in self.linear_heads]
|
107 |
+
|
108 |
+
@torch.no_grad()
|
109 |
+
def cls_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
|
110 |
+
"""Returns pooled [CLS] representations (B, hidden_size)."""
|
111 |
+
attn = _generate_attention_masks(input_ids)
|
112 |
+
outputs = self.bert(input_ids=input_ids, attention_mask=attn)
|
113 |
+
return outputs.last_hidden_state[:, 0] # (B, hidden)
|
114 |
+
|
115 |
+
# -----------------------------------------------------------------------------
|
116 |
+
# F1‑CheXbert evaluator
|
117 |
+
# -----------------------------------------------------------------------------
|
118 |
+
|
119 |
+
class F1CheXbert(nn.Module):
|
120 |
+
"""Generate CheXbert labels + handy evaluation utilities."""
|
121 |
+
|
122 |
+
CONDITION_NAMES = [
|
123 |
+
"Enlarged Cardiomediastinum",
|
124 |
+
"Cardiomegaly",
|
125 |
+
"Lung Opacity",
|
126 |
+
"Lung Lesion",
|
127 |
+
"Edema",
|
128 |
+
"Consolidation",
|
129 |
+
"Pneumonia",
|
130 |
+
"Atelectasis",
|
131 |
+
"Pneumothorax",
|
132 |
+
"Pleural Effusion",
|
133 |
+
"Pleural Other",
|
134 |
+
"Fracture",
|
135 |
+
"Support Devices",
|
136 |
+
]
|
137 |
+
NO_FINDING = "No Finding"
|
138 |
+
TARGET_NAMES = CONDITION_NAMES + [NO_FINDING]
|
139 |
+
|
140 |
+
TOP5 = [
|
141 |
+
"Cardiomegaly",
|
142 |
+
"Edema",
|
143 |
+
"Consolidation",
|
144 |
+
"Atelectasis",
|
145 |
+
"Pleural Effusion",
|
146 |
+
]
|
147 |
+
|
148 |
+
def __init__(
|
149 |
+
self,
|
150 |
+
*,
|
151 |
+
refs_filename: str | None = None,
|
152 |
+
hyps_filename: str | None = None,
|
153 |
+
device: Union[str, torch.device] = "cpu",
|
154 |
+
):
|
155 |
+
super().__init__()
|
156 |
+
|
157 |
+
# Resolve device -------------------------------------------------------
|
158 |
+
if isinstance(device, str):
|
159 |
+
self.device = torch.device(device)
|
160 |
+
else:
|
161 |
+
self.device = device
|
162 |
+
|
163 |
+
self.refs_filename = refs_filename
|
164 |
+
self.hyps_filename = hyps_filename
|
165 |
+
|
166 |
+
# HuggingFace tokenizer (always CPU, we just move tensors later) -------
|
167 |
+
self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
168 |
+
|
169 |
+
# backbone + heads ------------------------------------------------------
|
170 |
+
self.model = BertLabeler(device=self.device).eval()
|
171 |
+
|
172 |
+
# indices for the TOP‑5 label subset -----------------------------------
|
173 |
+
self.top5_idx = [self.TARGET_NAMES.index(n) for n in self.TOP5]
|
174 |
+
|
175 |
+
# ---------------------------------------------------------------------
|
176 |
+
# Public helpers
|
177 |
+
# ---------------------------------------------------------------------
|
178 |
+
|
179 |
+
@torch.no_grad()
|
180 |
+
def get_embeddings(self, reports: Sequence[str]) -> List[np.ndarray]:
|
181 |
+
"""Return list[np.ndarray] of pooled [CLS] vectors for each report."""
|
182 |
+
# Tokenise *as a batch* for efficiency
|
183 |
+
encoding = self.tokenizer(
|
184 |
+
reports,
|
185 |
+
padding=True,
|
186 |
+
truncation=True,
|
187 |
+
max_length=512,
|
188 |
+
return_tensors="pt",
|
189 |
+
)
|
190 |
+
input_ids = encoding.input_ids.to(self.device)
|
191 |
+
# (B, hidden)
|
192 |
+
cls = self.model.cls_embeddings(input_ids)
|
193 |
+
return [v.cpu().numpy() for v in cls]
|
194 |
+
|
195 |
+
@torch.no_grad()
|
196 |
+
def get_label(self, report: str, mode: str = "rrg") -> List[int]:
|
197 |
+
"""Return 14‑dim binary vector for the given report."""
|
198 |
+
input_ids = self.tokenizer(report, truncation=True, max_length=512, return_tensors="pt").input_ids.to(self.device)
|
199 |
+
preds = [head.argmax(dim=1).item() for head in self.model.cls_logits(input_ids)]
|
200 |
+
|
201 |
+
binary = []
|
202 |
+
if mode == "rrg":
|
203 |
+
for c in preds:
|
204 |
+
binary.append(1 if c in {1, 3} else 0)
|
205 |
+
elif mode == "classification":
|
206 |
+
for c in preds:
|
207 |
+
if c == 1:
|
208 |
+
binary.append(1)
|
209 |
+
elif c == 2:
|
210 |
+
binary.append(0)
|
211 |
+
elif c == 3:
|
212 |
+
binary.append(-1)
|
213 |
+
else:
|
214 |
+
binary.append(0)
|
215 |
+
else:
|
216 |
+
raise ValueError(f"Unknown mode: {mode}")
|
217 |
+
return binary
|
218 |
+
|
219 |
+
# ---------------------------------------------------------------------
|
220 |
+
# Full evaluator – unchanged logic but simplified I/O
|
221 |
+
# ---------------------------------------------------------------------
|
222 |
+
|
223 |
+
def forward(self, hyps: List[str], refs: List[str]):
|
224 |
+
"""Return (accuracy, per‑example‑accuracy, full classification reports)."""
|
225 |
+
# Reference labels -----------------------------------------------------
|
226 |
+
if self.refs_filename and os.path.exists(self.refs_filename):
|
227 |
+
with open(self.refs_filename) as f:
|
228 |
+
refs_chexbert = [eval(line) for line in f]
|
229 |
+
else:
|
230 |
+
refs_chexbert = [self.get_label(r) for r in refs]
|
231 |
+
if self.refs_filename:
|
232 |
+
with open(self.refs_filename, "w") as f:
|
233 |
+
f.write("\n".join(map(str, refs_chexbert)))
|
234 |
+
|
235 |
+
# Hypothesis labels ----------------------------------------------------
|
236 |
+
hyps_chexbert = [self.get_label(h) for h in hyps]
|
237 |
+
if self.hyps_filename:
|
238 |
+
with open(self.hyps_filename, "w") as f:
|
239 |
+
f.write("\n".join(map(str, hyps_chexbert)))
|
240 |
+
|
241 |
+
# TOP‑5 subset arrays --------------------------------------------------
|
242 |
+
refs5 = [np.array(r)[self.top5_idx] for r in refs_chexbert]
|
243 |
+
hyps5 = [np.array(h)[self.top5_idx] for h in hyps_chexbert]
|
244 |
+
|
245 |
+
# overall accuracy -----------------------------------------------------
|
246 |
+
accuracy = accuracy_score(refs5, hyps5)
|
247 |
+
_, y_true, y_pred = _check_targets(refs5, hyps5)
|
248 |
+
pe_accuracy = (count_nonzero(y_true - y_pred, axis=1) == 0).astype(float)
|
249 |
+
|
250 |
+
# full classification reports -----------------------------------------
|
251 |
+
cr = classification_report(refs_chexbert, hyps_chexbert, target_names=self.TARGET_NAMES, output_dict=True)
|
252 |
+
cr5 = classification_report(refs5, hyps5, target_names=self.TOP5, output_dict=True)
|
253 |
+
|
254 |
+
return accuracy, pe_accuracy, cr, cr5
|
factual/f1temporal.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Temporal Entity F1
|
2 |
+
# Adopted from https://github.com/X-iZhang/Libra/blob/main/libra/eval/temporal_f1.py
|
3 |
+
|
4 |
+
import re
|
5 |
+
import stanza
|
6 |
+
import argparse
|
7 |
+
from typing import List, Union
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
# Initialize the pipeline with the radiology NER model explicitly specified
|
12 |
+
nlp = stanza.Pipeline(
|
13 |
+
lang='en',
|
14 |
+
package='radiology',
|
15 |
+
processors={'tokenize': 'default', 'ner': 'radiology'},
|
16 |
+
logging_level='ERROR', # Only output warnings or more severe messages
|
17 |
+
verbose=False # Suppress additional information during pipeline initialization
|
18 |
+
)
|
19 |
+
|
20 |
+
# Keywords used for radiology-related entity extraction
|
21 |
+
# Reference: Learning to Exploit Temporal Structure for Biomedical Vision-Language Processing (CVPR2023)
|
22 |
+
# https://arxiv.org/pdf/2301.04558
|
23 |
+
|
24 |
+
KEYWORDS = {
|
25 |
+
"bigger", "change", "cleared", "constant", "decrease", "decreased", "decreasing", "elevated", "elevation",
|
26 |
+
"enlarged", "enlargement", "enlarging", "expanded", "greater", "growing", "improved", "improvement",
|
27 |
+
"improving", "increase", "increased", "increasing", "larger", "new", "persistence", "persistent",
|
28 |
+
"persisting", "progression", "progressive", "reduced", "removal", "resolution", "resolved", "resolving",
|
29 |
+
"smaller", "stability", "stable", "stably", "unchanged", "unfolded", "worse", "worsen", "worsened",
|
30 |
+
"worsening", "unaltered"
|
31 |
+
}
|
32 |
+
|
33 |
+
def clean_text(text: str) -> str:
|
34 |
+
"""
|
35 |
+
Clean the input text by removing special characters and redundant spaces or newlines.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
text (str): Input text.
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
str: Cleaned text.
|
42 |
+
"""
|
43 |
+
# Remove special characters and redundant newlines
|
44 |
+
text = re.sub(r'\n+', ' ', text) # Replace multiple newlines with a single space
|
45 |
+
text = re.sub(r'[_-]+', ' ', text) # Replace underscores and dashes with spaces
|
46 |
+
text = re.sub(r'\(___, __, __\)', '', text) # Remove irrelevant underscore patterns
|
47 |
+
text = re.sub(r'---, ---, ---', '', text) # Remove dashed patterns
|
48 |
+
text = re.sub(r'\(__, __, ___\)', '', text) # Remove similar underscore patterns
|
49 |
+
text = re.sub(r'[_-]+', ' ', text) # Replace underscores and dashes again (if any remain)
|
50 |
+
text = re.sub(r'[^\w\s.,:;()-]', '', text) # Remove non-alphanumeric characters except common punctuation
|
51 |
+
|
52 |
+
# Remove extra spaces
|
53 |
+
text = re.sub(r'\s{2,}', ' ', text).strip()
|
54 |
+
return text
|
55 |
+
|
56 |
+
def extract_entities(text: str, keywords: set) -> set:
|
57 |
+
"""
|
58 |
+
Extract entities from the given text based on Stanza NER and provided keywords.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
text (str): Input text.
|
62 |
+
keywords (set): Set of keywords to extract entities.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
set: Set of matched entities found in the text.
|
66 |
+
"""
|
67 |
+
# Use Stanza NER to extract entities tagged as "OBSERVATION" or "OBSERVATION_MODIFIER"
|
68 |
+
doc = nlp(text)
|
69 |
+
stanza_entities = {ent.text.lower() for ent in doc.entities if ent.type in {"OBSERVATION", "OBSERVATION_MODIFIER"}}
|
70 |
+
|
71 |
+
# Filter Stanza entities to include only those present in keywords
|
72 |
+
matched_stanza_entities = {entity for entity in stanza_entities if entity in keywords}
|
73 |
+
|
74 |
+
# Clean the text before extracting entities
|
75 |
+
text = clean_text(text)
|
76 |
+
|
77 |
+
# Create a regex pattern that matches any of the keywords as whole words
|
78 |
+
pattern = r'\b(' + '|'.join(re.escape(word) for word in keywords) + r')\b'
|
79 |
+
|
80 |
+
# Find all matches using regex
|
81 |
+
keyword_matches = {match.group().lower() for match in re.finditer(pattern, text.lower())}
|
82 |
+
|
83 |
+
# Combine Stanza entities and regex matches
|
84 |
+
return matched_stanza_entities | keyword_matches
|
85 |
+
|
86 |
+
def calculate_tem_score(prediction_text: str, reference_text: Union[str, List[str]], epsilon: float = 1e-10) -> float:
|
87 |
+
"""
|
88 |
+
Calculate the Temporal Entity Matching (TEM) score (similar to F1-score).
|
89 |
+
|
90 |
+
Args:
|
91 |
+
reference_text (Union[str, List[str]]): Reference text or a list of reference texts.
|
92 |
+
prediction_text (str): Prediction text.
|
93 |
+
epsilon (float): Small value to avoid division by zero.
|
94 |
+
|
95 |
+
Returns:
|
96 |
+
float: TEM score.
|
97 |
+
"""
|
98 |
+
if isinstance(reference_text, list):
|
99 |
+
reference_entities = set()
|
100 |
+
for ref in reference_text:
|
101 |
+
reference_entities.update(extract_entities(ref, KEYWORDS))
|
102 |
+
else:
|
103 |
+
reference_entities = extract_entities(reference_text, KEYWORDS)
|
104 |
+
|
105 |
+
prediction_entities = extract_entities(prediction_text, KEYWORDS)
|
106 |
+
|
107 |
+
if len(reference_entities) == 0:
|
108 |
+
if len(prediction_entities) == 0:
|
109 |
+
return {
|
110 |
+
"f1": 1.0,
|
111 |
+
"prediction_entities": prediction_entities,
|
112 |
+
"reference_entities": reference_entities
|
113 |
+
} # Perfect match when both are empty
|
114 |
+
else:
|
115 |
+
return {
|
116 |
+
"f1": epsilon,
|
117 |
+
"prediction_entities": prediction_entities,
|
118 |
+
"reference_entities": reference_entities
|
119 |
+
} # Minimal score when reference is empty but prediction is not
|
120 |
+
|
121 |
+
# Calculate intersection of entities
|
122 |
+
true_positives = len(prediction_entities & reference_entities)
|
123 |
+
|
124 |
+
# Calculate precision and recall with epsilon to avoid division by zero
|
125 |
+
precision = (true_positives + epsilon) / (len(prediction_entities) + epsilon)
|
126 |
+
recall = (true_positives + epsilon) / (len(reference_entities) + epsilon)
|
127 |
+
|
128 |
+
# Calculate TEM score (F1 score)
|
129 |
+
tem_score = (2 * precision * recall) / (precision + recall + epsilon)
|
130 |
+
|
131 |
+
return {
|
132 |
+
"f1": tem_score,
|
133 |
+
"prediction_entities": prediction_entities,
|
134 |
+
"reference_entities": reference_entities
|
135 |
+
}
|
136 |
+
|
137 |
+
def F1Temporal(predictions: List[str], references: List[Union[str, List[str]]], epsilon: float = 1e-10) -> dict:
|
138 |
+
"""
|
139 |
+
Calculate the average TEM score over a list of reference and prediction texts.
|
140 |
+
|
141 |
+
Args:
|
142 |
+
references (List[Union[str, List[str]]]): List of reference texts or lists of reference texts.
|
143 |
+
predictions (List[str]): List of prediction texts.
|
144 |
+
epsilon (float): Small value to avoid division by zero.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
float: Average TEM score.
|
148 |
+
"""
|
149 |
+
assert len(references) == len(predictions), "Reference and prediction lists must have the same length."
|
150 |
+
|
151 |
+
tem_scores = []
|
152 |
+
prediction_entities = []
|
153 |
+
reference_entities = []
|
154 |
+
|
155 |
+
for pred, ref in zip(predictions, references):
|
156 |
+
result = calculate_tem_score(pred, ref, epsilon)
|
157 |
+
tem_scores.append(result["f1"])
|
158 |
+
prediction_entities.append(result["prediction_entities"])
|
159 |
+
reference_entities.append(result["reference_entities"])
|
160 |
+
|
161 |
+
average_f1 = sum(tem_scores) / len(tem_scores)
|
162 |
+
|
163 |
+
return {
|
164 |
+
"f1": average_f1,
|
165 |
+
"prediction_entities": prediction_entities,
|
166 |
+
"reference_entities": reference_entities
|
167 |
+
}
|
factual/green_score/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .green import GREEN
|
factual/green_score/green.py
ADDED
@@ -0,0 +1,465 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import torch
|
3 |
+
import torch.distributed as dist
|
4 |
+
import pandas as pd
|
5 |
+
from datasets import Dataset
|
6 |
+
from datasets.distributed import split_dataset_by_node
|
7 |
+
import os
|
8 |
+
from tqdm import tqdm
|
9 |
+
import numpy as np
|
10 |
+
import time
|
11 |
+
import sys
|
12 |
+
import warnings
|
13 |
+
import torch.nn as nn
|
14 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
15 |
+
from transformers.utils import logging
|
16 |
+
|
17 |
+
# Import necessary functions (ensure these are available in your environment)
|
18 |
+
from factual.green_score.utils import (
|
19 |
+
gather_processes,
|
20 |
+
make_prompt,
|
21 |
+
clean_responses,
|
22 |
+
compute_largest_cluster,
|
23 |
+
flatten_values_lists_of_list_dicts_to_dict,
|
24 |
+
)
|
25 |
+
|
26 |
+
# Set the logging level for the transformers library to ERROR to suppress benign warnings
|
27 |
+
logging.get_logger("transformers").setLevel(logging.ERROR)
|
28 |
+
|
29 |
+
def get_rank():
|
30 |
+
if not dist.is_initialized():
|
31 |
+
return 0
|
32 |
+
return dist.get_rank()
|
33 |
+
|
34 |
+
|
35 |
+
def is_main_process():
|
36 |
+
return get_rank() == 0
|
37 |
+
|
38 |
+
|
39 |
+
def tqdm_on_main(*args, **kwargs):
|
40 |
+
if is_main_process():
|
41 |
+
print("==== Beginning Inference ====")
|
42 |
+
return tqdm(*args, **kwargs)
|
43 |
+
else:
|
44 |
+
return kwargs.get("iterable", None)
|
45 |
+
|
46 |
+
|
47 |
+
class GREEN:
|
48 |
+
def __init__(self, model_name, output_dir=".", cpu=False):
|
49 |
+
super().__init__()
|
50 |
+
warnings.filterwarnings(
|
51 |
+
"ignore", message="A decoder-only architecture is being used*"
|
52 |
+
)
|
53 |
+
from sklearn.exceptions import ConvergenceWarning
|
54 |
+
|
55 |
+
warnings.filterwarnings(
|
56 |
+
"ignore",
|
57 |
+
category=ConvergenceWarning,
|
58 |
+
message="Number of distinct clusters.*",
|
59 |
+
)
|
60 |
+
warnings.filterwarnings(
|
61 |
+
"ignore",
|
62 |
+
category=FutureWarning,
|
63 |
+
module="transformers.tokenization_utils_base",
|
64 |
+
)
|
65 |
+
self.cpu = cpu
|
66 |
+
self.model_name = model_name.split("/")[-1]
|
67 |
+
self.output_dir = output_dir
|
68 |
+
self.batch_size = 4
|
69 |
+
self.max_length = 2048
|
70 |
+
self.categories = [
|
71 |
+
"Clinically Significant Errors",
|
72 |
+
"Clinically Insignificant Errors",
|
73 |
+
"Matched Findings",
|
74 |
+
]
|
75 |
+
self.sub_categories = [
|
76 |
+
"(a) False report of a finding in the candidate",
|
77 |
+
"(b) Missing a finding present in the reference",
|
78 |
+
"(c) Misidentification of a finding's anatomic location/position",
|
79 |
+
"(d) Misassessment of the severity of a finding",
|
80 |
+
"(e) Mentioning a comparison that isn't in the reference",
|
81 |
+
"(f) Omitting a comparison detailing a change from a prior study",
|
82 |
+
]
|
83 |
+
self.prompts = None
|
84 |
+
self.completions = None
|
85 |
+
self.green_scores = None
|
86 |
+
self.error_counts = None
|
87 |
+
|
88 |
+
if torch.cuda.is_available() and torch.cuda.device_count() > 1 and not self.cpu:
|
89 |
+
if not dist.is_initialized():
|
90 |
+
dist.init_process_group(
|
91 |
+
backend="nccl",
|
92 |
+
)
|
93 |
+
torch.cuda.set_device(dist.get_rank())
|
94 |
+
if dist.get_rank() == 0:
|
95 |
+
print(
|
96 |
+
"Distributed training with", torch.cuda.device_count(), "GPUs"
|
97 |
+
)
|
98 |
+
|
99 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
100 |
+
model_name,
|
101 |
+
trust_remote_code=False if "Phi" in model_name else True,
|
102 |
+
device_map=(
|
103 |
+
{"": "cuda:{}".format(torch.cuda.current_device())}
|
104 |
+
if not self.cpu
|
105 |
+
else {"": "cpu"}
|
106 |
+
),
|
107 |
+
torch_dtype=torch.float16,
|
108 |
+
)
|
109 |
+
self.model.eval()
|
110 |
+
|
111 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
112 |
+
model_name,
|
113 |
+
add_eos_token=True,
|
114 |
+
use_fast=True,
|
115 |
+
trust_remote_code=True,
|
116 |
+
padding_side="left",
|
117 |
+
)
|
118 |
+
|
119 |
+
# Set up chat template for chat-style prompts
|
120 |
+
chat_template = (
|
121 |
+
"{% for message in messages %}\n"
|
122 |
+
"{% if message['from'] == 'human' %}\n"
|
123 |
+
"{{ '<|user|>\n' + message['value'] + eos_token }}\n"
|
124 |
+
"{% elif message['from'] == 'system' %}\n"
|
125 |
+
"{{ '<|system|>\n' + message['value'] + eos_token }}\n"
|
126 |
+
"{% elif message['from'] == 'gpt' %}\n"
|
127 |
+
"{{ '<|assistant|>\n' + message['value'] + eos_token }}\n"
|
128 |
+
"{% endif %}\n"
|
129 |
+
"{% if loop.last and add_generation_prompt %}\n"
|
130 |
+
"{{ '<|assistant|>' }}\n"
|
131 |
+
"{% endif %}\n"
|
132 |
+
"{% endfor %}"
|
133 |
+
)
|
134 |
+
|
135 |
+
self.tokenizer.chat_template = chat_template
|
136 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
137 |
+
self.tokenizer.clean_up_tokenization_spaces = True
|
138 |
+
self.tokenizer.padding_side = "left"
|
139 |
+
|
140 |
+
def __call__(self, refs, hyps):
|
141 |
+
print("Processing data...making prompts")
|
142 |
+
dataset = Dataset.from_dict({"reference": refs, "prediction": hyps})
|
143 |
+
|
144 |
+
dataset = self.process_data(dataset)
|
145 |
+
print("Done.")
|
146 |
+
|
147 |
+
self.dataset = dataset
|
148 |
+
|
149 |
+
t = time.time()
|
150 |
+
|
151 |
+
mean, std, green_scores, summary, results_df = self.infer()
|
152 |
+
|
153 |
+
t = time.time() - t
|
154 |
+
print("Seconds per example: ", t / len(refs))
|
155 |
+
|
156 |
+
if not is_main_process():
|
157 |
+
print(f"Rank {dist.get_rank()} exiting.")
|
158 |
+
dist.destroy_process_group()
|
159 |
+
sys.exit()
|
160 |
+
|
161 |
+
return mean, std, green_scores, summary, results_df
|
162 |
+
|
163 |
+
def process_data(self, dataset):
|
164 |
+
def prompting(examples):
|
165 |
+
return {
|
166 |
+
"prompt": [
|
167 |
+
make_prompt(r, p)
|
168 |
+
for r, p in zip(examples["reference"], examples["prediction"])
|
169 |
+
]
|
170 |
+
}
|
171 |
+
|
172 |
+
dataset = dataset.map(prompting, batched=True)
|
173 |
+
return dataset
|
174 |
+
|
175 |
+
@torch.inference_mode()
|
176 |
+
def infer(self):
|
177 |
+
if torch.cuda.is_available() and torch.cuda.device_count() > 1 and not self.cpu:
|
178 |
+
dataset_dist = split_dataset_by_node(
|
179 |
+
self.dataset,
|
180 |
+
rank=get_rank(),
|
181 |
+
world_size=int(os.environ["WORLD_SIZE"]),
|
182 |
+
)
|
183 |
+
print("Distributed dataset created on rank: ", int(os.environ["RANK"]))
|
184 |
+
else:
|
185 |
+
dataset_dist = self.dataset
|
186 |
+
|
187 |
+
local_completions = []
|
188 |
+
local_references = []
|
189 |
+
|
190 |
+
for batch in tqdm_on_main(
|
191 |
+
iterable=dataset_dist.iter(batch_size=self.batch_size),
|
192 |
+
total=len(dataset_dist) // self.batch_size,
|
193 |
+
):
|
194 |
+
local_references.extend(batch["prompt"])
|
195 |
+
local_completions.extend(self.get_response(batch))
|
196 |
+
|
197 |
+
if torch.cuda.is_available() and torch.cuda.device_count() > 1 and not self.cpu:
|
198 |
+
self.completions, self.prompts = gather_processes(
|
199 |
+
local_completions, local_references
|
200 |
+
)
|
201 |
+
else:
|
202 |
+
self.completions = local_completions
|
203 |
+
self.prompts = local_references
|
204 |
+
|
205 |
+
if is_main_process():
|
206 |
+
print("==== End Inference ====")
|
207 |
+
|
208 |
+
if len(self.completions) != len(self.prompts):
|
209 |
+
print("Length of prompts and completions are not equal!")
|
210 |
+
|
211 |
+
return self.process_results()
|
212 |
+
|
213 |
+
def tokenize_batch_as_chat(self, batch):
|
214 |
+
local_rank = int(os.environ.get("LOCAL_RANK", 0)) if not self.cpu else "cpu"
|
215 |
+
batch = [
|
216 |
+
self.tokenizer.apply_chat_template(
|
217 |
+
i, tokenize=False, add_generation_prompt=True
|
218 |
+
)
|
219 |
+
for i in batch
|
220 |
+
]
|
221 |
+
|
222 |
+
batch = self.tokenizer.batch_encode_plus(
|
223 |
+
batch,
|
224 |
+
return_tensors="pt",
|
225 |
+
padding=True,
|
226 |
+
truncation=True,
|
227 |
+
max_length=self.max_length,
|
228 |
+
).to(local_rank)
|
229 |
+
|
230 |
+
return batch
|
231 |
+
|
232 |
+
def get_response(self, batch):
|
233 |
+
assert "prompt" in batch.keys(), "prompt is not in batch keys"
|
234 |
+
|
235 |
+
batch = [
|
236 |
+
[{"from": "human", "value": prompt}, {"from": "gpt", "value": ""}]
|
237 |
+
for prompt in batch["prompt"]
|
238 |
+
]
|
239 |
+
|
240 |
+
batch = self.tokenize_batch_as_chat(batch)
|
241 |
+
|
242 |
+
outputs = self.model.generate(
|
243 |
+
input_ids=batch["input_ids"],
|
244 |
+
attention_mask=batch["attention_mask"],
|
245 |
+
eos_token_id=self.tokenizer.eos_token_id,
|
246 |
+
pad_token_id=self.tokenizer.pad_token_id,
|
247 |
+
max_length=2048,
|
248 |
+
do_sample=False,
|
249 |
+
temperature=None,
|
250 |
+
top_p=None,
|
251 |
+
)
|
252 |
+
|
253 |
+
responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
254 |
+
|
255 |
+
response_list = []
|
256 |
+
if isinstance(responses, list):
|
257 |
+
for response in responses:
|
258 |
+
response = clean_responses(response)
|
259 |
+
response_list.append(response)
|
260 |
+
else:
|
261 |
+
responses = clean_responses(responses)
|
262 |
+
response_list.append(responses)
|
263 |
+
|
264 |
+
return response_list
|
265 |
+
|
266 |
+
def process_results(self):
|
267 |
+
self.green_scores = [
|
268 |
+
self.compute_green(response) for response in self.completions
|
269 |
+
]
|
270 |
+
self.error_counts = pd.DataFrame(
|
271 |
+
[self.compute_error_count(response) for response in self.completions],
|
272 |
+
columns=self.sub_categories + ["Matched Findings"],
|
273 |
+
)
|
274 |
+
|
275 |
+
results_df = pd.DataFrame(
|
276 |
+
{
|
277 |
+
"reference": self.dataset["reference"],
|
278 |
+
"predictions": self.dataset["prediction"],
|
279 |
+
"green_analysis": self.completions,
|
280 |
+
"green_score": self.green_scores,
|
281 |
+
**self.error_counts,
|
282 |
+
}
|
283 |
+
)
|
284 |
+
|
285 |
+
mean, std, summary = self.compute_summary()
|
286 |
+
|
287 |
+
return mean, std, self.green_scores, summary, results_df
|
288 |
+
|
289 |
+
def compute_error_count(self, response):
|
290 |
+
_, sig_errors = self.parse_error_counts(response, self.categories[0])
|
291 |
+
matched_findings, _ = self.parse_error_counts(response, self.categories[2])
|
292 |
+
return sig_errors + [matched_findings]
|
293 |
+
|
294 |
+
def compute_green(self, response):
|
295 |
+
sig_present, sig_errors = self.parse_error_counts(response, self.categories[0])
|
296 |
+
matched_findings, _ = self.parse_error_counts(response, self.categories[2])
|
297 |
+
|
298 |
+
if matched_findings == 0:
|
299 |
+
return 0
|
300 |
+
|
301 |
+
if sig_present is None or matched_findings is None:
|
302 |
+
return None
|
303 |
+
|
304 |
+
return matched_findings / (matched_findings + sum(sig_errors))
|
305 |
+
|
306 |
+
def parse_error_counts(self, text, category, for_reward=False):
|
307 |
+
if category not in self.categories:
|
308 |
+
raise ValueError(
|
309 |
+
f"Category {category} is not a valid category. Please choose from {self.categories}."
|
310 |
+
)
|
311 |
+
|
312 |
+
pattern = rf"\[{category}\]:\s*(.*?)(?:\n\s*\n|\Z)"
|
313 |
+
category_text = re.search(pattern, text, re.DOTALL)
|
314 |
+
|
315 |
+
sum_counts = 0
|
316 |
+
sub_counts = [0 for i in range(6)]
|
317 |
+
|
318 |
+
if not category_text:
|
319 |
+
if for_reward:
|
320 |
+
return None, None
|
321 |
+
return sum_counts, sub_counts
|
322 |
+
if category_text.group(1).startswith("No"):
|
323 |
+
return sum_counts, sub_counts
|
324 |
+
|
325 |
+
if category == "Matched Findings":
|
326 |
+
counts = re.findall(r"^\b\d+\b(?=\.)", category_text.group(1))
|
327 |
+
if len(counts) > 0:
|
328 |
+
sum_counts = int(counts[0])
|
329 |
+
return sum_counts, sub_counts
|
330 |
+
else:
|
331 |
+
sub_categories = [s.split(" ", 1)[0] + " " for s in self.sub_categories]
|
332 |
+
matches = sorted(re.findall(r"\([a-f]\) .*", category_text.group(1)))
|
333 |
+
|
334 |
+
if len(matches) == 0:
|
335 |
+
matches = sorted(re.findall(r"\([1-6]\) .*", category_text.group(1)))
|
336 |
+
sub_categories = [
|
337 |
+
f"({i})" + " " for i in range(1, len(self.sub_categories) + 1)
|
338 |
+
]
|
339 |
+
|
340 |
+
for position, sub_category in enumerate(sub_categories):
|
341 |
+
for match in range(len(matches)):
|
342 |
+
if matches[match].startswith(sub_category):
|
343 |
+
count = re.findall(r"(?<=: )\b\d+\b(?=\.)", matches[match])
|
344 |
+
if len(count) > 0:
|
345 |
+
sub_counts[position] = int(count[0])
|
346 |
+
return sum(sub_counts), sub_counts
|
347 |
+
|
348 |
+
def parse_error_sentences(self, response, category):
|
349 |
+
if category not in self.categories:
|
350 |
+
raise ValueError(
|
351 |
+
f"Category {category} is not a valid category. Please choose from {self.categories}."
|
352 |
+
)
|
353 |
+
pattern = rf"\[{category}\]:\s*(.*?)(?:\n\s*\n|\Z)"
|
354 |
+
category_text = re.search(pattern, response, re.DOTALL)
|
355 |
+
sub_category_dict_sentences = {}
|
356 |
+
for sub_category in self.sub_categories:
|
357 |
+
sub_category_dict_sentences[sub_category] = []
|
358 |
+
|
359 |
+
if not category_text:
|
360 |
+
return sub_category_dict_sentences
|
361 |
+
if category_text.group(1).startswith("No"):
|
362 |
+
return sub_category_dict_sentences
|
363 |
+
|
364 |
+
if category == "Matched Findings":
|
365 |
+
return (
|
366 |
+
category_text.group(1).rsplit(":", 1)[-1].rsplit(".", 1)[-1].split(";")
|
367 |
+
)
|
368 |
+
|
369 |
+
matches = sorted(re.findall(r"\([a-f]\) .*", category_text.group(1)))
|
370 |
+
|
371 |
+
if len(matches) == 0:
|
372 |
+
matches = sorted(re.findall(r"\([1-6]\) .*", category_text.group(1)))
|
373 |
+
self.sub_categories = [
|
374 |
+
f"({i})" + " " for i in range(1, len(self.sub_categories) + 1)
|
375 |
+
]
|
376 |
+
|
377 |
+
for position, sub_category in enumerate(self.sub_categories):
|
378 |
+
for match in range(len(matches)):
|
379 |
+
if matches[match].startswith(sub_category):
|
380 |
+
sentences_list = (
|
381 |
+
matches[match].rsplit(":", 1)[-1].split(".", 1)[-1].split(";")
|
382 |
+
)
|
383 |
+
sub_category_dict_sentences[self.sub_categories[position]] = (
|
384 |
+
sentences_list
|
385 |
+
)
|
386 |
+
|
387 |
+
return sub_category_dict_sentences
|
388 |
+
|
389 |
+
def compute_sentences(self, response):
|
390 |
+
return self.parse_error_sentences(response, self.categories[0])
|
391 |
+
|
392 |
+
def get_representative_sentences(self, responses):
|
393 |
+
list_sentences = []
|
394 |
+
for i in responses:
|
395 |
+
sentences = self.compute_sentences(i)
|
396 |
+
list_sentences.append(sentences)
|
397 |
+
|
398 |
+
dict_sentences = flatten_values_lists_of_list_dicts_to_dict(list_sentences)
|
399 |
+
|
400 |
+
result_sentences_dict = {}
|
401 |
+
|
402 |
+
for i in self.sub_categories:
|
403 |
+
sentences = dict_sentences[i]
|
404 |
+
sentences = [i for i in sentences if i.strip() != ""]
|
405 |
+
_, sentences_of_largest_cluster = compute_largest_cluster(sentences)
|
406 |
+
result_sentences_dict[i] = sentences_of_largest_cluster
|
407 |
+
|
408 |
+
return result_sentences_dict
|
409 |
+
|
410 |
+
def compute_accuracy(self, responses):
|
411 |
+
counts = []
|
412 |
+
for response in responses:
|
413 |
+
_, sig_errors = self.parse_error_counts(response, self.categories[0])
|
414 |
+
counts.append(sig_errors)
|
415 |
+
|
416 |
+
counts = np.array(counts)
|
417 |
+
|
418 |
+
dict_acc = {}
|
419 |
+
for i in range(len(self.sub_categories)):
|
420 |
+
error_counts = counts[:, i]
|
421 |
+
accuracy = np.mean(error_counts == 0)
|
422 |
+
dict_acc[self.sub_categories[i]] = accuracy
|
423 |
+
|
424 |
+
return dict_acc
|
425 |
+
|
426 |
+
def compute_summary(self):
|
427 |
+
print("Computing summary ...")
|
428 |
+
representative_sentences = self.get_representative_sentences(self.completions)
|
429 |
+
accuracies = self.compute_accuracy(self.completions)
|
430 |
+
mean = np.mean(self.green_scores)
|
431 |
+
std = np.std(self.green_scores)
|
432 |
+
|
433 |
+
summary = f"\n-------------{self.model_name}----------------\n [Summary]: Green average {mean} and standard deviation {std} \n [Clinically Significant Errors Analyses]: <accuracy>. <representative error>\n\n"
|
434 |
+
for idx, sub_category in enumerate(self.sub_categories):
|
435 |
+
accuracy = accuracies[sub_category]
|
436 |
+
sentences = representative_sentences[sub_category]
|
437 |
+
summary += f"{sub_category}: {accuracy}. \n {sentences} \n\n"
|
438 |
+
summary += "----------------------------------\n"
|
439 |
+
|
440 |
+
return mean, std, summary
|
441 |
+
|
442 |
+
|
443 |
+
if __name__ == "__main__":
|
444 |
+
refs = [
|
445 |
+
"Interstitial opacities without changes.",
|
446 |
+
"Interval development of segmental heterogeneous airspace opacities throughout the lungs . No significant pneumothorax or pleural effusion . Bilateral calcified pleural plaques are scattered throughout the lungs . The heart is not significantly enlarged .",
|
447 |
+
"Lung volumes are low, causing bronchovascular crowding. The cardiomediastinal silhouette is unremarkable. No focal consolidation, pleural effusion, or pneumothorax detected. Within the limitations of chest radiography, osseous structures are unremarkable.",
|
448 |
+
]
|
449 |
+
hyps = [
|
450 |
+
"Interstitial opacities at bases without changes.",
|
451 |
+
"Interval development of segmental heterogeneous airspace opacities throughout the lungs . No significant pneumothorax or pleural effusion . Bilateral calcified pleural plaques are scattered throughout the lungs . The heart is not significantly enlarged .",
|
452 |
+
"Endotracheal and nasogastric tubes have been removed. Changes of median sternotomy, with continued leftward displacement of the fourth inferiomost sternal wire. There is continued moderate-to-severe enlargement of the cardiac silhouette. Pulmonary aeration is slightly improved, with residual left lower lobe atelectasis. Stable central venous congestion and interstitial pulmonary edema. Small bilateral pleural effusions are unchanged.",
|
453 |
+
]
|
454 |
+
|
455 |
+
model_name = "StanfordAIMI/GREEN-radllama2-7b"
|
456 |
+
|
457 |
+
green_scorer = GREEN(model_name, output_dir=".")
|
458 |
+
mean, std, green_score_list, summary, result_df = green_scorer(refs, hyps)
|
459 |
+
print(green_score_list)
|
460 |
+
print(summary)
|
461 |
+
# for index, row in result_df.iterrows():
|
462 |
+
# print(f"Row {index}:\n")
|
463 |
+
# for col_name in result_df.columns:
|
464 |
+
# print(f"{col_name}: {row[col_name]}\n")
|
465 |
+
# print('-' * 80)
|
factual/green_score/utils.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.distributed as dist
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
from sklearn.metrics import silhouette_score
|
6 |
+
from sklearn.cluster import KMeans
|
7 |
+
from sklearn import preprocessing
|
8 |
+
from sentence_transformers import SentenceTransformer
|
9 |
+
from scipy.spatial import distance
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
# A dictionary to store rewards for pairs of reference and hypothesis reports
|
13 |
+
|
14 |
+
|
15 |
+
def compute_largest_cluster(sentences):
|
16 |
+
"""
|
17 |
+
Computes the largest cluster of sentences using K-means clustering, finds the sentences within the largest cluster, and orders them by their distance to the cluster center.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
sentences (list): List of sentences to be clustered.
|
21 |
+
|
22 |
+
Returns:
|
23 |
+
tuple: A tuple containing:
|
24 |
+
- embeddings (ndarray): Normalized embeddings of the input sentences.
|
25 |
+
- sentences_of_largest_cluster (list): Sentences in the largest cluster, ordered by their proximity
|
26 |
+
to the cluster center.
|
27 |
+
"""
|
28 |
+
if len(sentences) == 0:
|
29 |
+
return None, None
|
30 |
+
embeddings, kmeans = compute_kmeans(sentences)
|
31 |
+
cluster_sizes = np.bincount(kmeans.labels_)
|
32 |
+
largest_cluster_idx = np.argmax(cluster_sizes)
|
33 |
+
cluster_member_ids = np.where(kmeans.labels_ == largest_cluster_idx)[0]
|
34 |
+
sentences_of_largest_cluster = [sentences[i] for i in cluster_member_ids]
|
35 |
+
|
36 |
+
largest_cluster_mean = kmeans.cluster_centers_[largest_cluster_idx]
|
37 |
+
embeddings_of_largest_cluster = [embeddings[i] for i in cluster_member_ids]
|
38 |
+
distances = distance.cdist(
|
39 |
+
embeddings_of_largest_cluster, [largest_cluster_mean], "cosine"
|
40 |
+
).flatten()
|
41 |
+
closest_point_indices = np.argsort(distances)[0]
|
42 |
+
|
43 |
+
sentences_of_largest_cluster = sentences_of_largest_cluster[closest_point_indices]
|
44 |
+
|
45 |
+
return embeddings, sentences_of_largest_cluster
|
46 |
+
|
47 |
+
|
48 |
+
def compute_kmeans(sentences):
|
49 |
+
"""
|
50 |
+
Computes K-means clustering for a list of sentences by generating their embeddings, normalizing the embeddings, and determining the optimal number of clusters using binary search.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
sentences (list): List of sentences to be clustered.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tuple: A tuple containing:
|
57 |
+
- embeddings (ndarray): Normalized embeddings of the input sentences.
|
58 |
+
- kmeans (KMeans): The KMeans object with the optimal number of clusters determined.
|
59 |
+
"""
|
60 |
+
# sentence embeddings
|
61 |
+
model = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2")
|
62 |
+
embeddings = model.encode(sentences)
|
63 |
+
# normalize the embeddings for equivalent computation of the cosine distance
|
64 |
+
embeddings = preprocessing.normalize(embeddings)
|
65 |
+
# compute the number of clusters with binary search
|
66 |
+
kmeans = binary_search_optimal_kmeans(embeddings, min_k=0, max_k=len(sentences))
|
67 |
+
return embeddings, kmeans
|
68 |
+
|
69 |
+
|
70 |
+
def binary_search_optimal_kmeans(data, min_k, max_k):
|
71 |
+
"""
|
72 |
+
Finds the optimal k for KMeans clustering using binary search on the silhouette score.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
data (list): cluster data.
|
76 |
+
min_k: minimum k for binary search
|
77 |
+
max_k: maximum k for binary search
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
list: List of cleaned response strings.
|
81 |
+
"""
|
82 |
+
best_k = min_k
|
83 |
+
best_score = -1
|
84 |
+
best_kmeans = KMeans(n_clusters=1, random_state=42).fit(
|
85 |
+
data
|
86 |
+
) # start with 1 cluster for len(data) < 2
|
87 |
+
|
88 |
+
while min_k <= max_k:
|
89 |
+
mid_k = (min_k + max_k) // 2
|
90 |
+
if mid_k < 2:
|
91 |
+
break
|
92 |
+
|
93 |
+
kmeans = KMeans(n_clusters=mid_k, random_state=42).fit(data)
|
94 |
+
labels = kmeans.labels_
|
95 |
+
score = silhouette_score(data, labels)
|
96 |
+
|
97 |
+
if score > best_score:
|
98 |
+
best_score = score
|
99 |
+
best_k = mid_k
|
100 |
+
best_kmeans = kmeans # Update the best KMeans model
|
101 |
+
min_k = mid_k + 1
|
102 |
+
else:
|
103 |
+
max_k = mid_k - 1
|
104 |
+
|
105 |
+
return best_kmeans
|
106 |
+
|
107 |
+
|
108 |
+
def flatten_values_lists_of_list_dicts_to_dict(item):
|
109 |
+
"""
|
110 |
+
Flattens a list of dictionaries containing lists of values into a single dictionary.
|
111 |
+
|
112 |
+
Args:
|
113 |
+
item (list): List of dictionaries, where each dictionary's values are lists. If any element of the list is itself a list, the function will consider only the first dictionary in that sublist.
|
114 |
+
|
115 |
+
Returns:
|
116 |
+
dict: A dictionary where each key corresponds to the keys in the input dictionaries, and each value is a flattened list of all values associated with that key across all input dictionaries.
|
117 |
+
"""
|
118 |
+
|
119 |
+
result = {}
|
120 |
+
for i in item:
|
121 |
+
if isinstance(i, list):
|
122 |
+
i = i[0]
|
123 |
+
for key, lists in i.items():
|
124 |
+
if key not in result:
|
125 |
+
result[key] = []
|
126 |
+
result[key].extend(lists)
|
127 |
+
|
128 |
+
return result
|
129 |
+
|
130 |
+
|
131 |
+
def gather_processes(local_candidates, local_references=None):
|
132 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1"))
|
133 |
+
local_rank = int(os.environ.get("RANK", "0"))
|
134 |
+
global_candidates_list = None
|
135 |
+
global_references_list = None
|
136 |
+
|
137 |
+
if local_rank == 0:
|
138 |
+
# Initialize the gather list only on the root process
|
139 |
+
global_candidates_list = [None for _ in range(world_size)]
|
140 |
+
global_references_list = [None for _ in range(world_size)]
|
141 |
+
try:
|
142 |
+
dist.gather_object(local_candidates, global_candidates_list, dst=0)
|
143 |
+
|
144 |
+
if not local_references is None:
|
145 |
+
dist.gather_object(local_references, global_references_list, dst=0)
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
print(f"Error during result gathering: {e}")
|
149 |
+
|
150 |
+
if local_rank != 0:
|
151 |
+
# Exit the process
|
152 |
+
# print(f"Rank {dist.get_rank()} exiting.")
|
153 |
+
dist.destroy_process_group() # Clean up the distributed processing group
|
154 |
+
sys.exit() # Exit the process
|
155 |
+
|
156 |
+
# Flatten the gathered list
|
157 |
+
candidates_list = []
|
158 |
+
for i in global_candidates_list:
|
159 |
+
candidates_list.extend(i)
|
160 |
+
|
161 |
+
if not global_references_list[0] is None:
|
162 |
+
references_list = []
|
163 |
+
for i in global_references_list:
|
164 |
+
references_list.extend(i)
|
165 |
+
print(f"References list: {len(references_list)}")
|
166 |
+
return candidates_list, references_list
|
167 |
+
|
168 |
+
return candidates_list
|
169 |
+
|
170 |
+
|
171 |
+
def clean_responses(response):
|
172 |
+
if "[Explanation]:" in response:
|
173 |
+
if "<|assistant|>" in response:
|
174 |
+
response = response.split("<|assistant|>")[-1]
|
175 |
+
if (
|
176 |
+
"[Explanation]:\n <Explanation>\n" or "[Explanation]:\n<Explanation>"
|
177 |
+
) in response:
|
178 |
+
response = response.split("[Explanation]:")[1]
|
179 |
+
else:
|
180 |
+
response = response.split("[Explanation]:")[-1]
|
181 |
+
if "<|assistant|>" in response:
|
182 |
+
response = response.split("<|assistant|>")[-1]
|
183 |
+
return response.replace("</s>", "").replace("<unk>", "")
|
184 |
+
|
185 |
+
|
186 |
+
def make_prompt(text1, text2, max_len=300):
|
187 |
+
"""
|
188 |
+
Creates a prompt for evaluating the accuracy of a candidate radiology report in comparison to a reference radiology report.
|
189 |
+
|
190 |
+
Args:
|
191 |
+
text1 (str): Reference radiology report.
|
192 |
+
text2 (str): Candidate radiology report.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
str: Formatted prompt string.
|
196 |
+
"""
|
197 |
+
text1 = " ".join(text1.split()[:max_len])
|
198 |
+
text2 = " ".join(text2.split()[:max_len])
|
199 |
+
prompt = f"Objective: Evaluate the accuracy of a candidate radiology report in comparison to a reference radiology report composed by expert radiologists.\n\n Process Overview: You will be presented with:\n\n 1. The criteria for making a judgment.\n 2. The reference radiology report.\n 3. The candidate radiology report.\n 4. The desired format for your assessment.\n\n 1. Criteria for Judgment:\n\n For each candidate report, determine:\n\n The count of clinically significant errors.\n The count of clinically insignificant errors.\n\n Errors can fall into one of these categories:\n\n a) False report of a finding in the candidate.\n b) Missing a finding present in the reference.\n c) Misidentification of a finding's anatomic location/position.\n d) Misassessment of the severity of a finding.\n e) Mentioning a comparison that isn't in the reference.\n f) Omitting a comparison detailing a change from a prior study.\n Note: Concentrate on the clinical findings rather than the report's writing style. Evaluate only the findings that appear in both reports.\n\n 2. Reference Report:\n {text1}\n\n 3. Candidate Report:\n {text2}\n\n 4. Reporting Your Assessment:\n\n Follow this specific format for your output, even if no errors are found:\n ```\n [Explanation]:\n <Explanation>\n\n [Clinically Significant Errors]:\n (a) <Error Type>: <The number of errors>. <Error 1>; <Error 2>; ...; <Error n>\n ....\n (f) <Error Type>: <The number of errors>. <Error 1>; <Error 2>; ...; <Error n>\n\n [Clinically Insignificant Errors]:\n (a) <Error Type>: <The number of errors>. <Error 1>; <Error 2>; ...; <Error n>\n ....\n (f) <Error Type>: <The number of errors>. <Error 1>; <Error 2>; ...; <Error n>\n\n [Matched Findings]:\n <The number of matched findings>. <Finding 1>; <Finding 2>; ...; <Finding n>\n ```\n"
|
200 |
+
return prompt
|
nlg/__init__.py
ADDED
File without changes
|
nlg/bertscore/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'tylin'
|
nlg/bertscore/bertscore.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from bert_score import BERTScorer
|
4 |
+
|
5 |
+
|
6 |
+
class BertScore(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
model_type='distilbert-base-uncased',
|
9 |
+
num_layers=5,
|
10 |
+
rescale_with_baseline=True,
|
11 |
+
idf=False,
|
12 |
+
):
|
13 |
+
super(BertScore, self).__init__()
|
14 |
+
with torch.no_grad():
|
15 |
+
self.bert_scorer = BERTScorer(model_type=model_type,
|
16 |
+
num_layers=num_layers,
|
17 |
+
batch_size=64,
|
18 |
+
nthreads=4,
|
19 |
+
all_layers=False,
|
20 |
+
idf=idf,
|
21 |
+
device=None,
|
22 |
+
lang='en',
|
23 |
+
rescale_with_baseline=rescale_with_baseline,
|
24 |
+
baseline_path=None)
|
25 |
+
|
26 |
+
def forward(self, refs, hyps):
|
27 |
+
p, r, f = self.bert_scorer.score(
|
28 |
+
cands=hyps,
|
29 |
+
refs=refs,
|
30 |
+
verbose=False,
|
31 |
+
batch_size=64,
|
32 |
+
)
|
33 |
+
return torch.mean(f).item(), f.tolist()
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == '__main__':
|
37 |
+
x, y = (BertScore()(
|
38 |
+
hyps=[
|
39 |
+
"nothing to do lol",
|
40 |
+
"nothing to do x",
|
41 |
+
'there are moderate bilateral pleural effusions with overlying atelectasis, underlying consolidation not excluded. mild prominence of the interstitial markings suggests mild pulmonary edema. the cardiac silhouette is mildly enlarged. the mediastinal contours are unremarkable. there is no evidence of pneumothorax.'
|
42 |
+
],
|
43 |
+
refs=[
|
44 |
+
'heart size is moderately enlarged. the mediastinal and hilar contours are unchanged. there is no pulmonary edema. small left pleural effusion is present. patchy opacities in the lung bases likely reflect atelectasis. no pneumothorax is seen. there are no acute osseous abnormalities.',
|
45 |
+
'heart size is mildly enlarged. the mediastinal and hilar contours are normal. there is mild pulmonary edema. moderate bilateral pleural effusions are present, left greater than right. bibasilar airspace opacities likely reflect atelectasis. no pneumothorax is seen. there are no acute osseous abnormalities.',
|
46 |
+
'heart size is mildly enlarged. the mediastinal and hilar contours are normal. there is mild pulmonary edema. moderate bilateral pleural effusions are present, left greater than right. bibasilar airspace opacities likely reflect atelectasis. no pneumothorax is seen. there are no acute osseous abnormalities.'
|
47 |
+
])
|
48 |
+
)
|
49 |
+
print(x)
|
50 |
+
print(y)
|
nlg/bleu/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'tylin'
|
nlg/bleu/bleu.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
#
|
3 |
+
# File Name : bleu.py
|
4 |
+
#
|
5 |
+
# Description : Wrapper for BLEU scorer.
|
6 |
+
#
|
7 |
+
# Creation Date : 06-01-2015
|
8 |
+
# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
|
9 |
+
# Authors : Hao Fang <[email protected]> and Tsung-Yi Lin <[email protected]>
|
10 |
+
|
11 |
+
import torch.nn as nn
|
12 |
+
from .bleu_scorer import BleuScorer
|
13 |
+
|
14 |
+
|
15 |
+
class Bleu(nn.Module):
|
16 |
+
def __init__(self, n=4, **kwargs):
|
17 |
+
# default compute Blue score up to 4
|
18 |
+
super().__init__()
|
19 |
+
self._n = n
|
20 |
+
|
21 |
+
def forward(self, gts, res):
|
22 |
+
return self.compute_score(gts, res)
|
23 |
+
|
24 |
+
def compute_score(self, gts, res):
|
25 |
+
res = {i: [v] for i, v in enumerate(res)}
|
26 |
+
gts = {i: [v] for i, v in enumerate(gts)}
|
27 |
+
bleu_scorer = BleuScorer(n=self._n)
|
28 |
+
|
29 |
+
for id in sorted(gts.keys()):
|
30 |
+
hypo = res[id]
|
31 |
+
ref = gts[id]
|
32 |
+
|
33 |
+
# Sanity check.
|
34 |
+
assert (type(hypo) is list)
|
35 |
+
assert (len(hypo) == 1)
|
36 |
+
assert (type(ref) is list)
|
37 |
+
assert (len(ref) >= 1)
|
38 |
+
|
39 |
+
bleu_scorer += (hypo[0], ref)
|
40 |
+
|
41 |
+
# score, scores = bleu_scorer.compute_score(option='shortest')
|
42 |
+
score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
|
43 |
+
# score, scores = bleu_scorer.compute_score(option='average', verbose=1)
|
44 |
+
|
45 |
+
# return (bleu, bleu_info)
|
46 |
+
return score[self._n-1], scores[self._n-1]
|
47 |
+
|
48 |
+
def method(self):
|
49 |
+
return "Bleu"
|
nlg/bleu/bleu_scorer.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# bleu_scorer.py
|
4 |
+
# David Chiang <[email protected]>
|
5 |
+
|
6 |
+
# Copyright (c) 2004-2006 University of Maryland. All rights
|
7 |
+
# reserved. Do not redistribute without permission from the
|
8 |
+
# author. Not for commercial use.
|
9 |
+
|
10 |
+
# Modified by:
|
11 |
+
# Hao Fang <[email protected]>
|
12 |
+
# Tsung-Yi Lin <[email protected]>
|
13 |
+
|
14 |
+
'''Provides:
|
15 |
+
cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
|
16 |
+
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
|
17 |
+
'''
|
18 |
+
|
19 |
+
import copy
|
20 |
+
import sys, math, re
|
21 |
+
from collections import defaultdict
|
22 |
+
|
23 |
+
import six
|
24 |
+
from six.moves import xrange as range
|
25 |
+
|
26 |
+
|
27 |
+
def precook(s, n=4, out=False):
|
28 |
+
"""Takes a string as input and returns an object that can be given to
|
29 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
30 |
+
can take string arguments as well."""
|
31 |
+
words = s.split()
|
32 |
+
counts = defaultdict(int)
|
33 |
+
for k in range(1,n+1):
|
34 |
+
for i in range(len(words)-k+1):
|
35 |
+
ngram = tuple(words[i:i+k])
|
36 |
+
counts[ngram] += 1
|
37 |
+
return (len(words), counts)
|
38 |
+
|
39 |
+
def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
|
40 |
+
'''Takes a list of reference sentences for a single segment
|
41 |
+
and returns an object that encapsulates everything that BLEU
|
42 |
+
needs to know about them.'''
|
43 |
+
|
44 |
+
reflen = []
|
45 |
+
maxcounts = {}
|
46 |
+
for ref in refs:
|
47 |
+
rl, counts = precook(ref, n)
|
48 |
+
reflen.append(rl)
|
49 |
+
for (ngram,count) in six.iteritems(counts):
|
50 |
+
maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
51 |
+
|
52 |
+
# Calculate effective reference sentence length.
|
53 |
+
if eff == "shortest":
|
54 |
+
reflen = min(reflen)
|
55 |
+
elif eff == "average":
|
56 |
+
reflen = float(sum(reflen))/len(reflen)
|
57 |
+
|
58 |
+
## lhuang: N.B.: leave reflen computaiton to the very end!!
|
59 |
+
|
60 |
+
## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
|
61 |
+
|
62 |
+
return (reflen, maxcounts)
|
63 |
+
|
64 |
+
def cook_test(test, reflen_refmaxcounts, eff=None, n=4):
|
65 |
+
'''Takes a test sentence and returns an object that
|
66 |
+
encapsulates everything that BLEU needs to know about it.'''
|
67 |
+
|
68 |
+
reflen, refmaxcounts = reflen_refmaxcounts
|
69 |
+
testlen, counts = precook(test, n, True)
|
70 |
+
|
71 |
+
result = {}
|
72 |
+
|
73 |
+
# Calculate effective reference sentence length.
|
74 |
+
|
75 |
+
if eff == "closest":
|
76 |
+
result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
|
77 |
+
else: ## i.e., "average" or "shortest" or None
|
78 |
+
result["reflen"] = reflen
|
79 |
+
|
80 |
+
result["testlen"] = testlen
|
81 |
+
|
82 |
+
result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
|
83 |
+
|
84 |
+
result['correct'] = [0]*n
|
85 |
+
for (ngram, count) in six.iteritems(counts):
|
86 |
+
result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
|
87 |
+
|
88 |
+
return result
|
89 |
+
|
90 |
+
class BleuScorer(object):
|
91 |
+
"""Bleu scorer.
|
92 |
+
"""
|
93 |
+
|
94 |
+
__slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
|
95 |
+
# special_reflen is used in oracle (proportional effective ref len for a node).
|
96 |
+
|
97 |
+
def copy(self):
|
98 |
+
''' copy the refs.'''
|
99 |
+
new = BleuScorer(n=self.n)
|
100 |
+
new.ctest = copy.copy(self.ctest)
|
101 |
+
new.crefs = copy.copy(self.crefs)
|
102 |
+
new._score = None
|
103 |
+
return new
|
104 |
+
|
105 |
+
def __init__(self, test=None, refs=None, n=4, special_reflen=None):
|
106 |
+
''' singular instance '''
|
107 |
+
|
108 |
+
self.n = n
|
109 |
+
self.crefs = []
|
110 |
+
self.ctest = []
|
111 |
+
self.cook_append(test, refs)
|
112 |
+
self.special_reflen = special_reflen
|
113 |
+
|
114 |
+
def cook_append(self, test, refs):
|
115 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
116 |
+
|
117 |
+
if refs is not None:
|
118 |
+
self.crefs.append(cook_refs(refs))
|
119 |
+
if test is not None:
|
120 |
+
cooked_test = cook_test(test, self.crefs[-1])
|
121 |
+
self.ctest.append(cooked_test) ## N.B.: -1
|
122 |
+
else:
|
123 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
124 |
+
|
125 |
+
self._score = None ## need to recompute
|
126 |
+
|
127 |
+
def ratio(self, option=None):
|
128 |
+
self.compute_score(option=option)
|
129 |
+
return self._ratio
|
130 |
+
|
131 |
+
def score_ratio(self, option=None):
|
132 |
+
'''return (bleu, len_ratio) pair'''
|
133 |
+
return (self.fscore(option=option), self.ratio(option=option))
|
134 |
+
|
135 |
+
def score_ratio_str(self, option=None):
|
136 |
+
return "%.4f (%.2f)" % self.score_ratio(option)
|
137 |
+
|
138 |
+
def reflen(self, option=None):
|
139 |
+
self.compute_score(option=option)
|
140 |
+
return self._reflen
|
141 |
+
|
142 |
+
def testlen(self, option=None):
|
143 |
+
self.compute_score(option=option)
|
144 |
+
return self._testlen
|
145 |
+
|
146 |
+
def retest(self, new_test):
|
147 |
+
if type(new_test) is str:
|
148 |
+
new_test = [new_test]
|
149 |
+
assert len(new_test) == len(self.crefs), new_test
|
150 |
+
self.ctest = []
|
151 |
+
for t, rs in zip(new_test, self.crefs):
|
152 |
+
self.ctest.append(cook_test(t, rs))
|
153 |
+
self._score = None
|
154 |
+
|
155 |
+
return self
|
156 |
+
|
157 |
+
def rescore(self, new_test):
|
158 |
+
''' replace test(s) with new test(s), and returns the new score.'''
|
159 |
+
|
160 |
+
return self.retest(new_test).compute_score()
|
161 |
+
|
162 |
+
def size(self):
|
163 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
164 |
+
return len(self.crefs)
|
165 |
+
|
166 |
+
def __iadd__(self, other):
|
167 |
+
'''add an instance (e.g., from another sentence).'''
|
168 |
+
|
169 |
+
if type(other) is tuple:
|
170 |
+
## avoid creating new BleuScorer instances
|
171 |
+
self.cook_append(other[0], other[1])
|
172 |
+
else:
|
173 |
+
assert self.compatible(other), "incompatible BLEUs."
|
174 |
+
self.ctest.extend(other.ctest)
|
175 |
+
self.crefs.extend(other.crefs)
|
176 |
+
self._score = None ## need to recompute
|
177 |
+
|
178 |
+
return self
|
179 |
+
|
180 |
+
def compatible(self, other):
|
181 |
+
return isinstance(other, BleuScorer) and self.n == other.n
|
182 |
+
|
183 |
+
def single_reflen(self, option="average"):
|
184 |
+
return self._single_reflen(self.crefs[0][0], option)
|
185 |
+
|
186 |
+
def _single_reflen(self, reflens, option=None, testlen=None):
|
187 |
+
|
188 |
+
if option == "shortest":
|
189 |
+
reflen = min(reflens)
|
190 |
+
elif option == "average":
|
191 |
+
reflen = float(sum(reflens))/len(reflens)
|
192 |
+
elif option == "closest":
|
193 |
+
reflen = min((abs(l-testlen), l) for l in reflens)[1]
|
194 |
+
else:
|
195 |
+
assert False, "unsupported reflen option %s" % option
|
196 |
+
|
197 |
+
return reflen
|
198 |
+
|
199 |
+
def recompute_score(self, option=None, verbose=0):
|
200 |
+
self._score = None
|
201 |
+
return self.compute_score(option, verbose)
|
202 |
+
|
203 |
+
def compute_score(self, option=None, verbose=0):
|
204 |
+
n = self.n
|
205 |
+
small = 1e-9
|
206 |
+
tiny = 1e-15 ## so that if guess is 0 still return 0
|
207 |
+
bleu_list = [[] for _ in range(n)]
|
208 |
+
|
209 |
+
if self._score is not None:
|
210 |
+
return self._score
|
211 |
+
|
212 |
+
if option is None:
|
213 |
+
option = "average" if len(self.crefs) == 1 else "closest"
|
214 |
+
|
215 |
+
self._testlen = 0
|
216 |
+
self._reflen = 0
|
217 |
+
totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
|
218 |
+
|
219 |
+
# for each sentence
|
220 |
+
for comps in self.ctest:
|
221 |
+
testlen = comps['testlen']
|
222 |
+
self._testlen += testlen
|
223 |
+
|
224 |
+
if self.special_reflen is None: ## need computation
|
225 |
+
reflen = self._single_reflen(comps['reflen'], option, testlen)
|
226 |
+
else:
|
227 |
+
reflen = self.special_reflen
|
228 |
+
|
229 |
+
self._reflen += reflen
|
230 |
+
|
231 |
+
for key in ['guess','correct']:
|
232 |
+
for k in range(n):
|
233 |
+
totalcomps[key][k] += comps[key][k]
|
234 |
+
|
235 |
+
# append per image bleu score
|
236 |
+
bleu = 1.
|
237 |
+
for k in range(n):
|
238 |
+
bleu *= (float(comps['correct'][k]) + tiny) \
|
239 |
+
/(float(comps['guess'][k]) + small)
|
240 |
+
bleu_list[k].append(bleu ** (1./(k+1)))
|
241 |
+
ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
|
242 |
+
if ratio < 1:
|
243 |
+
for k in range(n):
|
244 |
+
bleu_list[k][-1] *= math.exp(1 - 1/ratio)
|
245 |
+
|
246 |
+
if verbose > 1:
|
247 |
+
print(comps, reflen)
|
248 |
+
|
249 |
+
totalcomps['reflen'] = self._reflen
|
250 |
+
totalcomps['testlen'] = self._testlen
|
251 |
+
|
252 |
+
bleus = []
|
253 |
+
bleu = 1.
|
254 |
+
for k in range(n):
|
255 |
+
bleu *= float(totalcomps['correct'][k] + tiny) \
|
256 |
+
/ (totalcomps['guess'][k] + small)
|
257 |
+
bleus.append(bleu ** (1./(k+1)))
|
258 |
+
ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
|
259 |
+
if ratio < 1:
|
260 |
+
for k in range(n):
|
261 |
+
bleus[k] *= math.exp(1 - 1/ratio)
|
262 |
+
|
263 |
+
if verbose > 0:
|
264 |
+
print(totalcomps)
|
265 |
+
print("ratio:", ratio)
|
266 |
+
|
267 |
+
self._score = bleus
|
268 |
+
return self._score, bleu_list
|
nlg/radevalbertscore.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from bert_score import score
|
3 |
+
|
4 |
+
def _get_default_device():
|
5 |
+
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
6 |
+
|
7 |
+
class RadEvalBERTScorer:
|
8 |
+
"""
|
9 |
+
Wrapper around bert_score for radiology reports using a custom BERT model.
|
10 |
+
"""
|
11 |
+
def __init__(self,
|
12 |
+
model_type: str = "IAMJB/RadEvalModernBERT",
|
13 |
+
num_layers: int = None,
|
14 |
+
use_fast_tokenizer: bool = True,
|
15 |
+
rescale_with_baseline: bool = False,
|
16 |
+
device: torch.device = None):
|
17 |
+
self.model_type = model_type
|
18 |
+
self.num_layers = num_layers
|
19 |
+
self.use_fast_tokenizer = use_fast_tokenizer
|
20 |
+
self.rescale_with_baseline = rescale_with_baseline
|
21 |
+
self.device = device or _get_default_device()
|
22 |
+
|
23 |
+
def score(self, refs: list[str], hyps: list[str]) -> float:
|
24 |
+
"""
|
25 |
+
Compute BERTScore F1 between reference and hypothesis texts.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
refs: list of reference sentences.
|
29 |
+
hyps: list of hypothesis sentences (predictions).
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
Mean F1 score as a float.
|
33 |
+
"""
|
34 |
+
# bert_score expects cands (hypotheses) first, then refs
|
35 |
+
P, R, F1 = score(
|
36 |
+
cands=hyps,
|
37 |
+
refs=refs,
|
38 |
+
model_type=self.model_type,
|
39 |
+
num_layers=self.num_layers,
|
40 |
+
use_fast_tokenizer=self.use_fast_tokenizer,
|
41 |
+
rescale_with_baseline=self.rescale_with_baseline,
|
42 |
+
device=self.device
|
43 |
+
)
|
44 |
+
# Return the mean F1 over all pairs
|
45 |
+
return F1.mean().item(), F1
|
46 |
+
|
47 |
+
if __name__ == "__main__":
|
48 |
+
# Example usage
|
49 |
+
refs = ["Chronic mild to moderate cardiomegaly and pulmonary venous hypertension."]
|
50 |
+
hyps = ["Mild left basal atelectasis; no pneumonia."]
|
51 |
+
scorer = RadiologyBERTScorer(num_layers=23)
|
52 |
+
f1_score = scorer.score(refs, hyps)
|
53 |
+
print(f"Mean F1 score: {f1_score:.4f}")
|
nlg/rouge/rouge.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
from rouge_score import rouge_scorer
|
3 |
+
from six.moves import zip_longest
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
class Rouge(nn.Module):
|
8 |
+
def __init__(self, rouges, **kwargs):
|
9 |
+
super().__init__()
|
10 |
+
rouges = [r.replace('rougel', 'rougeL') for r in rouges]
|
11 |
+
self.scorer = rouge_scorer.RougeScorer(rouges, use_stemmer=True)
|
12 |
+
self.rouges = rouges
|
13 |
+
|
14 |
+
def forward(self, refs, hyps):
|
15 |
+
scores = []
|
16 |
+
for target_rec, prediction_rec in zip_longest(refs, hyps):
|
17 |
+
if target_rec is None or prediction_rec is None:
|
18 |
+
raise ValueError("Must have equal number of lines across target and "
|
19 |
+
"prediction.")
|
20 |
+
scores.append(self.scorer.score(target_rec, prediction_rec))
|
21 |
+
f1_rouge = [s[self.rouges[0]].fmeasure for s in scores]
|
22 |
+
return np.mean(f1_rouge), f1_rouge
|
23 |
+
|
24 |
+
|
25 |
+
class Rouge1(Rouge):
|
26 |
+
def __init__(self, **kwargs):
|
27 |
+
super(Rouge1, self).__init__(rouges=['rouge1'])
|
28 |
+
|
29 |
+
|
30 |
+
class Rouge2(Rouge):
|
31 |
+
def __init__(self, **kwargs):
|
32 |
+
super(Rouge2, self).__init__(rouges=['rouge2'])
|
33 |
+
|
34 |
+
|
35 |
+
class RougeL(Rouge):
|
36 |
+
def __init__(self, **kwargs):
|
37 |
+
super(RougeL, self).__init__(rouges=['rougeL'])
|
utils.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ---------------------------------------------------------------
|
2 |
+
# This file includes code adapted from:
|
3 |
+
# https://github.com/jbdel/RadEval/blob/null-hypothesis/utils.py
|
4 |
+
# Original author: Justin Xu
|
5 |
+
# ---------------------------------------------------------------
|
6 |
+
import re
|
7 |
+
import nltk
|
8 |
+
import random
|
9 |
+
from typing import List, Dict, Tuple, Callable, Optional
|
10 |
+
from collections import defaultdict
|
11 |
+
|
12 |
+
nltk.download("punkt_tab", quiet=True)
|
13 |
+
|
14 |
+
def clean_numbered_list(text):
|
15 |
+
"""
|
16 |
+
Clean a report if it's a numbered list by:
|
17 |
+
1. Adding proper spacing between numbered items
|
18 |
+
2. Removing the numbered list markers
|
19 |
+
3. Adding spaces after periods between sentences
|
20 |
+
"""
|
21 |
+
# First, separate numbered items that are stuck together without spaces
|
22 |
+
# Example: "textx.2. text2" -> "texty. 2. text2"
|
23 |
+
text = re.sub(r'\.(\d+\.)', r'. \1', text)
|
24 |
+
|
25 |
+
# Handle patterns where there's no period between numbered entries
|
26 |
+
# Example: "1. item1 2. item2" -> "1. item1. 2. item2"
|
27 |
+
text = re.sub(r'(\d+\.\s*[^.]+?)\s+(?=\d+\.)', r'\1. ', text)
|
28 |
+
|
29 |
+
# Then remove the numbered list markers
|
30 |
+
# But avoid removing decimal numbers in measurements like "3.5 cm"
|
31 |
+
text = re.sub(r'(?<!\d)\d+\.\s*', '', text)
|
32 |
+
|
33 |
+
# Add spaces after periods between sentences if missing
|
34 |
+
# Example: "sentence1.sentence2" -> "sentence1. sentence2"
|
35 |
+
# But don't split decimal numbers like "3.5 cm"
|
36 |
+
text = re.sub(r'\.([A-Za-z])', r'. \1', text)
|
37 |
+
return nltk.sent_tokenize(text)
|
38 |
+
|
39 |
+
class PairedTest:
|
40 |
+
"""
|
41 |
+
Paired significance testing for comparing radiology report generation systems.
|
42 |
+
|
43 |
+
Supports paired approximate randomization (AR).
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
systems: Dict[str, List[str]],
|
48 |
+
metrics: Dict[str, Callable],
|
49 |
+
references: Optional[List[str]],
|
50 |
+
n_samples: int = 10000,
|
51 |
+
n_jobs: int = 1,
|
52 |
+
seed: int = 12345):
|
53 |
+
"""
|
54 |
+
Args:
|
55 |
+
systems: Dictionary mapping system names to their generated reports
|
56 |
+
metrics: Dictionary mapping metric names to metric functions
|
57 |
+
references: List of reference reports
|
58 |
+
n_samples: Number of resampling trials (default: 10000)
|
59 |
+
n_jobs: Number of parallel jobs (default: 1)
|
60 |
+
seed: Random seed for reproducibility
|
61 |
+
"""
|
62 |
+
self.systems = systems
|
63 |
+
self.metrics = metrics
|
64 |
+
self.references = references
|
65 |
+
self.n_samples = n_samples
|
66 |
+
self.n_jobs = n_jobs
|
67 |
+
self.seed = seed
|
68 |
+
|
69 |
+
random.seed(seed)
|
70 |
+
|
71 |
+
if not systems:
|
72 |
+
raise ValueError("At least one system is required")
|
73 |
+
|
74 |
+
system_lengths = [len(outputs) for outputs in systems.values()]
|
75 |
+
if len(set(system_lengths)) > 1:
|
76 |
+
raise ValueError("All systems must have the same number of outputs")
|
77 |
+
|
78 |
+
if references and len(references) != system_lengths[0]:
|
79 |
+
raise ValueError("References must have same length as system outputs")
|
80 |
+
|
81 |
+
self.n_instances = system_lengths[0]
|
82 |
+
|
83 |
+
def __call__(self) -> Tuple[Dict[str, str], Dict[str, Dict[str, float]]]:
|
84 |
+
"""
|
85 |
+
Run the paired significance test.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
Tuple of (signatures, scores) where:
|
89 |
+
- signatures: Dict mapping metric names to signature strings
|
90 |
+
- scores: Dict mapping system names to metric scores and p-values
|
91 |
+
"""
|
92 |
+
# Calculate baseline scores for all systems and metrics
|
93 |
+
baseline_scores = self._calculate_baseline_scores()
|
94 |
+
|
95 |
+
# Get baseline system (first system)
|
96 |
+
baseline_name = list(self.systems.keys())[0]
|
97 |
+
|
98 |
+
scores = {}
|
99 |
+
signatures = {}
|
100 |
+
|
101 |
+
# Calculate scores and p-values for each system
|
102 |
+
for system_name in self.systems.keys():
|
103 |
+
scores[system_name] = {}
|
104 |
+
|
105 |
+
for metric_name in self.metrics.keys():
|
106 |
+
score = baseline_scores[system_name][metric_name]
|
107 |
+
scores[system_name][metric_name] = score
|
108 |
+
|
109 |
+
if system_name != baseline_name:
|
110 |
+
p_value = self._calculate_p_value(
|
111 |
+
baseline_name, system_name, metric_name, baseline_scores
|
112 |
+
)
|
113 |
+
scores[system_name][f'{metric_name}_pvalue'] = p_value
|
114 |
+
|
115 |
+
for metric_name in self.metrics.keys():
|
116 |
+
signatures[metric_name] = f"{metric_name}|{'ar'}:{self.n_samples}|seed:{self.seed}"
|
117 |
+
|
118 |
+
return signatures, scores
|
119 |
+
|
120 |
+
def _calculate_baseline_scores(self) -> Dict[str, Dict[str, float]]:
|
121 |
+
"""Calculate baseline scores for all systems and metrics."""
|
122 |
+
scores = defaultdict(dict)
|
123 |
+
|
124 |
+
for system_name, outputs in self.systems.items():
|
125 |
+
for metric_name, metric_func in self.metrics.items():
|
126 |
+
if self.references:
|
127 |
+
score = metric_func(outputs, self.references)
|
128 |
+
else:
|
129 |
+
score = metric_func(outputs)
|
130 |
+
|
131 |
+
if isinstance(score, dict):
|
132 |
+
if 'score' in score:
|
133 |
+
scores[system_name][metric_name] = score['score']
|
134 |
+
else:
|
135 |
+
scores[system_name][metric_name] = list(score.values())[0]
|
136 |
+
elif isinstance(score, (tuple, list)):
|
137 |
+
scores[system_name][metric_name] = score[0]
|
138 |
+
else:
|
139 |
+
scores[system_name][metric_name] = score
|
140 |
+
|
141 |
+
return scores
|
142 |
+
|
143 |
+
def _calculate_p_value(self,
|
144 |
+
baseline_name: str,
|
145 |
+
system_name: str,
|
146 |
+
metric_name: str,
|
147 |
+
baseline_scores: Dict[str, Dict[str, float]]) -> float:
|
148 |
+
"""Calculate p-value using AR test"""
|
149 |
+
|
150 |
+
baseline_outputs = self.systems[baseline_name]
|
151 |
+
system_outputs = self.systems[system_name]
|
152 |
+
metric_func = self.metrics[metric_name]
|
153 |
+
|
154 |
+
baseline_score = baseline_scores[baseline_name][metric_name]
|
155 |
+
system_score = baseline_scores[system_name][metric_name]
|
156 |
+
original_delta = abs(system_score - baseline_score)
|
157 |
+
|
158 |
+
return self._approximate_randomization_test(
|
159 |
+
baseline_outputs, system_outputs, metric_func, original_delta
|
160 |
+
)
|
161 |
+
|
162 |
+
def _approximate_randomization_test(self,
|
163 |
+
baseline_outputs: List[str],
|
164 |
+
system_outputs: List[str],
|
165 |
+
metric_func: Callable,
|
166 |
+
original_delta: float) -> float:
|
167 |
+
"""
|
168 |
+
Perform AR test.
|
169 |
+
|
170 |
+
For each trial, randomly swap outputs between systems and calculate
|
171 |
+
the score difference. P-value is the proportion of trials where
|
172 |
+
the randomized delta >= original delta.
|
173 |
+
"""
|
174 |
+
count_greater = 0
|
175 |
+
|
176 |
+
for _ in range(self.n_samples):
|
177 |
+
randomized_baseline = []
|
178 |
+
randomized_system = []
|
179 |
+
|
180 |
+
for i in range(self.n_instances):
|
181 |
+
if random.random() < 0.5:
|
182 |
+
# Don't swap
|
183 |
+
randomized_baseline.append(baseline_outputs[i])
|
184 |
+
randomized_system.append(system_outputs[i])
|
185 |
+
else:
|
186 |
+
# Swap
|
187 |
+
randomized_baseline.append(system_outputs[i])
|
188 |
+
randomized_system.append(baseline_outputs[i])
|
189 |
+
|
190 |
+
if self.references:
|
191 |
+
rand_baseline_score = metric_func(randomized_baseline, self.references)
|
192 |
+
rand_system_score = metric_func(randomized_system, self.references)
|
193 |
+
else:
|
194 |
+
rand_baseline_score = metric_func(randomized_baseline)
|
195 |
+
rand_system_score = metric_func(randomized_system)
|
196 |
+
|
197 |
+
if isinstance(rand_baseline_score, dict):
|
198 |
+
rand_baseline_score = rand_baseline_score.get('score', list(rand_baseline_score.values())[0])
|
199 |
+
elif isinstance(rand_baseline_score, (tuple, list)):
|
200 |
+
rand_baseline_score = rand_baseline_score[0]
|
201 |
+
|
202 |
+
if isinstance(rand_system_score, dict):
|
203 |
+
rand_system_score = rand_system_score.get('score', list(rand_system_score.values())[0])
|
204 |
+
elif isinstance(rand_system_score, (tuple, list)):
|
205 |
+
rand_system_score = rand_system_score[0]
|
206 |
+
|
207 |
+
rand_delta = abs(rand_system_score - rand_baseline_score)
|
208 |
+
|
209 |
+
if rand_delta >= original_delta:
|
210 |
+
count_greater += 1
|
211 |
+
|
212 |
+
return count_greater / self.n_samples
|
213 |
+
|
214 |
+
|
215 |
+
def print_significance_results(scores: Dict[str, Dict[str, float]],
|
216 |
+
signatures: Dict[str, str],
|
217 |
+
baseline_name: str,
|
218 |
+
significance_level: float = 0.05):
|
219 |
+
"""
|
220 |
+
Args:
|
221 |
+
scores: Dictionary of system scores and p-values
|
222 |
+
signatures: Dictionary of metric signatures
|
223 |
+
baseline_name: Name of the baseline system
|
224 |
+
significance_level: Significance threshold (default: 0.05)
|
225 |
+
"""
|
226 |
+
assert baseline_name in scores, f"Baseline system '{baseline_name}' not found in scores."
|
227 |
+
|
228 |
+
metric_names = [name for name in signatures.keys()]
|
229 |
+
system_names = list(scores.keys())
|
230 |
+
|
231 |
+
print("=" * 80)
|
232 |
+
print("PAIRED SIGNIFICANCE TEST RESULTS")
|
233 |
+
print("=" * 80)
|
234 |
+
|
235 |
+
header = f"{'System':<40}"
|
236 |
+
for metric in metric_names:
|
237 |
+
header += f"{metric:>15}"
|
238 |
+
print(header)
|
239 |
+
print("-" * len(header))
|
240 |
+
|
241 |
+
baseline_row = f"Baseline: {baseline_name:<32}"
|
242 |
+
for metric in metric_names:
|
243 |
+
score = scores[baseline_name][metric]
|
244 |
+
baseline_row += f"{score:>12.4f} "
|
245 |
+
print(baseline_row)
|
246 |
+
print("-" * len(header))
|
247 |
+
|
248 |
+
for system_name in system_names:
|
249 |
+
if system_name == baseline_name:
|
250 |
+
continue
|
251 |
+
|
252 |
+
system_row = f"{system_name:<40}"
|
253 |
+
for metric in metric_names:
|
254 |
+
score = scores[system_name].get(metric, 0.0)
|
255 |
+
if isinstance(score, float):
|
256 |
+
system_row += f"{score:>12.4f} "
|
257 |
+
else:
|
258 |
+
system_row += f"{str(score):>12} "
|
259 |
+
print(system_row)
|
260 |
+
|
261 |
+
# P-value row
|
262 |
+
pvalue_row = " " * 40
|
263 |
+
for metric in metric_names:
|
264 |
+
pvalue_key = f"{metric}_pvalue"
|
265 |
+
if pvalue_key in scores[system_name]:
|
266 |
+
p_val = scores[system_name][pvalue_key]
|
267 |
+
significance_marker = "*" if p_val < significance_level else ""
|
268 |
+
pvalue_row += f"(p={p_val:.4f}){significance_marker:<2}".rjust(15)
|
269 |
+
else:
|
270 |
+
pvalue_row += " " * 15
|
271 |
+
print(pvalue_row)
|
272 |
+
print("-" * len(header))
|
273 |
+
|
274 |
+
# Footer
|
275 |
+
print(f"- Significance level: {significance_level}")
|
276 |
+
print("- '*' indicates significant difference (p < significance level)")
|
277 |
+
print("- Null hypothesis: systems are essentially the same")
|
278 |
+
print("- Significant results suggest systems are meaningfully different\n")
|
279 |
+
|
280 |
+
print("METRIC SIGNATURES:")
|
281 |
+
for metric, signature in signatures.items():
|
282 |
+
print(f"- {metric}: {signature}")
|
283 |
+
|
284 |
+
|
285 |
+
def compare_systems(systems: Dict[str, List[str]],
|
286 |
+
metrics: Dict[str, Callable],
|
287 |
+
references: Optional[List[str]] = None,
|
288 |
+
n_samples: int = 10000,
|
289 |
+
significance_level: float = 0.05,
|
290 |
+
seed: int = 12345,
|
291 |
+
print_results: bool = True) -> Tuple[Dict[str, str], Dict[str, Dict[str, float]]]:
|
292 |
+
"""
|
293 |
+
Args:
|
294 |
+
systems: Dictionary mapping system names to their generated reports
|
295 |
+
metrics: Dictionary mapping metric names to metric functions
|
296 |
+
references: Optional list of reference reports
|
297 |
+
n_samples: Number of resampling trials
|
298 |
+
significance_level: Significance threshold for printing results
|
299 |
+
seed: Random seed for reproducibility
|
300 |
+
print_results: Whether to print formatted results
|
301 |
+
|
302 |
+
Returns:
|
303 |
+
Tuple of (signatures, scores)
|
304 |
+
|
305 |
+
Example:
|
306 |
+
```python
|
307 |
+
systems = {
|
308 |
+
'baseline_model': baseline_reports,
|
309 |
+
'new_model': new_model_reports,
|
310 |
+
'other_model': other_model_reports
|
311 |
+
}
|
312 |
+
|
313 |
+
metrics = {
|
314 |
+
'bleu': lambda hyp, ref: bleu_score(hyp, ref),
|
315 |
+
'rouge': lambda hyp, ref: rouge_score(hyp, ref),
|
316 |
+
'bertscore': lambda hyp, ref: bert_score(hyp, ref)
|
317 |
+
'custom_metric': lambda hyp, ref: custom_metric(hyp, ref)
|
318 |
+
}
|
319 |
+
|
320 |
+
signatures, scores = compare_systems(
|
321 |
+
systems, metrics, references,
|
322 |
+
n_samples=10000
|
323 |
+
)
|
324 |
+
```
|
325 |
+
"""
|
326 |
+
|
327 |
+
paired_test = PairedTest(
|
328 |
+
systems=systems,
|
329 |
+
metrics=metrics,
|
330 |
+
references=references,
|
331 |
+
n_samples=n_samples,
|
332 |
+
seed=seed
|
333 |
+
)
|
334 |
+
|
335 |
+
signatures, scores = paired_test()
|
336 |
+
|
337 |
+
if print_results:
|
338 |
+
baseline_name = list(systems.keys())[0]
|
339 |
+
print_significance_results(scores, signatures, baseline_name, significance_level)
|
340 |
+
|
341 |
+
return signatures, scores
|