X-iZhang commited on
Commit
bad8293
·
verified ·
1 Parent(s): a39b152
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ RadEval_banner.png filter=lfs diff=lfs merge=lfs -text
RadEval.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ import stanza
3
+ import warnings
4
+ import logging
5
+ import os
6
+ import re
7
+ from nlg.rouge.rouge import Rouge
8
+ from nlg.bleu.bleu import Bleu
9
+ from nlg.bertscore.bertscore import BertScore
10
+ from radgraph import F1RadGraph
11
+ from factual.green_score import GREEN
12
+ from factual.RaTEScore import RaTEScore
13
+ from factual.f1temporal import F1Temporal
14
+ from torch import nn
15
+ import pandas as pd
16
+ import numpy as np
17
+ from sklearn.metrics import classification_report
18
+ from sklearn.exceptions import UndefinedMetricWarning
19
+ import json
20
+ from factual.f1chexbert import F1CheXbert
21
+ import nltk
22
+ from utils import clean_numbered_list
23
+ from factual.RadCliQv1.radcliq import CompositeMetric
24
+ from factual.SRRBert.srr_bert import SRRBert, srr_bert_parse_sentences
25
+ from nlg.radevalbertscore import RadEvalBERTScorer
26
+ # Suppress Warning
27
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
28
+ warnings.filterwarnings('ignore')
29
+ logging.basicConfig(level=logging.ERROR)
30
+
31
+
32
+
33
+
34
+ class RadEval():
35
+ def __init__(self,
36
+ do_radgraph=False,
37
+ do_green=False,
38
+ do_bleu=False,
39
+ do_rouge=False,
40
+ do_bertscore=False,
41
+ do_srr_bert=False,
42
+ do_chexbert=False,
43
+ do_ratescore=False,
44
+ do_radcliq=False,
45
+ do_radeval_bertsore=False,
46
+ do_temporal=False,
47
+ do_details=False,
48
+ ):
49
+ super(RadEval, self).__init__()
50
+
51
+ self.do_radgraph = do_radgraph
52
+ self.do_green = do_green
53
+ self.do_bleu = do_bleu
54
+ self.do_rouge = do_rouge
55
+ self.do_bertscore = do_bertscore
56
+ self.do_srr_bert = do_srr_bert
57
+ self.do_chexbert = do_chexbert
58
+ self.do_ratescore = do_ratescore
59
+ self.do_radcliq = do_radcliq
60
+ self.do_temporal = do_temporal
61
+ self.do_radeval_bertsore = do_radeval_bertsore
62
+ self.do_details = do_details
63
+
64
+ # Initialize scorers only once
65
+ if self.do_radgraph:
66
+ self.radgraph_scorer = F1RadGraph(reward_level="all", model_type="radgraph-xl")
67
+ if self.do_bleu:
68
+ self.bleu_scorer = Bleu()
69
+ self.bleu_scorer_1 = Bleu(n=1)
70
+ self.bleu_scorer_2 = Bleu(n=2)
71
+ self.bleu_scorer_3 = Bleu(n=3)
72
+ if self.do_bertscore:
73
+ self.bertscore_scorer = BertScore(model_type='distilbert-base-uncased',
74
+ num_layers=5)
75
+ if self.do_green:
76
+ # Initialize green scorer here if needed
77
+ self.green_scorer = GREEN("StanfordAIMI/GREEN-radllama2-7b",
78
+ output_dir=".")
79
+
80
+ if self.do_rouge:
81
+ self.rouge_scorers = {
82
+ "rouge1": Rouge(rouges=["rouge1"]),
83
+ "rouge2": Rouge(rouges=["rouge2"]),
84
+ "rougeL": Rouge(rouges=["rougeL"])
85
+ }
86
+
87
+ if self.do_srr_bert:
88
+ nltk.download('punkt_tab', quiet=True)
89
+ self.srr_bert_scorer = SRRBert(model_type="leaves_with_statuses")
90
+
91
+
92
+ if self.do_chexbert:
93
+ self.chexbert_scorer = F1CheXbert()
94
+
95
+ if self.do_ratescore:
96
+ self.ratescore_scorer = RaTEScore()
97
+
98
+ if self.do_radcliq:
99
+ self.radcliq_scorer = CompositeMetric()
100
+
101
+ if self.do_temporal:
102
+ stanza.download('en', package='radiology', processors={'ner': 'radiology'})
103
+ self.F1Temporal = F1Temporal
104
+
105
+ if self.do_radeval_bertsore:
106
+ self.radeval_bertsore = RadEvalBERTScorer(
107
+ model_type="IAMJB/RadEvalModernBERT",
108
+ num_layers=22,
109
+ use_fast_tokenizer=True,
110
+ rescale_with_baseline=False)
111
+ # Store the metric keys
112
+ self.metric_keys = []
113
+ if self.do_radgraph:
114
+ self.metric_keys.extend(["radgraph_simple", "radgraph_partial", "radgraph_complete"])
115
+ if self.do_bleu:
116
+ self.metric_keys.append("bleu")
117
+ if self.do_green:
118
+ self.metric_keys.append("green")
119
+ if self.do_bertscore:
120
+ self.metric_keys.append("bertscore")
121
+ if self.do_rouge:
122
+ self.metric_keys.extend(self.rouge_scorers.keys())
123
+ if self.do_srr_bert:
124
+ self.metric_keys.extend(["samples_avg_precision", "samples_avg_recall", "samples_avg_f1-score"])
125
+
126
+ if self.do_chexbert:
127
+ self.metric_keys.extend([
128
+ "chexbert-5_micro avg_f1-score",
129
+ "chexbert-all_micro avg_f1-score",
130
+ "chexbert-5_macro avg_f1-score",
131
+ "chexbert-all_macro avg_f1-score"
132
+ ])
133
+
134
+ if self.do_ratescore:
135
+ self.metric_keys.append("ratescore")
136
+ if self.do_radcliq:
137
+ self.metric_keys.append("radcliqv1")
138
+ if self.do_temporal:
139
+ self.metric_keys.append("temporal_f1")
140
+ if self.do_radeval_bertsore:
141
+ self.metric_keys.append("radeval_bertsore")
142
+
143
+ def __call__(self, refs, hyps):
144
+ if not (isinstance(hyps, list) and isinstance(refs, list)):
145
+ raise TypeError("hyps and refs must be of type list")
146
+ if len(hyps) != len(refs):
147
+ raise ValueError("hyps and refs lists don't have the same size")
148
+ if len(refs) == 0:
149
+ return {}
150
+
151
+ scores = self.compute_scores(refs=refs, hyps=hyps)
152
+ return scores
153
+
154
+ def compute_scores(self, refs, hyps):
155
+ if not (isinstance(hyps, list) and isinstance(refs, list)):
156
+ raise TypeError("hyps and refs must be of type list")
157
+ if len(hyps) != len(refs):
158
+ raise ValueError("hyps and refs lists don't have the same size")
159
+
160
+ scores = {}
161
+ if self.do_radgraph:
162
+ radgraph_scores = self.radgraph_scorer(refs=refs, hyps=hyps)
163
+
164
+ if self.do_details:
165
+ f1_scores = radgraph_scores[0]
166
+ individual_scores = radgraph_scores[1]
167
+ hyps_entities = radgraph_scores[2]
168
+ refs_entities = radgraph_scores[3]
169
+
170
+ scores["radgraph"] = {
171
+ "radgraph_simple": f1_scores[0],
172
+ "radgraph_partial": f1_scores[1],
173
+ "radgraph_complete": f1_scores[2],
174
+ "reward_list": individual_scores,
175
+ "hypothesis_annotation_lists": hyps_entities,
176
+ "reference_annotation_lists": refs_entities
177
+ }
178
+
179
+ else:
180
+ radgraph_scores = radgraph_scores[0]
181
+ scores["radgraph_simple"] = radgraph_scores[0]
182
+ scores["radgraph_partial"] = radgraph_scores[1]
183
+ scores["radgraph_complete"] = radgraph_scores[2]
184
+
185
+ if self.do_bleu:
186
+ if self.do_details:
187
+ bleu_1_score = self.bleu_scorer_1(refs, hyps)[0]
188
+ bleu_2_score = self.bleu_scorer_2(refs, hyps)[0]
189
+ bleu_3_score = self.bleu_scorer_3(refs, hyps)[0]
190
+ bleu_4_score = self.bleu_scorer(refs, hyps)[0]
191
+
192
+ scores["bleu"] = {
193
+ "bleu_1": bleu_1_score,
194
+ "bleu_2": bleu_2_score,
195
+ "bleu_3": bleu_3_score,
196
+ "bleu_4": bleu_4_score
197
+ }
198
+ else:
199
+ scores["bleu"] = self.bleu_scorer(refs, hyps)[0]
200
+
201
+ if self.do_bertscore:
202
+ if self.do_details:
203
+ bertscore_scores, sample_scores = self.bertscore_scorer(refs, hyps)
204
+ scores["bertscore"] = {
205
+ "mean_score": bertscore_scores,
206
+ "sample_scores": sample_scores
207
+ }
208
+ else:
209
+ scores["bertscore"] = self.bertscore_scorer(refs, hyps)[0]
210
+
211
+ if self.do_green:
212
+ # Use the initialized green scorer
213
+ mean, std, sample_scores, summary, _ = self.green_scorer(refs, hyps)
214
+ if self.do_details:
215
+ scores["green"] = {
216
+ "mean": mean,
217
+ "std": std,
218
+ "sample_scores": sample_scores,
219
+ "summary": summary
220
+ }
221
+ else:
222
+ scores["green"] = mean
223
+
224
+ if self.do_rouge:
225
+ if self.do_details:
226
+ rouge_scores = {}
227
+ for key, scorer in self.rouge_scorers.items():
228
+ mean, sample_scores = scorer(refs, hyps)
229
+ rouge_scores[key] = {
230
+ "mean_score": mean,
231
+ "sample_scores": sample_scores
232
+ }
233
+
234
+ scores["rouge"] = rouge_scores
235
+ else:
236
+ for key, scorer in self.rouge_scorers.items():
237
+ scores[key] = scorer(refs, hyps)[0]
238
+
239
+ if self.do_srr_bert:
240
+ # Clean reports before tokenization
241
+ parsed_refs = [srr_bert_parse_sentences(ref) for ref in refs]
242
+ parsed_hyps = [srr_bert_parse_sentences(hyp) for hyp in hyps]
243
+
244
+
245
+ section_level_hyps_pred = []
246
+ section_level_refs_pred = []
247
+ for parsed_hyp, parsed_ref in zip(parsed_hyps, parsed_refs):
248
+ outputs, _ = self.srr_bert_scorer(sentences=parsed_ref + parsed_hyp)
249
+
250
+ refs_preds = outputs[:len(parsed_ref)]
251
+ hyps_preds = outputs[len(parsed_ref):]
252
+
253
+ merged_refs_preds = np.any(refs_preds, axis=0).astype(int)
254
+ merged_hyps_preds = np.any(hyps_preds, axis=0).astype(int)
255
+
256
+ section_level_hyps_pred.append(merged_hyps_preds)
257
+ section_level_refs_pred.append(merged_refs_preds)
258
+
259
+ label_names = [label for label, idx in sorted(self.srr_bert_scorer.mapping.items(), key=lambda x: x[1])]
260
+ classification_dict = classification_report(section_level_refs_pred,
261
+ section_level_hyps_pred,
262
+ target_names=label_names,
263
+ output_dict=True,
264
+ zero_division=0)
265
+
266
+ if self.do_details:
267
+ label_scores = {}
268
+ for label in label_names:
269
+ if label in classification_dict:
270
+ f1 = classification_dict[label]["f1-score"]
271
+ support = classification_dict[label]["support"]
272
+ if f1 > 0 or support > 0:
273
+ label_scores[label] = {
274
+ "f1-score": f1,
275
+ "precision": classification_dict[label]["precision"],
276
+ "recall": classification_dict[label]["recall"],
277
+ "support": support
278
+ }
279
+
280
+ scores["srr_bert"] = {
281
+ "srr_bert_weighted_f1": classification_dict["weighted avg"]["f1-score"],
282
+ "srr_bert_weighted_precision": classification_dict["weighted avg"]["precision"],
283
+ "srr_bert_weighted_recall": classification_dict["weighted avg"]["recall"],
284
+ "label_scores": label_scores
285
+ }
286
+ else:
287
+ scores["srr_bert_weighted_f1"] = classification_dict["weighted avg"]["f1-score"]
288
+ scores["srr_bert_weighted_precision"] = classification_dict["weighted avg"]["precision"]
289
+ scores["srr_bert_weighted_recall"] = classification_dict["weighted avg"]["recall"]
290
+
291
+
292
+
293
+ if self.do_chexbert:
294
+ accuracy, accuracy_per_sample, chexbert_all, chexbert_5 = self.chexbert_scorer(hyps, refs)
295
+ if self.do_details:
296
+ chexbert_5_labels = {
297
+ k: v["f1-score"]
298
+ for k, v in list(chexbert_5.items())[:-4]
299
+ }
300
+
301
+ chexbert_all_labels = {
302
+ k: v["f1-score"]
303
+ for k, v in list(chexbert_all.items())[:-4]
304
+ }
305
+
306
+ scores["chexbert"] = {
307
+ "chexbert-5_micro avg_f1-score": chexbert_5["micro avg"]["f1-score"],
308
+ "chexbert-all_micro avg_f1-score": chexbert_all["micro avg"]["f1-score"],
309
+ "chexbert-5_macro avg_f1-score": chexbert_5["macro avg"]["f1-score"],
310
+ "chexbert-all_macro avg_f1-score": chexbert_all["macro avg"]["f1-score"],
311
+ "chexbert-5_weighted_f1": chexbert_5["weighted avg"]["f1-score"],
312
+ "chexbert-all_weighted_f1": chexbert_all["weighted avg"]["f1-score"],
313
+ "label_scores_f1-score": {
314
+ "chexbert-5": chexbert_5_labels,
315
+ "chexbert_all": chexbert_all_labels
316
+ }
317
+ }
318
+ else:
319
+ scores["chexbert-5_micro avg_f1-score"] = chexbert_5["micro avg"]["f1-score"]
320
+ scores["chexbert-all_micro avg_f1-score"] = chexbert_all["micro avg"]["f1-score"]
321
+ scores["chexbert-5_macro avg_f1-score"] = chexbert_5["macro avg"]["f1-score"]
322
+ scores["chexbert-all_macro avg_f1-score"] = chexbert_all["macro avg"]["f1-score"]
323
+ scores["chexbert-5_weighted_f1"] = chexbert_5["weighted avg"]["f1-score"]
324
+ scores["chexbert-all_weighted_f1"] = chexbert_all["weighted avg"]["f1-score"]
325
+
326
+ if self.do_ratescore:
327
+ rate_score, pred_pairs_raw ,gt_pairs_raw = self.ratescore_scorer.compute_score(candidate_list=hyps, reference_list=refs)
328
+ f1_ratescore = float(np.mean(rate_score))
329
+ if self.do_details:
330
+ pred_pairs = [
331
+ {ent: label for ent, label in sample}
332
+ for sample in pred_pairs_raw
333
+ ]
334
+ gt_pairs = [
335
+ {ent: label for ent, label in sample}
336
+ for sample in gt_pairs_raw
337
+ ]
338
+ scores["ratescore"] = {
339
+ "f1-score": f1_ratescore,
340
+ "hyps_pairs": pred_pairs,
341
+ "refs_pairs": gt_pairs
342
+ }
343
+ else:
344
+ scores["ratescore"] = f1_ratescore
345
+
346
+ if self.do_radcliq:
347
+ mean_scores, detail_scores = self.radcliq_scorer.predict(refs, hyps)
348
+ if self.do_details:
349
+ scores["radcliq-v1"] = {
350
+ "mean_score": mean_scores,
351
+ "sample_scores": detail_scores.tolist()
352
+ }
353
+ else:
354
+ scores["radcliq-v1"] = mean_scores
355
+
356
+ if self.do_temporal:
357
+ temporal_scores = self.F1Temporal(predictions=hyps, references=refs)
358
+ if self.do_details:
359
+ hyp_entities = [
360
+ sorted(list(group)) if group else []
361
+ for group in temporal_scores.get("prediction_entities", [])
362
+ ]
363
+ ref_entities = [
364
+ sorted(list(group)) if group else []
365
+ for group in temporal_scores.get("reference_entities", [])
366
+ ]
367
+ scores["temporal_f1"] = {
368
+ "f1-score": temporal_scores["f1"],
369
+ "hyps_entities": hyp_entities,
370
+ "refs_entities": ref_entities
371
+ }
372
+ else:
373
+ scores["temporal_f1"] = temporal_scores["f1"]
374
+
375
+ if self.do_radeval_bertsore:
376
+ radeval_bertsores = self.radeval_bertsore.score(refs=refs, hyps=hyps)
377
+ if self.do_details:
378
+ scores["radeval_bertsore"] = {
379
+ "f1-score": radeval_bertsores[0],
380
+ "sample_scores": radeval_bertsores[1].tolist()
381
+ }
382
+ else:
383
+ scores["radeval_bertsore"] = radeval_bertsores[0]
384
+
385
+ return scores
386
+
387
+
388
+ def main():
389
+ refs = [
390
+ "No acute cardiopulmonary process.",
391
+ "No radiographic findings to suggest pneumonia.",
392
+ "1.Status post median sternotomy for CABG with stable cardiac enlargement and calcification of the aorta consistent with atherosclerosis.Relatively lower lung volumes with no focal airspace consolidation appreciated.Crowding of the pulmonary vasculature with possible minimal perihilar edema, but no overt pulmonary edema.No pleural effusions or pneumothoraces.",
393
+ "1. Left PICC tip appears to terminate in the distal left brachiocephalic vein.2. Mild pulmonary vascular congestion.3. Interval improvement in aeration of the lung bases with residual streaky opacity likely reflective of atelectasis.Interval resolution of the left pleural effusion.",
394
+ "No definite acute cardiopulmonary process.Enlarged cardiac silhouette could be accentuated by patient's positioning.",
395
+ "Increased mild pulmonary edema and left basal atelectasis.",
396
+ ]
397
+
398
+ hyps = [
399
+ "No acute cardiopulmonary process.",
400
+ "No radiographic findings to suggest pneumonia.",
401
+ "Status post median sternotomy for CABG with stable cardiac enlargement and calcification of the aorta consistent with atherosclerosis.",
402
+ "Relatively lower lung volumes with no focal airspace consolidation appreciated.",
403
+ "Crowding of the pulmonary vasculature with possible minimal perihilar edema, but no overt pulmonary edema.",
404
+ "No pleural effusions or pneumothoraces.",
405
+ ]
406
+
407
+ evaluator = RadEval(do_radgraph=True,
408
+ do_green=False,
409
+ do_bleu=True,
410
+ do_rouge=True,
411
+ do_bertscore=True,
412
+ do_srr_bert=True,
413
+ do_chexbert=True,
414
+ do_temporal=True,
415
+ do_ratescore=True,
416
+ do_radcliq=True,
417
+ do_radeval_bertsore=True)
418
+
419
+ results = evaluator(refs=refs, hyps=hyps)
420
+ print(json.dumps(results, indent=4))
421
+
422
+
423
+ if __name__ == '__main__':
424
+ main()
RadEval_banner.png ADDED

Git LFS Details

  • SHA256: 2ddd4245bbadd24dae93ac925134bd2aea5548b94b8b55c8c8a21c8fd0338709
  • Pointer size: 132 Bytes
  • Size of remote file: 1.18 MB
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .RadEval import RadEval
2
+ from .utils import compare_systems
factual/RaTEScore/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .score import *
2
+ from .scorer import RaTEScore
factual/RaTEScore/score.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import medspacy
3
+ nlp = medspacy.load(medspacy_enable=["medspacy_pyrush", "medspacy_context"])
4
+
5
+ from .utils import sentence_split, post_process
6
+
7
+ def run_ner(texts, idx2label, tokenizer, model, device, batch_size):
8
+
9
+ clean_text_list, is_start_list = sentence_split(texts)
10
+
11
+ predicted_labels = []
12
+
13
+ for i in range(0, len(clean_text_list), batch_size):
14
+ batch_text = clean_text_list[i:i+batch_size]
15
+
16
+ inputs = tokenizer(batch_text,
17
+ max_length=512,
18
+ padding=True,
19
+ truncation=True,
20
+ return_tensors="pt").to(device)
21
+
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+
25
+ predicted_labels.extend(torch.argmax(outputs.logits, dim=2).tolist())
26
+
27
+ inputs = tokenizer(clean_text_list,
28
+ max_length=512,
29
+ padding=True,
30
+ truncation=True,
31
+ return_tensors="pt")
32
+
33
+ save_pairs = []
34
+
35
+ pad_token_id = tokenizer.pad_token_id
36
+
37
+ for i, is_start in enumerate(is_start_list):
38
+
39
+ predicted_entities = [idx2label[label] for label in predicted_labels[i]]
40
+
41
+ non_pad_mask = inputs["input_ids"][i] != pad_token_id
42
+ non_pad_length = non_pad_mask.sum().item()
43
+ non_pad_input_ids = inputs["input_ids"][i][:non_pad_length]
44
+
45
+ tokenized_text = tokenizer.convert_ids_to_tokens(non_pad_input_ids)
46
+
47
+ if is_start:
48
+ save_pair = post_process(tokenized_text, predicted_entities, tokenizer)
49
+ else:
50
+ save_pair = post_process(tokenized_text, predicted_entities, tokenizer)
51
+ save_pairs[-1].extend(save_pair)
52
+ continue
53
+
54
+ save_pairs.append(save_pair)
55
+
56
+ return save_pairs
57
+
58
+
59
+ def process_embedding(pair, eval_tokenizer, eval_model, device):
60
+ entities = [pair[0] for pair in pair]
61
+ types = [pair[1] for pair in pair]
62
+
63
+ if len(entities) == 0:
64
+ embeds_word = torch.tensor([])
65
+ else:
66
+ embeds_word = torch.tensor([]).to(device)
67
+
68
+ with torch.no_grad():
69
+ # tokenize the queries
70
+ encoded = eval_tokenizer(
71
+ entities,
72
+ truncation=True,
73
+ padding=True,
74
+ return_tensors='pt',
75
+ max_length=30,
76
+ ).to(device)
77
+
78
+ # encode the queries (use the [CLS] last hidden states as the representations)
79
+ embeds_word = torch.cat((embeds_word.to('cpu'),
80
+ eval_model(**encoded).last_hidden_state[:, 0, :].to('cpu')), dim=0)
81
+
82
+ return embeds_word, types
83
+
factual/RaTEScore/scorer.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import json
3
+ import numpy as np
4
+ from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoModelForTokenClassification
5
+ import pandas as pd
6
+ import os
7
+
8
+ from .score import run_ner, process_embedding
9
+ from .utils import compute
10
+
11
+
12
+ DEFAULT_MATRIX_LONG = {"abnormality_abnormality": 0.4276119164393705, "abnormality_anatomy": 0.6240929990607657, "abnormality_disease": 0.0034478181112993847, "abnormality_non-abnormality": 0.5431049700217344, "abnormality_non-disease": 0.27005425386213877, "anatomy_abnormality": 0.7487824274337533, "anatomy_anatomy": 0.2856134859160784, "anatomy_disease": 0.4592143222158069, "anatomy_non-abnormality": 0.02097055139911715, "anatomy_non-disease": 0.00013736314126696204, "disease_abnormality": 0.8396510075734789, "disease_anatomy": 0.9950209388542061, "disease_disease": 0.8460555030578727, "disease_non-abnormality": 0.9820689020512646, "disease_non-disease": 0.3789136708096537, "non-abnormality_abnormality": 0.16546764653692908, "non-abnormality_anatomy": 0.018670610691852826, "non-abnormality_disease": 0.719397354576018, "non-abnormality_non-abnormality": 0.0009357166071730684, "non-abnormality_non-disease": 0.0927333564267591, "non-disease_abnormality": 0.7759420231214385, "non-disease_anatomy": 0.1839139293714062, "non-disease_disease": 0.10073046076318157, "non-disease_non-abnormality": 0.03860183811876373, "non-disease_non-disease": 0.34065681486566446, "neg_weight":0.8716553966489615}
13
+ DEFAULT_MATRIX_SHORT = {"abnormality_abnormality": 0.4070293318365468, "abnormality_anatomy": 0.6952639610605605, "abnormality_disease": 0.28342529466226446, "abnormality_non-abnormality": 0.9479148658006686, "abnormality_non-disease": 0.23875064111146294, "anatomy_abnormality": 0.5829759950441763, "anatomy_anatomy": 0.7709590751917746, "anatomy_disease": 0.0006059634829551632, "anatomy_non-abnormality": 0.794672584951181, "anatomy_non-disease": 0.27982942400798977, "disease_abnormality": 0.8840397619834857, "disease_anatomy": 0.9637659445696822, "disease_disease": 0.19018958438059513, "disease_non-abnormality": 0.6962283914800402, "disease_non-disease": 0.943727057946997, "non-abnormality_abnormality": 0.1712744286898638, "non-abnormality_anatomy": 0.4485149671497294, "non-abnormality_disease": 0.00045065329822896076, "non-abnormality_non-abnormality": 0.0007887930317199857, "non-abnormality_non-disease": 0.8555432840895761, "non-disease_abnormality": 0.9555801066212176, "non-disease_anatomy": 0.13122106162635216, "non-disease_disease": 0.6072996585919443, "non-disease_non-abnormality": 0.05650711141169969, "non-disease_non-disease": 0.3214769399791204, "neg_weight":0.3611577852354489}
14
+
15
+
16
+ class RaTEScore:
17
+ def __init__(self,
18
+ bert_model="Angelakeke/RaTE-NER-Deberta",
19
+ eval_model='FremyCompany/BioLORD-2023-C',
20
+ batch_size=1,
21
+ use_gpu=True,
22
+ visualization_path=None,
23
+ affinity_matrix="long",
24
+ ):
25
+ """ RaTEScore is a novel, entity-aware metric to assess the quality of medical reports generated by AI models.
26
+ It emphasizes crucial medical entities such as diagnostic outcomes and anatomical details, and is robust
27
+ against complex medical synonyms and sensitive to negation expressions. The evaluations demonstrate that
28
+ RaTEScore aligns more closely with human preference than existing metrics.
29
+
30
+ Args:
31
+ bert_model (str, optional): Medical entity recognition modul module. Defaults to "Angelakeke/RaTE-NER-Deberta".
32
+ eval_model (str, optional): Synonym disambuation encoding module. Defaults to 'FremyCompany/BioLORD-2023-C'.
33
+ batch_size (int, optional): Batch size to choose. Defaults to 1.
34
+ use_gpu (bool, optional): If to use gpu. Defaults to True.
35
+ visualization_path (str, optional): Output the visualized files, default to save as a json file. Defaults to None.
36
+ affinity_matrix (str, optional):pre-searched type weight and can be changed due to the human rating bias.
37
+ Defaults to 'long'.
38
+
39
+ """
40
+
41
+ # if use_gpu
42
+ if use_gpu:
43
+ self.device = torch.device('cuda')
44
+ else:
45
+ self.device = torch.device('cpu')
46
+
47
+ # load the Medical entity recognition module
48
+ self.tokenizer = AutoTokenizer.from_pretrained(bert_model)
49
+ self.model = AutoModelForTokenClassification.from_pretrained(bert_model).eval().to(self.device)
50
+
51
+ # load the Synonym disambuation module
52
+ self.eval_tokenizer = AutoTokenizer.from_pretrained(eval_model)
53
+ self.eval_model = AutoModel.from_pretrained(eval_model).eval().to(self.device)
54
+
55
+ # load the weight matrix
56
+ if isinstance(affinity_matrix, str):
57
+ # Choose the appropriate matrix based on the argument
58
+ if affinity_matrix.lower() == "long":
59
+ self.matrix_path = DEFAULT_MATRIX_LONG
60
+ elif affinity_matrix.lower() == "short":
61
+ self.matrix_path = DEFAULT_MATRIX_SHORT
62
+ else:
63
+ # Assume it's a file path
64
+ try:
65
+ with open(affinity_matrix, 'r') as f:
66
+ self.matrix_path = json.load(f)
67
+ except Exception as e:
68
+ raise ValueError(f"Failed to load affinity matrix from {affinity_matrix}: {e}")
69
+ else:
70
+ raise ValueError("affinity_matrix must be a string")
71
+
72
+ self.affinity_matrix = {(k.split('_')[0].upper(), k.split('_')[1].upper()):v for k,v in self.matrix_path.items()}
73
+
74
+ # load the label file
75
+ self.config = AutoConfig.from_pretrained(bert_model)
76
+ self.label2idx = self.config.label2id
77
+ self.idx2label = self.config.id2label
78
+
79
+ # save the input
80
+ self.batch_size = batch_size
81
+
82
+ if visualization_path:
83
+ self.visualization_path = visualization_path
84
+ if not os.path.exists(os.path.dirname(visualization_path)):
85
+ os.makedirs(os.path.dirname(visualization_path))
86
+ else:
87
+ self.visualization_path = None
88
+
89
+
90
+ def compute_score(self, candidate_list, reference_list):
91
+ '''Compute the RaTEScore for the candidate and reference reports.
92
+
93
+ Args:
94
+ candidate_list (list): list of candidate reports
95
+ reference_list (list): list of reference reports
96
+ '''
97
+
98
+ # check if candidate and reference are list
99
+ if not isinstance(candidate_list, list):
100
+ raise ValueError("candidate must be a list")
101
+ if not isinstance(reference_list, list):
102
+ raise ValueError("reference must be a list")
103
+
104
+ assert len(candidate_list) == len(reference_list), "candidate and reference must have the same length"
105
+
106
+ # check if candidate and reference are list of strings
107
+ if not all(isinstance(x, str) for x in candidate_list):
108
+ raise ValueError("candidate must be a list of strings")
109
+
110
+ gt_pairs = run_ner(reference_list, self.idx2label, self.tokenizer, self.model, self.device, self.batch_size)
111
+ pred_pairs = run_ner(candidate_list, self.idx2label, self.tokenizer, self.model, self.device, self.batch_size)
112
+
113
+ rate_score = []
114
+
115
+ for gt_pair, pred_pair in zip(gt_pairs, pred_pairs):
116
+
117
+ # process the embedding for gt
118
+ gt_embeds_word, gt_types = process_embedding(gt_pair, self.eval_tokenizer, self.eval_model, self.device)
119
+
120
+ # process the embedding for pred
121
+ pred_embeds_word, pred_types = process_embedding(pred_pair, self.eval_tokenizer, self.eval_model, self.device)
122
+
123
+ # compute the score, if the length of gt or pred is 0, the score is 0.5
124
+ if len(gt_embeds_word) == 0 or len(pred_embeds_word) == 0:
125
+ rate_score.append(0.5)
126
+ continue
127
+
128
+ precision_score = compute(gt_embeds_word, pred_embeds_word, gt_types, pred_types, self.affinity_matrix)
129
+ recall_score = compute(pred_embeds_word, gt_embeds_word, pred_types, gt_types, self.affinity_matrix)
130
+
131
+ if precision_score + recall_score == 0:
132
+ rate_score.append(0)
133
+ else:
134
+ rate_score.append(2*precision_score*recall_score/(precision_score+recall_score))
135
+
136
+ if self.visualization_path:
137
+ save_file = pd.DataFrame({
138
+ 'candidate': candidate_list,
139
+ 'reference': reference_list,
140
+ 'candidate_entities': pred_pairs,
141
+ 'reference_entities': gt_pairs,
142
+ 'rate_score': rate_score
143
+ })
144
+ save_file.to_json(os.path.join(self.visualization_path, 'rate_score.json'), lines=True, orient='records')
145
+
146
+ return rate_score, pred_pairs ,gt_pairs
factual/RaTEScore/utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import medspacy
4
+ nlp = medspacy.load(medspacy_enable=["medspacy_pyrush", "medspacy_conte"])
5
+
6
+ def sentence_split(text_list):
7
+ """
8
+ split sentences by medspacy
9
+ """
10
+ clean_text_list = []
11
+ is_start_list = []
12
+
13
+ for text in text_list:
14
+
15
+ doc = nlp(text)
16
+
17
+ is_start = 1
18
+
19
+ for sent in doc.sents:
20
+ sent = str(sent).strip()
21
+ # # check if the sentence has no words
22
+ if len(sent.split()) == 0:
23
+ continue
24
+ if len(sent) < 3:
25
+ continue
26
+ is_start_list.append(is_start)
27
+ clean_text_list.append(sent)
28
+ is_start = 0
29
+
30
+ return clean_text_list, is_start_list
31
+
32
+ def post_process(tokenized_text, predicted_entities, tokenizer):
33
+ entity_spans = []
34
+ start = end = None
35
+ entity_type = None
36
+
37
+ for i, (token, label) in enumerate(zip(tokenized_text, predicted_entities[:len(tokenized_text)])):
38
+ if token in ["[CLS]", "[SEP]"]:
39
+ continue
40
+ if label != "O" and i < len(predicted_entities) - 1:
41
+ if label.startswith("B-") and predicted_entities[i+1].startswith("I-"):
42
+ start = i
43
+ entity_type = label[2:]
44
+ elif label.startswith("B-") and predicted_entities[i+1].startswith("B-"):
45
+ start = i
46
+ end = i
47
+ entity_spans.append((start, end, label[2:]))
48
+ start = i
49
+ entity_type = label[2:]
50
+ elif label.startswith("B-") and predicted_entities[i+1].startswith("O"):
51
+ start = i
52
+ end = i
53
+ entity_spans.append((start, end, label[2:]))
54
+ start = end = None
55
+ entity_type = None
56
+ elif label.startswith("I-") and predicted_entities[i+1].startswith("B-"):
57
+ end = i
58
+ if start is not None:
59
+ entity_spans.append((start, end, entity_type))
60
+ start = i
61
+ entity_type = label[2:]
62
+ elif label.startswith("I-") and predicted_entities[i+1].startswith("O"):
63
+ end = i
64
+ if start is not None:
65
+ entity_spans.append((start, end, entity_type))
66
+ start = end = None
67
+ entity_type = None
68
+
69
+ # 处理最后一个实体
70
+ if start is not None and end is None:
71
+ end = len(tokenized_text) - 2
72
+ entity_spans.append((start, end, entity_type))
73
+
74
+ # 输出结果
75
+ save_pair = []
76
+ for start, end, entity_type in entity_spans:
77
+ entity_str = tokenizer.convert_tokens_to_string(tokenized_text[start:end+1])
78
+ # print(f"实体: {entity_str}, 类型: {entity_type}")
79
+ save_pair.append((entity_str, entity_type))
80
+
81
+ return save_pair
82
+
83
+
84
+ def topk_similarity(embeddings1, embeddings2, k=1):
85
+ """
86
+ Compute the top-k similarity between two sets of embeddings using PyTorch.
87
+ """
88
+
89
+ ### Normalize the embeddings to use cosine similarity
90
+ embeddings1 = F.normalize(embeddings1, p=2, dim=1)
91
+ embeddings2 = F.normalize(embeddings2, p=2, dim=1)
92
+
93
+ topk_values = []
94
+ topk_indices = []
95
+
96
+ ### Iterate over each embedding in the first set
97
+ for emb1 in embeddings1:
98
+
99
+ ### Calculate cosine similarity between this embedding and all embeddings in the second set
100
+ similarities = torch.matmul(embeddings2, emb1)
101
+
102
+ ### Find the top-k highest similarity values
103
+ values, indices = torch.topk(similarities, k, largest=True)
104
+
105
+ topk_values.append(values[0])
106
+ topk_indices.append(indices[0])
107
+
108
+ return topk_indices, topk_values
109
+
110
+ def compute(gt_embeds_word, pred_embeds_word, gt_types, pred_types, weight_matrix):
111
+ neg_class = [('NON-DISEASE', 'DISEASE'),
112
+ ('NON-ABNORMALITY', 'ABNORMALITY'),
113
+ ('DISEASE', 'NON-DISEASE'),
114
+ ('ABNORMALITY', 'NON-ABNORMALITY'),
115
+ ('NON-DISEASE', 'ABNORMALITY'),
116
+ ('NON-ABNORMALITY', 'DISEASE'),
117
+ ('DISEASE', 'NON-ABNORMALITY'),
118
+ ('ABNORMALITY', 'NON-DISEASE'),]
119
+ neg_weight = weight_matrix[("NEG", "WEIGHT")]
120
+ topk_indices, topk_values = topk_similarity(gt_embeds_word, pred_embeds_word, k=1)
121
+
122
+
123
+ for i in range(len(topk_indices)):
124
+ topk_indices[i] = topk_indices[i].cpu().numpy().tolist()
125
+ topk_values[i] = topk_values[i].cpu().numpy().tolist()
126
+
127
+ # map the indices to type
128
+ topk_map = [pred_types[i] for i in topk_indices]
129
+
130
+ weight_score = [weight_matrix[(gt_type, pred_type)] for gt_type, pred_type in zip(gt_types, topk_map)]
131
+ type_score = [neg_weight if (gt_type, pred_type) in neg_class else 1 for gt_type, pred_type in zip(gt_types, topk_map)]
132
+
133
+ weighted_avg_score = 0
134
+ weighted_sum = 0
135
+ for score, weight, type in zip(topk_values, weight_score, type_score):
136
+ weighted_avg_score += score*weight*type
137
+ weighted_sum += weight
138
+ if weighted_sum != 0:
139
+ RaTE = weighted_avg_score/weighted_sum
140
+ else:
141
+ RaTE = 0
142
+
143
+ return RaTE
factual/RadCliQv1/radcliq.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from nlg.bertscore.bertscore import BertScore
4
+ from radgraph import RadGraph
5
+ from factual.f1chexbert import F1CheXbert
6
+ from sklearn.preprocessing import StandardScaler
7
+ from nlg.bleu.bleu import Bleu
8
+
9
+
10
+ def radcliq_bertscore(refs, hyps, model_type='distilroberta-base'):
11
+ """
12
+ Computes BERTScore for each pair of reference and hypothesis.
13
+
14
+ Returns:
15
+ np.ndarray of shape (N,) with the BERTScore F1 values per pair.
16
+ """
17
+ # https://github.com/rajpurkarlab/CXR-Report-Metric/blob/9c9ecad39be6cb2be8e75be1d1c50ef8888a3e40/CXRMetric/run_eval.py#L103
18
+ scorer = BertScore(
19
+ model_type=model_type,
20
+ rescale_with_baseline=True,
21
+ idf=False,
22
+ num_layers=None
23
+ )
24
+ _, scores = scorer(refs, hyps)
25
+ # scores is a list of torch.Tensor, convert to numpy
26
+ return np.array([float(s) for s in scores])
27
+
28
+
29
+ def compute_f1(test_set, retrieved_set):
30
+ """Helper to compute F1 between two sets of items."""
31
+ tp = len(test_set & retrieved_set)
32
+ fp = len(retrieved_set) - tp
33
+ fn = len(test_set) - tp
34
+ precision = tp / (tp + fp) if (tp + fp) else 0.0
35
+ recall = tp / (tp + fn) if (tp + fn) else 0.0
36
+ return 2 * precision * recall / (precision + recall) if (precision + recall) else 0.0
37
+
38
+
39
+ def extract_entities(output):
40
+ """Extracts set of (tokens, label) tuples from RadGraph output."""
41
+ return {(tuple(ent["tokens"]), ent["label"]) for ent in output.get("entities", {}).values()}
42
+
43
+
44
+ def extract_relations(output):
45
+ """Extracts set of (src, tgt, relation) tuples from RadGraph output."""
46
+ rels = set()
47
+ entities = output.get("entities", {})
48
+ for ent in entities.values():
49
+ src = (tuple(ent["tokens"]), ent["label"])
50
+ for rel_type, tgt_idx in ent.get("relations", []):
51
+ tgt_ent = entities.get(tgt_idx)
52
+ if tgt_ent:
53
+ tgt = (tuple(tgt_ent["tokens"]), tgt_ent["label"])
54
+ rels.add((src, tgt, rel_type))
55
+ return rels
56
+
57
+
58
+ def radcliq_radgraph_scores(refs, hyps, model_name='radgraph'):
59
+ """
60
+ Computes entity and relation F1 via RadGraph for each report pair and returns their average.
61
+
62
+ Returns:
63
+ np.ndarray of shape (N,) with (entity_f1 + relation_f1)/2 per pair.
64
+ """
65
+ rad = RadGraph(model_type=model_name)
66
+ gt_outputs = rad(refs)
67
+ pred_outputs = rad(hyps)
68
+ scores = []
69
+ for i in range(len(refs)):
70
+ gt_out = gt_outputs.get(str(i), {})
71
+ pred_out = pred_outputs.get(str(i), {})
72
+
73
+ ents_gt = extract_entities(gt_out)
74
+ ents_pred = extract_entities(pred_out)
75
+ rels_gt = extract_relations(gt_out)
76
+ rels_pred = extract_relations(pred_out)
77
+
78
+ ent_f1 = compute_f1(ents_gt, ents_pred)
79
+ rel_f1 = compute_f1(rels_gt, rels_pred)
80
+ scores.append((ent_f1 + rel_f1) / 2)
81
+ return np.array(scores)
82
+
83
+
84
+ def semantic_embedding_scores(refs, hyps, device='cpu'):
85
+ """
86
+ Computes per-pair cosine similarity between embeddings from CheXbert labeler.
87
+
88
+ Returns:
89
+ np.ndarray of shape (N,) with cosine similarities per pair.
90
+ """
91
+ if len(refs) != len(hyps):
92
+ raise ValueError(f"refs ({len(refs)}) and hyps ({len(hyps)}) must be same length")
93
+ labeler = F1CheXbert(device=device)
94
+ gt_embs = np.vstack(labeler.get_embeddings(refs))
95
+ pred_embs = np.vstack(labeler.get_embeddings(hyps))
96
+ # https://github.com/rajpurkarlab/CXR-Report-Metric/blob/9c9ecad39be6cb2be8e75be1d1c50ef8888a3e40/CXRMetric/run_eval.py#L126
97
+ dot = np.einsum("nd,nd->n", gt_embs, pred_embs)
98
+ norms = np.linalg.norm(gt_embs, axis=1) * np.linalg.norm(pred_embs, axis=1)
99
+ with np.errstate(divide='ignore', invalid='ignore'):
100
+ sims = np.where(norms > 0, dot / norms, 0.0)
101
+ return sims
102
+
103
+
104
+ def radcliq_scores(refs, hyps,
105
+ bert_model='distilroberta-base',
106
+ radgraph_model='radgraph'):
107
+ """
108
+ Computes BERTScore, RadGraph score, and semantic embedding similarity for each ref-hyp pair.
109
+
110
+ Args:
111
+ refs: List of reference report strings.
112
+ hyps: List of hypothesis report strings.
113
+ device: Device for embedding model ('cpu' or 'cuda').
114
+ bert_model: HuggingFace model name for BERTScore.
115
+ radgraph_model: Model name for RadGraph inference.
116
+
117
+ Returns:
118
+ Dict with keys 'bertscore', 'radgraph', 'semantic', each mapping to a numpy array of shape (N,).
119
+ """
120
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
121
+ # BERTScore
122
+ bert_scores = radcliq_bertscore(refs, hyps, model_type=bert_model)
123
+ # RadGraph
124
+ rad_scores = radcliq_radgraph_scores(refs, hyps, model_name=radgraph_model)
125
+ # Semantic embeddings
126
+ sem_scores = semantic_embedding_scores(refs, hyps, device=device)
127
+
128
+ # BLEU
129
+ bleu_scorer = Bleu()
130
+ bleu_scores = bleu_scorer(refs, hyps)[1]
131
+
132
+ return {
133
+ 'bertscore': bert_scores,
134
+ 'radgraph': rad_scores,
135
+ 'semb_score': sem_scores,
136
+ 'bleu_score': bleu_scores
137
+ }
138
+
139
+
140
+
141
+ class CompositeMetric:
142
+ def __init__(self):
143
+ scaler = StandardScaler(with_mean=True, with_std=True)
144
+ # learnt parameters, infered from
145
+ # https://github.com/rajpurkarlab/CXR-Report-Metric/blob/main/CXRMetric/run_eval.py#L219
146
+ scaler.mean_ = np.array([0.53792312, 0.61757256, 0.76479421, 0.44738335])
147
+ scaler.scale_ = np.array([0.30282584, 0.22430938, 0.25394391, 0.29892717])
148
+ scaler.var_ = np.array([0.09170349, 0.05031470, 0.06448751, 0.08935745])
149
+ scaler.n_samples_seen_ = 160 # integer
150
+ scaler.n_features_in_ = 4 # integer
151
+
152
+ self.scaler = scaler
153
+ self.coefs = np.array([
154
+ -3.77083683e-01, # radgraph weight
155
+ -3.70300100e-01, # bertscore weight
156
+ -2.52616218e-01, # s-emb weight
157
+ 4.31504841e-12, # bleu weight
158
+ 2.46655256e-10 # intercept / bias
159
+ ])
160
+ self.cols = ["radgraph", "bertscore", "semb_score", "bleu_score"]
161
+
162
+ def predict(self, X):
163
+ Xn = self.scaler.transform(X)
164
+ Xn = np.hstack([Xn, np.ones((Xn.shape[0], 1))])
165
+ return Xn @ self.coefs
166
+
167
+ def _build_matrix(self, metrics: dict[str, np.ndarray]) -> np.ndarray:
168
+ """Stack features in the canonical column order."""
169
+ return np.column_stack([metrics[c] for c in self.cols])
170
+
171
+ def predict(self, refs, hyps) -> np.ndarray:
172
+ """
173
+ Args
174
+ ----
175
+ metrics : dict returned by `radcliq_scores`
176
+
177
+ Returns
178
+ -------
179
+ np.ndarray of shape (N,) – RadCliQ-v1 score for each ref/hyp pair.
180
+ """
181
+ metrics = radcliq_scores(refs, hyps)
182
+
183
+ X = self._build_matrix(metrics)
184
+
185
+ Xn = self.scaler.transform(X)
186
+
187
+ # Append bias term
188
+ Xn = np.hstack([Xn, np.ones((Xn.shape[0], 1))])
189
+ scores = Xn @ self.coefs
190
+
191
+ return 1/scores.mean(), scores
192
+
193
+ if __name__ == "__main__":
194
+ refs = [
195
+ "No evidence of pneumothorax following chest tube removal.",
196
+ "There is a left pleural effusion.",
197
+ "There is a left pleural effusion."
198
+ ]
199
+ hyps = [
200
+ "No pneumothorax detected.",
201
+ "Left pleural effusion is present.",
202
+ "No pneumothorax detected.",
203
+ ]
204
+
205
+ # Step-1: compute the four individual metrics
206
+
207
+ # Step-2: get the RadCliQ-v1 composite
208
+ radcliq = CompositeMetric()
209
+ mean_scores, detail_scores = radcliq.predict(refs, hyps)
210
+ for i, s in enumerate(detail_scores, 1):
211
+ print(f"Pair {i}: RadCliQ-v1 = {s:.4f}")
212
+
213
+ print(f"RadCliQ-v1 score: {mean_scores:.4f}")
factual/RadCliQv1/radcliq_bertscore.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from nlg.bertscore.bertscore import BertScore
2
+
3
+ def radcliq_bertscore(refs, hyps):
4
+ bertscore_scorer = BertScore(model_type='distilroberta-base',
5
+ rescale_with_baseline=True,
6
+ idf=False,
7
+ num_layers=None)
8
+ print(bertscore_scorer)
9
+ avg, scores = bertscore_scorer(refs, hyps)
10
+ return scores
factual/RadCliQv1/radcliq_radgraph.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from radgraph import RadGraph
3
+
4
+
5
+ def compute_f1(test, retrieved):
6
+ """Computes F1 between test/retrieved report's entities or relations."""
7
+ tp = len(test & retrieved)
8
+ fp = len(retrieved) - tp
9
+ fn = len(test) - tp
10
+ precision = tp / (tp + fp) if (tp + fp) else 0
11
+ recall = tp / (tp + fn) if (tp + fn) else 0
12
+ return 2 * precision * recall / (precision + recall) if (precision + recall) else 0
13
+
14
+
15
+ def extract_entities(output):
16
+ """Extracts set of (tokens, label) from a RadGraph output dict."""
17
+ return {(tuple(ent["tokens"]), ent["label"]) for ent in output.get("entities", {}).values()}
18
+
19
+
20
+ def extract_relations(output):
21
+ """Extracts set of (src, tgt, relation) from a RadGraph output dict."""
22
+ rels = set()
23
+ entities = output.get("entities", {})
24
+ for ent in entities.values():
25
+ src = (tuple(ent["tokens"]), ent["label"])
26
+ for rel_type, tgt_idx in ent.get("relations", []):
27
+ tgt_ent = entities.get(tgt_idx)
28
+ if tgt_ent:
29
+ tgt = (tuple(tgt_ent["tokens"]), tgt_ent["label"])
30
+ rels.add((src, tgt, rel_type))
31
+ return rels
32
+
33
+
34
+ def compute_radgraph_scores(refs, hyps, model_name='radgraph'):
35
+ """
36
+ Computes combined RadGraph F1 scores for each pair of reference and hypothesis reports.
37
+ Returns:
38
+ List of floats: (entity_f1 + relation_f1)/2 per report.
39
+ """
40
+ # Initialize RadGraph model
41
+ rad = RadGraph(model_type=model_name)
42
+
43
+ # Perform inference
44
+ gt_outputs = rad(refs)
45
+ pred_outputs = rad(hyps)
46
+
47
+ scores = []
48
+ for i in range(len(gt_outputs)):
49
+ gt_out = gt_outputs[str(i)]
50
+ pred_out = pred_outputs[str(i)]
51
+
52
+ gt_ents = extract_entities(gt_out)
53
+ pred_ents = extract_entities(pred_out)
54
+ gt_rels = extract_relations(gt_out)
55
+ pred_rels = extract_relations(pred_out)
56
+
57
+ ent_f1 = compute_f1(gt_ents, pred_ents)
58
+ rel_f1 = compute_f1(gt_rels, pred_rels)
59
+ scores.append((ent_f1 + rel_f1) / 2)
60
+
61
+ return scores
62
+
63
+
64
+ if __name__ == '__main__':
65
+ # Example usage
66
+ refs = [
67
+ "No evidence of pneumothorax following chest tube removal.",
68
+ "There is a left pleural effusion."
69
+ ]
70
+ hyps = [
71
+ "No pneumothorax detected.",
72
+ "Left pleural effusion is present."
73
+ ]
74
+
75
+ combined_scores = compute_radgraph_scores(refs, hyps)
76
+ print(combined_scores) # e.g., [1.0, 1.0]
77
+ from radgraph import F1RadGraph
78
+ f1_radgraph = F1RadGraph(model_type="radgraph", reward_level="simple")
79
+ f1_scores = f1_radgraph(refs, hyps,)
80
+ print(f1_scores)
factual/RadCliQv1/semb_score.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Sequence, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from factual.f1chexbert import F1CheXbert
7
+
8
+
9
+ def semantic_embedding_scores(
10
+ refs: Sequence[str],
11
+ hyps: Sequence[str],
12
+ *,
13
+ device: Union[str, torch.device] = "cpu",
14
+ ) -> np.ndarray:
15
+ """Return per‑pair cosine similarities between `refs` and `hyps`.
16
+
17
+ All heavy math is vectorised; no Python loops.
18
+
19
+ Args:
20
+ refs: Iterable of ground‑truth report strings.
21
+ hyps: Iterable of predicted report strings (must match `refs` length).
22
+ device: Computation device (e.g. "cpu", "cuda", "cuda:0").
23
+
24
+ Returns
25
+ -------
26
+ np.ndarray
27
+ Shape ``(N,)`` – cosine similarity for each pair, where
28
+ ``N == len(refs) == len(hyps)``.
29
+
30
+ Raises
31
+ ------
32
+ ValueError
33
+ If `refs` and `hyps` are of different lengths.
34
+ """
35
+
36
+ if len(refs) != len(hyps):
37
+ raise ValueError(f"refs ({len(refs)}) and hyps ({len(hyps)}) differ in length")
38
+
39
+ labeler = F1CheXbert(device=device)
40
+
41
+ # Stack embeddings into (N, dim) matrices
42
+ gt_embeds = np.vstack(labeler.get_embeddings(refs)) # (N, dim)
43
+ pred_embeds = np.vstack(labeler.get_embeddings(hyps)) # (N, dim)
44
+
45
+ # Cosine similarity – fully vectorised
46
+ dot = np.einsum("nd,nd->n", gt_embeds, pred_embeds)
47
+ norms = np.linalg.norm(gt_embeds, axis=1) * np.linalg.norm(pred_embeds, axis=1)
48
+ with np.errstate(divide="ignore", invalid="ignore"):
49
+ sims = np.where(norms > 0, dot / norms, 0.0)
50
+
51
+ return sims
52
+
53
+
54
+ def mean_semantic_score(scores: np.ndarray) -> float:
55
+ """Convenience helper: mean of an array of scores."""
56
+ return float(scores.mean())
57
+
58
+
59
+ if __name__ == "__main__":
60
+ _refs = [
61
+ "No evidence of pneumothorax following chest tube removal.",
62
+ "There is a left pleural effusion.",
63
+ "No evidence of pneumothorax following chest tube removal.",
64
+
65
+ ]
66
+ _hyps = [
67
+ "No pneumothorax detected.",
68
+ "Left pleural effusion is present.",
69
+ "Left pleural effusion is present.",
70
+ ]
71
+
72
+ _scores = semantic_embedding_scores(_refs, _hyps, device="cpu")
73
+ print("Per‑pair cosine:", _scores)
74
+ print("Mean:", mean_semantic_score(_scores))
factual/SRRBert/leaves_mapping.json ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "No Finding": 0,
3
+ "Lung Lesion": 1,
4
+ "Edema": 2,
5
+ "Pneumonia": 3,
6
+ "Atelectasis": 4,
7
+ "Aspiration": 5,
8
+ "Lung collapse": 6,
9
+ "Perihilar airspace opacity": 7,
10
+ "Air space opacity\u2013multifocal": 8,
11
+ "Mass/Solitary lung mass": 9,
12
+ "Nodule/Solitary lung nodule": 10,
13
+ "Cavitating mass with content": 11,
14
+ "Cavitating masses": 12,
15
+ "Emphysema": 13,
16
+ "Fibrosis": 14,
17
+ "Pulmonary congestion": 15,
18
+ "Hilar lymphadenopathy": 16,
19
+ "Bronchiectasis": 17,
20
+ "Simple pneumothorax": 18,
21
+ "Loculated pneumothorax": 19,
22
+ "Tension pneumothorax": 20,
23
+ "Simple pleural effusion": 21,
24
+ "Loculated pleural effusion": 22,
25
+ "Pleural scarring": 23,
26
+ "Hydropneumothorax": 24,
27
+ "Pleural Other": 25,
28
+ "Cardiomegaly": 26,
29
+ "Pericardial effusion": 27,
30
+ "Inferior mediastinal mass": 28,
31
+ "Superior mediastinal mass": 29,
32
+ "Tortuous Aorta": 30,
33
+ "Calcification of the Aorta": 31,
34
+ "Enlarged pulmonary artery": 32,
35
+ "Hernia": 33,
36
+ "Pneumomediastinum": 34,
37
+ "Tracheal deviation": 35,
38
+ "Acute humerus fracture": 36,
39
+ "Acute rib fracture": 37,
40
+ "Acute clavicle fracture": 38,
41
+ "Acute scapula fracture": 39,
42
+ "Compression fracture": 40,
43
+ "Shoulder dislocation": 41,
44
+ "Subcutaneous Emphysema": 42,
45
+ "Suboptimal central line": 43,
46
+ "Suboptimal endotracheal tube": 44,
47
+ "Suboptimal nasogastric tube": 45,
48
+ "Suboptimal pulmonary arterial catheter": 46,
49
+ "Pleural tube": 47,
50
+ "PICC line": 48,
51
+ "Port catheter": 49,
52
+ "Pacemaker": 50,
53
+ "Implantable defibrillator": 51,
54
+ "LVAD": 52,
55
+ "Intraaortic balloon pump": 53,
56
+ "Pneumoperitoneum": 54
57
+ }
58
+
factual/SRRBert/leaves_with_statuses_mapping.json ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Lung Lesion (Present)": 0,
3
+ "Edema (Present)": 1,
4
+ "Pneumonia (Present)": 2,
5
+ "Atelectasis (Present)": 3,
6
+ "Aspiration (Present)": 4,
7
+ "Lung collapse (Present)": 5,
8
+ "Perihilar airspace opacity (Present)": 6,
9
+ "Air space opacity\u2013multifocal (Present)": 7,
10
+ "Mass/Solitary lung mass (Present)": 8,
11
+ "Nodule/Solitary lung nodule (Present)": 9,
12
+ "Cavitating mass with content (Present)": 10,
13
+ "Cavitating masses (Present)": 11,
14
+ "Emphysema (Present)": 12,
15
+ "Fibrosis (Present)": 13,
16
+ "Pulmonary congestion (Present)": 14,
17
+ "Hilar lymphadenopathy (Present)": 15,
18
+ "Bronchiectasis (Present)": 16,
19
+ "Simple pneumothorax (Present)": 17,
20
+ "Loculated pneumothorax (Present)": 18,
21
+ "Tension pneumothorax (Present)": 19,
22
+ "Simple pleural effusion (Present)": 20,
23
+ "Loculated pleural effusion (Present)": 21,
24
+ "Pleural scarring (Present)": 22,
25
+ "Hydropneumothorax (Present)": 23,
26
+ "Pleural Other (Present)": 24,
27
+ "Cardiomegaly (Present)": 25,
28
+ "Pericardial effusion (Present)": 26,
29
+ "Inferior mediastinal mass (Present)": 27,
30
+ "Superior mediastinal mass (Present)": 28,
31
+ "Tortuous Aorta (Present)": 29,
32
+ "Calcification of the Aorta (Present)": 30,
33
+ "Enlarged pulmonary artery (Present)": 31,
34
+ "Hernia (Present)": 32,
35
+ "Pneumomediastinum (Present)": 33,
36
+ "Tracheal deviation (Present)": 34,
37
+ "Acute humerus fracture (Present)": 35,
38
+ "Acute rib fracture (Present)": 36,
39
+ "Acute clavicle fracture (Present)": 37,
40
+ "Acute scapula fracture (Present)": 38,
41
+ "Compression fracture (Present)": 39,
42
+ "Shoulder dislocation (Present)": 40,
43
+ "Subcutaneous Emphysema (Present)": 41,
44
+ "Suboptimal central line (Present)": 42,
45
+ "Suboptimal endotracheal tube (Present)": 43,
46
+ "Suboptimal nasogastric tube (Present)": 44,
47
+ "Suboptimal pulmonary arterial catheter (Present)": 45,
48
+ "Pleural tube (Present)": 46,
49
+ "PICC line (Present)": 47,
50
+ "Port catheter (Present)": 48,
51
+ "Pacemaker (Present)": 49,
52
+ "Implantable defibrillator (Present)": 50,
53
+ "LVAD (Present)": 51,
54
+ "Intraaortic balloon pump (Present)": 52,
55
+ "Pneumoperitoneum (Present)": 53,
56
+ "Lung Lesion (Uncertain)": 54,
57
+ "Edema (Uncertain)": 55,
58
+ "Pneumonia (Uncertain)": 56,
59
+ "Atelectasis (Uncertain)": 57,
60
+ "Aspiration (Uncertain)": 58,
61
+ "Lung collapse (Uncertain)": 59,
62
+ "Perihilar airspace opacity (Uncertain)": 60,
63
+ "Air space opacity\u2013multifocal (Uncertain)": 61,
64
+ "Mass/Solitary lung mass (Uncertain)": 62,
65
+ "Nodule/Solitary lung nodule (Uncertain)": 63,
66
+ "Cavitating mass with content (Uncertain)": 64,
67
+ "Cavitating masses (Uncertain)": 65,
68
+ "Emphysema (Uncertain)": 66,
69
+ "Fibrosis (Uncertain)": 67,
70
+ "Pulmonary congestion (Uncertain)": 68,
71
+ "Hilar lymphadenopathy (Uncertain)": 69,
72
+ "Bronchiectasis (Uncertain)": 70,
73
+ "Simple pneumothorax (Uncertain)": 71,
74
+ "Loculated pneumothorax (Uncertain)": 72,
75
+ "Tension pneumothorax (Uncertain)": 73,
76
+ "Simple pleural effusion (Uncertain)": 74,
77
+ "Loculated pleural effusion (Uncertain)": 75,
78
+ "Pleural scarring (Uncertain)": 76,
79
+ "Hydropneumothorax (Uncertain)": 77,
80
+ "Pleural Other (Uncertain)": 78,
81
+ "Cardiomegaly (Uncertain)": 79,
82
+ "Pericardial effusion (Uncertain)": 80,
83
+ "Inferior mediastinal mass (Uncertain)": 81,
84
+ "Superior mediastinal mass (Uncertain)": 82,
85
+ "Tortuous Aorta (Uncertain)": 83,
86
+ "Calcification of the Aorta (Uncertain)": 84,
87
+ "Enlarged pulmonary artery (Uncertain)": 85,
88
+ "Hernia (Uncertain)": 86,
89
+ "Pneumomediastinum (Uncertain)": 87,
90
+ "Tracheal deviation (Uncertain)": 88,
91
+ "Acute humerus fracture (Uncertain)": 89,
92
+ "Acute rib fracture (Uncertain)": 90,
93
+ "Acute clavicle fracture (Uncertain)": 91,
94
+ "Acute scapula fracture (Uncertain)": 92,
95
+ "Compression fracture (Uncertain)": 93,
96
+ "Shoulder dislocation (Uncertain)": 94,
97
+ "Subcutaneous Emphysema (Uncertain)": 95,
98
+ "Suboptimal central line (Uncertain)": 96,
99
+ "Suboptimal endotracheal tube (Uncertain)": 97,
100
+ "Suboptimal nasogastric tube (Uncertain)": 98,
101
+ "Suboptimal pulmonary arterial catheter (Uncertain)": 99,
102
+ "Pleural tube (Uncertain)": 100,
103
+ "PICC line (Uncertain)": 101,
104
+ "Port catheter (Uncertain)": 102,
105
+ "Pacemaker (Uncertain)": 103,
106
+ "Implantable defibrillator (Uncertain)": 104,
107
+ "LVAD (Uncertain)": 105,
108
+ "Intraaortic balloon pump (Uncertain)": 106,
109
+ "Pneumoperitoneum (Uncertain)": 107,
110
+ "Lung Lesion (Absent)": 108,
111
+ "Edema (Absent)": 109,
112
+ "Pneumonia (Absent)": 110,
113
+ "Atelectasis (Absent)": 111,
114
+ "Aspiration (Absent)": 112,
115
+ "Lung collapse (Absent)": 113,
116
+ "Perihilar airspace opacity (Absent)": 114,
117
+ "Air space opacity\u2013multifocal (Absent)": 115,
118
+ "Mass/Solitary lung mass (Absent)": 116,
119
+ "Nodule/Solitary lung nodule (Absent)": 117,
120
+ "Cavitating mass with content (Absent)": 118,
121
+ "Cavitating masses (Absent)": 119,
122
+ "Emphysema (Absent)": 120,
123
+ "Fibrosis (Absent)": 121,
124
+ "Pulmonary congestion (Absent)": 122,
125
+ "Hilar lymphadenopathy (Absent)": 123,
126
+ "Bronchiectasis (Absent)": 124,
127
+ "Simple pneumothorax (Absent)": 125,
128
+ "Loculated pneumothorax (Absent)": 126,
129
+ "Tension pneumothorax (Absent)": 127,
130
+ "Simple pleural effusion (Absent)": 128,
131
+ "Loculated pleural effusion (Absent)": 129,
132
+ "Pleural scarring (Absent)": 130,
133
+ "Hydropneumothorax (Absent)": 131,
134
+ "Pleural Other (Absent)": 132,
135
+ "Cardiomegaly (Absent)": 133,
136
+ "Pericardial effusion (Absent)": 134,
137
+ "Inferior mediastinal mass (Absent)": 135,
138
+ "Superior mediastinal mass (Absent)": 136,
139
+ "Tortuous Aorta (Absent)": 137,
140
+ "Calcification of the Aorta (Absent)": 138,
141
+ "Enlarged pulmonary artery (Absent)": 139,
142
+ "Hernia (Absent)": 140,
143
+ "Pneumomediastinum (Absent)": 141,
144
+ "Tracheal deviation (Absent)": 142,
145
+ "Acute humerus fracture (Absent)": 143,
146
+ "Acute rib fracture (Absent)": 144,
147
+ "Acute clavicle fracture (Absent)": 145,
148
+ "Acute scapula fracture (Absent)": 146,
149
+ "Compression fracture (Absent)": 147,
150
+ "Shoulder dislocation (Absent)": 148,
151
+ "Subcutaneous Emphysema (Absent)": 149,
152
+ "Suboptimal central line (Absent)": 150,
153
+ "Suboptimal endotracheal tube (Absent)": 151,
154
+ "Suboptimal nasogastric tube (Absent)": 152,
155
+ "Suboptimal pulmonary arterial catheter (Absent)": 153,
156
+ "Pleural tube (Absent)": 154,
157
+ "PICC line (Absent)": 155,
158
+ "Port catheter (Absent)": 156,
159
+ "Pacemaker (Absent)": 157,
160
+ "Implantable defibrillator (Absent)": 158,
161
+ "LVAD (Absent)": 159,
162
+ "Intraaortic balloon pump (Absent)": 160,
163
+ "Pneumoperitoneum (Absent)": 161,
164
+ "No Finding": 162
165
+ }
factual/SRRBert/srr_bert.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import BertForSequenceClassification, BertTokenizer
7
+ from tqdm import tqdm
8
+ import re
9
+ import nltk
10
+
11
+
12
+ def srr_bert_parse_sentences(text):
13
+ # Handle numbers followed by a dot, not followed by a digit (to avoid decimals like 3.5)
14
+
15
+ # Case 1: Number at beginning of text
16
+ text = re.sub(r'^\s*\d+\.(?!\d)\s*', '', text)
17
+
18
+ # Case 2: Number after a period, like "word.2."
19
+ text = re.sub(r'(\w)\.(\d+)\.(?!\d)\s*', r'\1. ', text)
20
+
21
+ # Case 3: Number attached to a word, like "word2."
22
+ text = re.sub(r'(\w)(\d+)\.(?!\d)\s*', r'\1. ', text)
23
+
24
+ # Case 4: Number after space following a word, like "word 2."
25
+ text = re.sub(r'(\w)\s+\d+\.(?!\d)\s*', r'\1. ', text)
26
+
27
+ # Case 5: Standalone number in the middle, like ". 2. word"
28
+ text = re.sub(r'([.!?])\s*\d+\.(?!\d)\s*', r'\1 ', text)
29
+
30
+ # Add space after periods followed immediately by uppercase letter (new sentence without space)
31
+ text = re.sub(r'\.([A-Z])', r'. \1', text)
32
+
33
+ # Make sure the text ends with a period
34
+ if not text.strip().endswith(('.', '!', '?')):
35
+ text = text.strip() + '.'
36
+
37
+ # Tokenize into sentences
38
+ sentences = nltk.sent_tokenize(text)
39
+
40
+ return sentences
41
+
42
+
43
+ class SRRBert(nn.Module):
44
+ # Supported model types and their configs
45
+ MODEL_CONFIGS = {
46
+ "leaves": {
47
+ "model_path": "StanfordAIMI/SRR-BERT-Leaves",
48
+ "mapping_file": "leaves_mapping.json"
49
+ },
50
+ "upper": {
51
+ "model_path": "StanfordAIMI/SRR-BERT-Upper",
52
+ "mapping_file": "upper_mapping.json"
53
+ },
54
+ "leaves_with_statuses": {
55
+ "model_path": "StanfordAIMI/SRR-BERT-Leaves-with-Statuses",
56
+ "mapping_file": "leaves_with_statuses_mapping.json"
57
+ },
58
+ "upper_with_statuses": {
59
+ "model_path": "StanfordAIMI/SRRG-BERT-Upper-with-Statuses",
60
+ "mapping_file": "upper_with_statuses_mapping.json"
61
+ },
62
+ }
63
+
64
+ def __init__(
65
+ self,
66
+ model_type: str = "leaves",
67
+ batch_size: int = 4,
68
+ tqdm_enable: bool = False
69
+ ):
70
+ super().__init__()
71
+ if model_type not in self.MODEL_CONFIGS:
72
+ raise ValueError(
73
+ f"model_type must be one of {list(self.MODEL_CONFIGS.keys())}"
74
+ )
75
+ config = self.MODEL_CONFIGS[model_type]
76
+
77
+ # Load mapping
78
+ mapping_path = os.path.join(
79
+ os.path.dirname(__file__),
80
+ config["mapping_file"]
81
+ )
82
+ with open(mapping_path, 'r') as f:
83
+ self.mapping = json.load(f)
84
+
85
+ # Device setup
86
+ self.device = torch.device(
87
+ 'cuda' if torch.cuda.is_available() else 'cpu'
88
+ )
89
+
90
+ # Load model
91
+ self.model = BertForSequenceClassification.from_pretrained(
92
+ config["model_path"],
93
+ num_labels=len(self.mapping)
94
+ )
95
+ self.model.to(self.device)
96
+ self.model.eval()
97
+
98
+ # Tokenizer
99
+ self.tokenizer = BertTokenizer.from_pretrained(
100
+ "microsoft/BiomedVLP-CXR-BERT-general"
101
+ )
102
+
103
+ # Settings
104
+ self.batch_size = batch_size
105
+ self.tqdm_enable = tqdm_enable
106
+
107
+ def map_predictions_to_labels(self, outputs):
108
+ inverted_mapping = {v: k for k, v in self.mapping.items()}
109
+ all_labels = []
110
+ for output in outputs:
111
+ labels = [inverted_mapping[i] for i, flag in enumerate(output) if flag == 1]
112
+ all_labels.append(labels)
113
+ return all_labels
114
+
115
+ def forward(self, sentences):
116
+ # Batch sentences
117
+ batches = [
118
+ sentences[i:i + self.batch_size]
119
+ for i in range(0, len(sentences), self.batch_size)
120
+ ]
121
+ outputs = []
122
+ with torch.no_grad():
123
+ for batch in tqdm(
124
+ batches, desc="Predicting", disable=not self.tqdm_enable
125
+ ):
126
+ inputs = self.tokenizer.batch_encode_plus(
127
+ batch,
128
+ add_special_tokens=True,
129
+ max_length=512,
130
+ padding="max_length",
131
+ truncation=True,
132
+ return_attention_mask=True,
133
+ return_tensors="pt",
134
+ )
135
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
136
+ logits = self.model(**inputs).logits
137
+ preds = (torch.sigmoid(logits) > 0.5).cpu().numpy().astype(int)
138
+ outputs.append(preds)
139
+
140
+ outputs = np.concatenate(outputs, axis=0)
141
+ return outputs, self.map_predictions_to_labels(outputs)
142
+
143
+
144
+ if __name__ == "__main__":
145
+ example_sentences = [
146
+ "Layering pleural effusions",
147
+ "Moderate pulmonary edema.",
148
+ "Chronic fracture and dislocation involving the left humeral surgical neck and glenoid.",
149
+ "Stable cardiomegaly.",
150
+ ]
151
+
152
+ # Initialize model (choose one of: leaves, upper, leaves_with_statuses, upper_with_statuses)
153
+ model = SRRBert(
154
+ model_type="leaves",
155
+ batch_size=4,
156
+ tqdm_enable=True
157
+ )
158
+ outputs, labels = model(example_sentences)
159
+ print("Raw outputs:", outputs)
160
+ print("Predicted labels:", labels)
factual/SRRBert/upper_mapping.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Pleural Effusion": 0,
3
+ "Upper abdominal finding": 1,
4
+ "Widened cardiac silhouette": 2,
5
+ "Lung Finding": 3,
6
+ "No Finding": 4,
7
+ "Widened aortic contour": 5,
8
+ "Pleural Thickening": 6,
9
+ "Vascular finding": 7,
10
+ "Consolidation": 8,
11
+ "Pneumothorax": 9,
12
+ "Subdiaphragmatic gas": 10,
13
+ "Masslike opacity": 11,
14
+ "Chest wall finding": 12,
15
+ "Focal air space opacity": 13,
16
+ "Segmental collapse": 14,
17
+ "Fracture": 15,
18
+ "Mediastinal mass": 16,
19
+ "Solitary masslike opacity": 17,
20
+ "Support Devices": 18,
21
+ "Mediastinal finding": 19,
22
+ "Pleural finding": 20,
23
+ "Air space opacity": 21,
24
+ "Diffuse air space opacity": 22,
25
+ "Multiple masslike opacities": 23,
26
+ "Musculoskeletal finding": 24
27
+ }
28
+
factual/SRRBert/upper_with_statuses_mapping.json ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "Pleural Effusion (Present)": 0,
3
+ "Upper abdominal finding (Present)": 1,
4
+ "Widened cardiac silhouette (Present)": 2,
5
+ "Lung Finding (Present)": 3,
6
+ "Widened aortic contour (Present)": 4,
7
+ "Pleural Thickening (Present)": 5,
8
+ "Vascular finding (Present)": 6,
9
+ "Consolidation (Present)": 7,
10
+ "Pneumothorax (Present)": 8,
11
+ "Subdiaphragmatic gas (Present)": 9,
12
+ "Masslike opacity (Present)": 10,
13
+ "Chest wall finding (Present)": 11,
14
+ "Focal air space opacity (Present)": 12,
15
+ "Segmental collapse (Present)": 13,
16
+ "Fracture (Present)": 14,
17
+ "Mediastinal mass (Present)": 15,
18
+ "Solitary masslike opacity (Present)": 16,
19
+ "Support Devices (Present)": 17,
20
+ "Mediastinal finding (Present)": 18,
21
+ "Pleural finding (Present)": 19,
22
+ "Air space opacity (Present)": 20,
23
+ "Diffuse air space opacity (Present)": 21,
24
+ "Multiple masslike opacities (Present)": 22,
25
+ "Musculoskeletal finding (Present)": 23,
26
+ "Pleural Effusion (Uncertain)": 24,
27
+ "Upper abdominal finding (Uncertain)": 25,
28
+ "Widened cardiac silhouette (Uncertain)": 26,
29
+ "Lung Finding (Uncertain)": 27,
30
+ "Widened aortic contour (Uncertain)": 28,
31
+ "Pleural Thickening (Uncertain)": 29,
32
+ "Vascular finding (Uncertain)": 30,
33
+ "Consolidation (Uncertain)": 31,
34
+ "Pneumothorax (Uncertain)": 32,
35
+ "Subdiaphragmatic gas (Uncertain)": 33,
36
+ "Masslike opacity (Uncertain)": 34,
37
+ "Chest wall finding (Uncertain)": 35,
38
+ "Focal air space opacity (Uncertain)": 36,
39
+ "Segmental collapse (Uncertain)": 37,
40
+ "Fracture (Uncertain)": 38,
41
+ "Mediastinal mass (Uncertain)": 39,
42
+ "Solitary masslike opacity (Uncertain)": 40,
43
+ "Support Devices (Uncertain)": 41,
44
+ "Mediastinal finding (Uncertain)": 42,
45
+ "Pleural finding (Uncertain)": 43,
46
+ "Air space opacity (Uncertain)": 44,
47
+ "Diffuse air space opacity (Uncertain)": 45,
48
+ "Multiple masslike opacities (Uncertain)": 46,
49
+ "Musculoskeletal finding (Uncertain)": 47,
50
+ "Pleural Effusion (Absent)": 48,
51
+ "Upper abdominal finding (Absent)": 49,
52
+ "Widened cardiac silhouette (Absent)": 50,
53
+ "Lung Finding (Absent)": 51,
54
+ "Widened aortic contour (Absent)": 52,
55
+ "Pleural Thickening (Absent)": 53,
56
+ "Vascular finding (Absent)": 54,
57
+ "Consolidation (Absent)": 55,
58
+ "Pneumothorax (Absent)": 56,
59
+ "Subdiaphragmatic gas (Absent)": 57,
60
+ "Masslike opacity (Absent)": 58,
61
+ "Chest wall finding (Absent)": 59,
62
+ "Focal air space opacity (Absent)": 60,
63
+ "Segmental collapse (Absent)": 61,
64
+ "Fracture (Absent)": 62,
65
+ "Mediastinal mass (Absent)": 63,
66
+ "Solitary masslike opacity (Absent)": 64,
67
+ "Support Devices (Absent)": 65,
68
+ "Mediastinal finding (Absent)": 66,
69
+ "Pleural finding (Absent)": 67,
70
+ "Air space opacity (Absent)": 68,
71
+ "Diffuse air space opacity (Absent)": 69,
72
+ "Multiple masslike opacities (Absent)": 70,
73
+ "Musculoskeletal finding (Absent)": 71,
74
+ "No Finding": 72
75
+ }
76
+
factual/__init__.py ADDED
File without changes
factual/f1chexbert.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """CheXbert evaluation utilities – **device‑safe end‑to‑end**
3
+
4
+ This is a drop‑in replacement for your previous `f1chexbert.py` **and** for the helper
5
+ `SemanticEmbeddingScorer`. All tensors – model weights *and* inputs – are created on
6
+ exactly the same device so the ``Expected all tensors to be on the same device``
7
+ run‑time error disappears. The public API stays identical, so the rest of your
8
+ pipeline does not need to change.
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import os
14
+ import warnings
15
+ import logging
16
+ from typing import List, Sequence, Tuple, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import numpy as np
21
+ from transformers import (
22
+ AutoConfig,
23
+ BertModel,
24
+ BertTokenizer,
25
+ )
26
+ from sklearn.metrics import (
27
+ accuracy_score,
28
+ classification_report,
29
+ )
30
+ from sklearn.metrics._classification import _check_targets
31
+ from sklearn.utils.sparsefuncs import count_nonzero
32
+ from huggingface_hub import hf_hub_download
33
+ from appdirs import user_cache_dir
34
+
35
+ # -----------------------------------------------------------------------------
36
+ # GLOBALS & UTILITIES
37
+ # -----------------------------------------------------------------------------
38
+
39
+ CACHE_DIR = user_cache_dir("chexbert")
40
+ warnings.filterwarnings("ignore")
41
+ logging.getLogger("urllib3").setLevel(logging.ERROR)
42
+
43
+ # Helper ----------------------------------------------------------------------
44
+
45
+ def _generate_attention_masks(batch_ids: torch.LongTensor) -> torch.FloatTensor:
46
+ """Create a padding mask: 1 for real tokens, 0 for pads."""
47
+ # batch_ids shape: (B, L)
48
+ lengths = (batch_ids != 0).sum(dim=1) # (B,)
49
+ max_len = batch_ids.size(1)
50
+ idxs = torch.arange(max_len, device=batch_ids.device).unsqueeze(0) # (1, L)
51
+ return (idxs < lengths.unsqueeze(1)).float() # (B, L)
52
+
53
+ # -----------------------------------------------------------------------------
54
+ # MODEL COMPONENTS
55
+ # -----------------------------------------------------------------------------
56
+
57
+ class BertLabeler(nn.Module):
58
+ """BERT backbone + 14 small classification heads (CheXbert)."""
59
+
60
+ def __init__(self, *, device: Union[str, torch.device]):
61
+ super().__init__()
62
+
63
+ if isinstance(device, str):
64
+ self.device = torch.device(device)
65
+ else:
66
+ self.device = device
67
+
68
+ # 1) Backbone on *CPU* first – we'll move to correct device after weights load
69
+ config = AutoConfig.from_pretrained("bert-base-uncased")
70
+ self.bert = BertModel(config)
71
+
72
+ hidden = self.bert.config.hidden_size
73
+ # 13 heads with 4‑way logits, + 1 head with 2‑way logits
74
+ self.linear_heads = nn.ModuleList([nn.Linear(hidden, 4) for _ in range(13)])
75
+ self.linear_heads.append(nn.Linear(hidden, 2))
76
+
77
+ self.dropout = nn.Dropout(0.1)
78
+
79
+ # 2) Load checkpoint weights directly onto CPU first -------------------
80
+ ckpt_path = hf_hub_download(
81
+ repo_id="StanfordAIMI/RRG_scorers",
82
+ filename="chexbert.pth",
83
+ cache_dir=CACHE_DIR,
84
+ )
85
+ state = torch.load(ckpt_path, map_location="cpu")["model_state_dict"]
86
+ state = {k.replace("module.", ""): v for k, v in state.items()}
87
+ self.load_state_dict(state, strict=True)
88
+
89
+ # 3) NOW move the entire module (recursively) to `self.device` ----------
90
+ self.to(self.device)
91
+
92
+ # freeze ---------------------------------------------------------------
93
+ for p in self.parameters():
94
+ p.requires_grad = False
95
+
96
+ # ---------------------------------------------------------------------
97
+ # forward helpers
98
+ # ---------------------------------------------------------------------
99
+
100
+ @torch.no_grad()
101
+ def cls_logits(self, input_ids: torch.LongTensor) -> List[torch.Tensor]:
102
+ """Returns a list of logits for each head (no softmax)."""
103
+ attn = _generate_attention_masks(input_ids)
104
+ outputs = self.bert(input_ids=input_ids, attention_mask=attn)
105
+ cls_repr = self.dropout(outputs.last_hidden_state[:, 0])
106
+ return [head(cls_repr) for head in self.linear_heads]
107
+
108
+ @torch.no_grad()
109
+ def cls_embeddings(self, input_ids: torch.LongTensor) -> torch.Tensor:
110
+ """Returns pooled [CLS] representations (B, hidden_size)."""
111
+ attn = _generate_attention_masks(input_ids)
112
+ outputs = self.bert(input_ids=input_ids, attention_mask=attn)
113
+ return outputs.last_hidden_state[:, 0] # (B, hidden)
114
+
115
+ # -----------------------------------------------------------------------------
116
+ # F1‑CheXbert evaluator
117
+ # -----------------------------------------------------------------------------
118
+
119
+ class F1CheXbert(nn.Module):
120
+ """Generate CheXbert labels + handy evaluation utilities."""
121
+
122
+ CONDITION_NAMES = [
123
+ "Enlarged Cardiomediastinum",
124
+ "Cardiomegaly",
125
+ "Lung Opacity",
126
+ "Lung Lesion",
127
+ "Edema",
128
+ "Consolidation",
129
+ "Pneumonia",
130
+ "Atelectasis",
131
+ "Pneumothorax",
132
+ "Pleural Effusion",
133
+ "Pleural Other",
134
+ "Fracture",
135
+ "Support Devices",
136
+ ]
137
+ NO_FINDING = "No Finding"
138
+ TARGET_NAMES = CONDITION_NAMES + [NO_FINDING]
139
+
140
+ TOP5 = [
141
+ "Cardiomegaly",
142
+ "Edema",
143
+ "Consolidation",
144
+ "Atelectasis",
145
+ "Pleural Effusion",
146
+ ]
147
+
148
+ def __init__(
149
+ self,
150
+ *,
151
+ refs_filename: str | None = None,
152
+ hyps_filename: str | None = None,
153
+ device: Union[str, torch.device] = "cpu",
154
+ ):
155
+ super().__init__()
156
+
157
+ # Resolve device -------------------------------------------------------
158
+ if isinstance(device, str):
159
+ self.device = torch.device(device)
160
+ else:
161
+ self.device = device
162
+
163
+ self.refs_filename = refs_filename
164
+ self.hyps_filename = hyps_filename
165
+
166
+ # HuggingFace tokenizer (always CPU, we just move tensors later) -------
167
+ self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
168
+
169
+ # backbone + heads ------------------------------------------------------
170
+ self.model = BertLabeler(device=self.device).eval()
171
+
172
+ # indices for the TOP‑5 label subset -----------------------------------
173
+ self.top5_idx = [self.TARGET_NAMES.index(n) for n in self.TOP5]
174
+
175
+ # ---------------------------------------------------------------------
176
+ # Public helpers
177
+ # ---------------------------------------------------------------------
178
+
179
+ @torch.no_grad()
180
+ def get_embeddings(self, reports: Sequence[str]) -> List[np.ndarray]:
181
+ """Return list[np.ndarray] of pooled [CLS] vectors for each report."""
182
+ # Tokenise *as a batch* for efficiency
183
+ encoding = self.tokenizer(
184
+ reports,
185
+ padding=True,
186
+ truncation=True,
187
+ max_length=512,
188
+ return_tensors="pt",
189
+ )
190
+ input_ids = encoding.input_ids.to(self.device)
191
+ # (B, hidden)
192
+ cls = self.model.cls_embeddings(input_ids)
193
+ return [v.cpu().numpy() for v in cls]
194
+
195
+ @torch.no_grad()
196
+ def get_label(self, report: str, mode: str = "rrg") -> List[int]:
197
+ """Return 14‑dim binary vector for the given report."""
198
+ input_ids = self.tokenizer(report, truncation=True, max_length=512, return_tensors="pt").input_ids.to(self.device)
199
+ preds = [head.argmax(dim=1).item() for head in self.model.cls_logits(input_ids)]
200
+
201
+ binary = []
202
+ if mode == "rrg":
203
+ for c in preds:
204
+ binary.append(1 if c in {1, 3} else 0)
205
+ elif mode == "classification":
206
+ for c in preds:
207
+ if c == 1:
208
+ binary.append(1)
209
+ elif c == 2:
210
+ binary.append(0)
211
+ elif c == 3:
212
+ binary.append(-1)
213
+ else:
214
+ binary.append(0)
215
+ else:
216
+ raise ValueError(f"Unknown mode: {mode}")
217
+ return binary
218
+
219
+ # ---------------------------------------------------------------------
220
+ # Full evaluator – unchanged logic but simplified I/O
221
+ # ---------------------------------------------------------------------
222
+
223
+ def forward(self, hyps: List[str], refs: List[str]):
224
+ """Return (accuracy, per‑example‑accuracy, full classification reports)."""
225
+ # Reference labels -----------------------------------------------------
226
+ if self.refs_filename and os.path.exists(self.refs_filename):
227
+ with open(self.refs_filename) as f:
228
+ refs_chexbert = [eval(line) for line in f]
229
+ else:
230
+ refs_chexbert = [self.get_label(r) for r in refs]
231
+ if self.refs_filename:
232
+ with open(self.refs_filename, "w") as f:
233
+ f.write("\n".join(map(str, refs_chexbert)))
234
+
235
+ # Hypothesis labels ----------------------------------------------------
236
+ hyps_chexbert = [self.get_label(h) for h in hyps]
237
+ if self.hyps_filename:
238
+ with open(self.hyps_filename, "w") as f:
239
+ f.write("\n".join(map(str, hyps_chexbert)))
240
+
241
+ # TOP‑5 subset arrays --------------------------------------------------
242
+ refs5 = [np.array(r)[self.top5_idx] for r in refs_chexbert]
243
+ hyps5 = [np.array(h)[self.top5_idx] for h in hyps_chexbert]
244
+
245
+ # overall accuracy -----------------------------------------------------
246
+ accuracy = accuracy_score(refs5, hyps5)
247
+ _, y_true, y_pred = _check_targets(refs5, hyps5)
248
+ pe_accuracy = (count_nonzero(y_true - y_pred, axis=1) == 0).astype(float)
249
+
250
+ # full classification reports -----------------------------------------
251
+ cr = classification_report(refs_chexbert, hyps_chexbert, target_names=self.TARGET_NAMES, output_dict=True)
252
+ cr5 = classification_report(refs5, hyps5, target_names=self.TOP5, output_dict=True)
253
+
254
+ return accuracy, pe_accuracy, cr, cr5
factual/f1temporal.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Temporal Entity F1
2
+ # Adopted from https://github.com/X-iZhang/Libra/blob/main/libra/eval/temporal_f1.py
3
+
4
+ import re
5
+ import stanza
6
+ import argparse
7
+ from typing import List, Union
8
+
9
+
10
+
11
+ # Initialize the pipeline with the radiology NER model explicitly specified
12
+ nlp = stanza.Pipeline(
13
+ lang='en',
14
+ package='radiology',
15
+ processors={'tokenize': 'default', 'ner': 'radiology'},
16
+ logging_level='ERROR', # Only output warnings or more severe messages
17
+ verbose=False # Suppress additional information during pipeline initialization
18
+ )
19
+
20
+ # Keywords used for radiology-related entity extraction
21
+ # Reference: Learning to Exploit Temporal Structure for Biomedical Vision-Language Processing (CVPR2023)
22
+ # https://arxiv.org/pdf/2301.04558
23
+
24
+ KEYWORDS = {
25
+ "bigger", "change", "cleared", "constant", "decrease", "decreased", "decreasing", "elevated", "elevation",
26
+ "enlarged", "enlargement", "enlarging", "expanded", "greater", "growing", "improved", "improvement",
27
+ "improving", "increase", "increased", "increasing", "larger", "new", "persistence", "persistent",
28
+ "persisting", "progression", "progressive", "reduced", "removal", "resolution", "resolved", "resolving",
29
+ "smaller", "stability", "stable", "stably", "unchanged", "unfolded", "worse", "worsen", "worsened",
30
+ "worsening", "unaltered"
31
+ }
32
+
33
+ def clean_text(text: str) -> str:
34
+ """
35
+ Clean the input text by removing special characters and redundant spaces or newlines.
36
+
37
+ Args:
38
+ text (str): Input text.
39
+
40
+ Returns:
41
+ str: Cleaned text.
42
+ """
43
+ # Remove special characters and redundant newlines
44
+ text = re.sub(r'\n+', ' ', text) # Replace multiple newlines with a single space
45
+ text = re.sub(r'[_-]+', ' ', text) # Replace underscores and dashes with spaces
46
+ text = re.sub(r'\(___, __, __\)', '', text) # Remove irrelevant underscore patterns
47
+ text = re.sub(r'---, ---, ---', '', text) # Remove dashed patterns
48
+ text = re.sub(r'\(__, __, ___\)', '', text) # Remove similar underscore patterns
49
+ text = re.sub(r'[_-]+', ' ', text) # Replace underscores and dashes again (if any remain)
50
+ text = re.sub(r'[^\w\s.,:;()-]', '', text) # Remove non-alphanumeric characters except common punctuation
51
+
52
+ # Remove extra spaces
53
+ text = re.sub(r'\s{2,}', ' ', text).strip()
54
+ return text
55
+
56
+ def extract_entities(text: str, keywords: set) -> set:
57
+ """
58
+ Extract entities from the given text based on Stanza NER and provided keywords.
59
+
60
+ Args:
61
+ text (str): Input text.
62
+ keywords (set): Set of keywords to extract entities.
63
+
64
+ Returns:
65
+ set: Set of matched entities found in the text.
66
+ """
67
+ # Use Stanza NER to extract entities tagged as "OBSERVATION" or "OBSERVATION_MODIFIER"
68
+ doc = nlp(text)
69
+ stanza_entities = {ent.text.lower() for ent in doc.entities if ent.type in {"OBSERVATION", "OBSERVATION_MODIFIER"}}
70
+
71
+ # Filter Stanza entities to include only those present in keywords
72
+ matched_stanza_entities = {entity for entity in stanza_entities if entity in keywords}
73
+
74
+ # Clean the text before extracting entities
75
+ text = clean_text(text)
76
+
77
+ # Create a regex pattern that matches any of the keywords as whole words
78
+ pattern = r'\b(' + '|'.join(re.escape(word) for word in keywords) + r')\b'
79
+
80
+ # Find all matches using regex
81
+ keyword_matches = {match.group().lower() for match in re.finditer(pattern, text.lower())}
82
+
83
+ # Combine Stanza entities and regex matches
84
+ return matched_stanza_entities | keyword_matches
85
+
86
+ def calculate_tem_score(prediction_text: str, reference_text: Union[str, List[str]], epsilon: float = 1e-10) -> float:
87
+ """
88
+ Calculate the Temporal Entity Matching (TEM) score (similar to F1-score).
89
+
90
+ Args:
91
+ reference_text (Union[str, List[str]]): Reference text or a list of reference texts.
92
+ prediction_text (str): Prediction text.
93
+ epsilon (float): Small value to avoid division by zero.
94
+
95
+ Returns:
96
+ float: TEM score.
97
+ """
98
+ if isinstance(reference_text, list):
99
+ reference_entities = set()
100
+ for ref in reference_text:
101
+ reference_entities.update(extract_entities(ref, KEYWORDS))
102
+ else:
103
+ reference_entities = extract_entities(reference_text, KEYWORDS)
104
+
105
+ prediction_entities = extract_entities(prediction_text, KEYWORDS)
106
+
107
+ if len(reference_entities) == 0:
108
+ if len(prediction_entities) == 0:
109
+ return {
110
+ "f1": 1.0,
111
+ "prediction_entities": prediction_entities,
112
+ "reference_entities": reference_entities
113
+ } # Perfect match when both are empty
114
+ else:
115
+ return {
116
+ "f1": epsilon,
117
+ "prediction_entities": prediction_entities,
118
+ "reference_entities": reference_entities
119
+ } # Minimal score when reference is empty but prediction is not
120
+
121
+ # Calculate intersection of entities
122
+ true_positives = len(prediction_entities & reference_entities)
123
+
124
+ # Calculate precision and recall with epsilon to avoid division by zero
125
+ precision = (true_positives + epsilon) / (len(prediction_entities) + epsilon)
126
+ recall = (true_positives + epsilon) / (len(reference_entities) + epsilon)
127
+
128
+ # Calculate TEM score (F1 score)
129
+ tem_score = (2 * precision * recall) / (precision + recall + epsilon)
130
+
131
+ return {
132
+ "f1": tem_score,
133
+ "prediction_entities": prediction_entities,
134
+ "reference_entities": reference_entities
135
+ }
136
+
137
+ def F1Temporal(predictions: List[str], references: List[Union[str, List[str]]], epsilon: float = 1e-10) -> dict:
138
+ """
139
+ Calculate the average TEM score over a list of reference and prediction texts.
140
+
141
+ Args:
142
+ references (List[Union[str, List[str]]]): List of reference texts or lists of reference texts.
143
+ predictions (List[str]): List of prediction texts.
144
+ epsilon (float): Small value to avoid division by zero.
145
+
146
+ Returns:
147
+ float: Average TEM score.
148
+ """
149
+ assert len(references) == len(predictions), "Reference and prediction lists must have the same length."
150
+
151
+ tem_scores = []
152
+ prediction_entities = []
153
+ reference_entities = []
154
+
155
+ for pred, ref in zip(predictions, references):
156
+ result = calculate_tem_score(pred, ref, epsilon)
157
+ tem_scores.append(result["f1"])
158
+ prediction_entities.append(result["prediction_entities"])
159
+ reference_entities.append(result["reference_entities"])
160
+
161
+ average_f1 = sum(tem_scores) / len(tem_scores)
162
+
163
+ return {
164
+ "f1": average_f1,
165
+ "prediction_entities": prediction_entities,
166
+ "reference_entities": reference_entities
167
+ }
factual/green_score/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .green import GREEN
factual/green_score/green.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import torch
3
+ import torch.distributed as dist
4
+ import pandas as pd
5
+ from datasets import Dataset
6
+ from datasets.distributed import split_dataset_by_node
7
+ import os
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import time
11
+ import sys
12
+ import warnings
13
+ import torch.nn as nn
14
+ from transformers import AutoModelForCausalLM, AutoTokenizer
15
+ from transformers.utils import logging
16
+
17
+ # Import necessary functions (ensure these are available in your environment)
18
+ from factual.green_score.utils import (
19
+ gather_processes,
20
+ make_prompt,
21
+ clean_responses,
22
+ compute_largest_cluster,
23
+ flatten_values_lists_of_list_dicts_to_dict,
24
+ )
25
+
26
+ # Set the logging level for the transformers library to ERROR to suppress benign warnings
27
+ logging.get_logger("transformers").setLevel(logging.ERROR)
28
+
29
+ def get_rank():
30
+ if not dist.is_initialized():
31
+ return 0
32
+ return dist.get_rank()
33
+
34
+
35
+ def is_main_process():
36
+ return get_rank() == 0
37
+
38
+
39
+ def tqdm_on_main(*args, **kwargs):
40
+ if is_main_process():
41
+ print("==== Beginning Inference ====")
42
+ return tqdm(*args, **kwargs)
43
+ else:
44
+ return kwargs.get("iterable", None)
45
+
46
+
47
+ class GREEN:
48
+ def __init__(self, model_name, output_dir=".", cpu=False):
49
+ super().__init__()
50
+ warnings.filterwarnings(
51
+ "ignore", message="A decoder-only architecture is being used*"
52
+ )
53
+ from sklearn.exceptions import ConvergenceWarning
54
+
55
+ warnings.filterwarnings(
56
+ "ignore",
57
+ category=ConvergenceWarning,
58
+ message="Number of distinct clusters.*",
59
+ )
60
+ warnings.filterwarnings(
61
+ "ignore",
62
+ category=FutureWarning,
63
+ module="transformers.tokenization_utils_base",
64
+ )
65
+ self.cpu = cpu
66
+ self.model_name = model_name.split("/")[-1]
67
+ self.output_dir = output_dir
68
+ self.batch_size = 4
69
+ self.max_length = 2048
70
+ self.categories = [
71
+ "Clinically Significant Errors",
72
+ "Clinically Insignificant Errors",
73
+ "Matched Findings",
74
+ ]
75
+ self.sub_categories = [
76
+ "(a) False report of a finding in the candidate",
77
+ "(b) Missing a finding present in the reference",
78
+ "(c) Misidentification of a finding's anatomic location/position",
79
+ "(d) Misassessment of the severity of a finding",
80
+ "(e) Mentioning a comparison that isn't in the reference",
81
+ "(f) Omitting a comparison detailing a change from a prior study",
82
+ ]
83
+ self.prompts = None
84
+ self.completions = None
85
+ self.green_scores = None
86
+ self.error_counts = None
87
+
88
+ if torch.cuda.is_available() and torch.cuda.device_count() > 1 and not self.cpu:
89
+ if not dist.is_initialized():
90
+ dist.init_process_group(
91
+ backend="nccl",
92
+ )
93
+ torch.cuda.set_device(dist.get_rank())
94
+ if dist.get_rank() == 0:
95
+ print(
96
+ "Distributed training with", torch.cuda.device_count(), "GPUs"
97
+ )
98
+
99
+ self.model = AutoModelForCausalLM.from_pretrained(
100
+ model_name,
101
+ trust_remote_code=False if "Phi" in model_name else True,
102
+ device_map=(
103
+ {"": "cuda:{}".format(torch.cuda.current_device())}
104
+ if not self.cpu
105
+ else {"": "cpu"}
106
+ ),
107
+ torch_dtype=torch.float16,
108
+ )
109
+ self.model.eval()
110
+
111
+ self.tokenizer = AutoTokenizer.from_pretrained(
112
+ model_name,
113
+ add_eos_token=True,
114
+ use_fast=True,
115
+ trust_remote_code=True,
116
+ padding_side="left",
117
+ )
118
+
119
+ # Set up chat template for chat-style prompts
120
+ chat_template = (
121
+ "{% for message in messages %}\n"
122
+ "{% if message['from'] == 'human' %}\n"
123
+ "{{ '<|user|>\n' + message['value'] + eos_token }}\n"
124
+ "{% elif message['from'] == 'system' %}\n"
125
+ "{{ '<|system|>\n' + message['value'] + eos_token }}\n"
126
+ "{% elif message['from'] == 'gpt' %}\n"
127
+ "{{ '<|assistant|>\n' + message['value'] + eos_token }}\n"
128
+ "{% endif %}\n"
129
+ "{% if loop.last and add_generation_prompt %}\n"
130
+ "{{ '<|assistant|>' }}\n"
131
+ "{% endif %}\n"
132
+ "{% endfor %}"
133
+ )
134
+
135
+ self.tokenizer.chat_template = chat_template
136
+ self.tokenizer.pad_token = self.tokenizer.eos_token
137
+ self.tokenizer.clean_up_tokenization_spaces = True
138
+ self.tokenizer.padding_side = "left"
139
+
140
+ def __call__(self, refs, hyps):
141
+ print("Processing data...making prompts")
142
+ dataset = Dataset.from_dict({"reference": refs, "prediction": hyps})
143
+
144
+ dataset = self.process_data(dataset)
145
+ print("Done.")
146
+
147
+ self.dataset = dataset
148
+
149
+ t = time.time()
150
+
151
+ mean, std, green_scores, summary, results_df = self.infer()
152
+
153
+ t = time.time() - t
154
+ print("Seconds per example: ", t / len(refs))
155
+
156
+ if not is_main_process():
157
+ print(f"Rank {dist.get_rank()} exiting.")
158
+ dist.destroy_process_group()
159
+ sys.exit()
160
+
161
+ return mean, std, green_scores, summary, results_df
162
+
163
+ def process_data(self, dataset):
164
+ def prompting(examples):
165
+ return {
166
+ "prompt": [
167
+ make_prompt(r, p)
168
+ for r, p in zip(examples["reference"], examples["prediction"])
169
+ ]
170
+ }
171
+
172
+ dataset = dataset.map(prompting, batched=True)
173
+ return dataset
174
+
175
+ @torch.inference_mode()
176
+ def infer(self):
177
+ if torch.cuda.is_available() and torch.cuda.device_count() > 1 and not self.cpu:
178
+ dataset_dist = split_dataset_by_node(
179
+ self.dataset,
180
+ rank=get_rank(),
181
+ world_size=int(os.environ["WORLD_SIZE"]),
182
+ )
183
+ print("Distributed dataset created on rank: ", int(os.environ["RANK"]))
184
+ else:
185
+ dataset_dist = self.dataset
186
+
187
+ local_completions = []
188
+ local_references = []
189
+
190
+ for batch in tqdm_on_main(
191
+ iterable=dataset_dist.iter(batch_size=self.batch_size),
192
+ total=len(dataset_dist) // self.batch_size,
193
+ ):
194
+ local_references.extend(batch["prompt"])
195
+ local_completions.extend(self.get_response(batch))
196
+
197
+ if torch.cuda.is_available() and torch.cuda.device_count() > 1 and not self.cpu:
198
+ self.completions, self.prompts = gather_processes(
199
+ local_completions, local_references
200
+ )
201
+ else:
202
+ self.completions = local_completions
203
+ self.prompts = local_references
204
+
205
+ if is_main_process():
206
+ print("==== End Inference ====")
207
+
208
+ if len(self.completions) != len(self.prompts):
209
+ print("Length of prompts and completions are not equal!")
210
+
211
+ return self.process_results()
212
+
213
+ def tokenize_batch_as_chat(self, batch):
214
+ local_rank = int(os.environ.get("LOCAL_RANK", 0)) if not self.cpu else "cpu"
215
+ batch = [
216
+ self.tokenizer.apply_chat_template(
217
+ i, tokenize=False, add_generation_prompt=True
218
+ )
219
+ for i in batch
220
+ ]
221
+
222
+ batch = self.tokenizer.batch_encode_plus(
223
+ batch,
224
+ return_tensors="pt",
225
+ padding=True,
226
+ truncation=True,
227
+ max_length=self.max_length,
228
+ ).to(local_rank)
229
+
230
+ return batch
231
+
232
+ def get_response(self, batch):
233
+ assert "prompt" in batch.keys(), "prompt is not in batch keys"
234
+
235
+ batch = [
236
+ [{"from": "human", "value": prompt}, {"from": "gpt", "value": ""}]
237
+ for prompt in batch["prompt"]
238
+ ]
239
+
240
+ batch = self.tokenize_batch_as_chat(batch)
241
+
242
+ outputs = self.model.generate(
243
+ input_ids=batch["input_ids"],
244
+ attention_mask=batch["attention_mask"],
245
+ eos_token_id=self.tokenizer.eos_token_id,
246
+ pad_token_id=self.tokenizer.pad_token_id,
247
+ max_length=2048,
248
+ do_sample=False,
249
+ temperature=None,
250
+ top_p=None,
251
+ )
252
+
253
+ responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
254
+
255
+ response_list = []
256
+ if isinstance(responses, list):
257
+ for response in responses:
258
+ response = clean_responses(response)
259
+ response_list.append(response)
260
+ else:
261
+ responses = clean_responses(responses)
262
+ response_list.append(responses)
263
+
264
+ return response_list
265
+
266
+ def process_results(self):
267
+ self.green_scores = [
268
+ self.compute_green(response) for response in self.completions
269
+ ]
270
+ self.error_counts = pd.DataFrame(
271
+ [self.compute_error_count(response) for response in self.completions],
272
+ columns=self.sub_categories + ["Matched Findings"],
273
+ )
274
+
275
+ results_df = pd.DataFrame(
276
+ {
277
+ "reference": self.dataset["reference"],
278
+ "predictions": self.dataset["prediction"],
279
+ "green_analysis": self.completions,
280
+ "green_score": self.green_scores,
281
+ **self.error_counts,
282
+ }
283
+ )
284
+
285
+ mean, std, summary = self.compute_summary()
286
+
287
+ return mean, std, self.green_scores, summary, results_df
288
+
289
+ def compute_error_count(self, response):
290
+ _, sig_errors = self.parse_error_counts(response, self.categories[0])
291
+ matched_findings, _ = self.parse_error_counts(response, self.categories[2])
292
+ return sig_errors + [matched_findings]
293
+
294
+ def compute_green(self, response):
295
+ sig_present, sig_errors = self.parse_error_counts(response, self.categories[0])
296
+ matched_findings, _ = self.parse_error_counts(response, self.categories[2])
297
+
298
+ if matched_findings == 0:
299
+ return 0
300
+
301
+ if sig_present is None or matched_findings is None:
302
+ return None
303
+
304
+ return matched_findings / (matched_findings + sum(sig_errors))
305
+
306
+ def parse_error_counts(self, text, category, for_reward=False):
307
+ if category not in self.categories:
308
+ raise ValueError(
309
+ f"Category {category} is not a valid category. Please choose from {self.categories}."
310
+ )
311
+
312
+ pattern = rf"\[{category}\]:\s*(.*?)(?:\n\s*\n|\Z)"
313
+ category_text = re.search(pattern, text, re.DOTALL)
314
+
315
+ sum_counts = 0
316
+ sub_counts = [0 for i in range(6)]
317
+
318
+ if not category_text:
319
+ if for_reward:
320
+ return None, None
321
+ return sum_counts, sub_counts
322
+ if category_text.group(1).startswith("No"):
323
+ return sum_counts, sub_counts
324
+
325
+ if category == "Matched Findings":
326
+ counts = re.findall(r"^\b\d+\b(?=\.)", category_text.group(1))
327
+ if len(counts) > 0:
328
+ sum_counts = int(counts[0])
329
+ return sum_counts, sub_counts
330
+ else:
331
+ sub_categories = [s.split(" ", 1)[0] + " " for s in self.sub_categories]
332
+ matches = sorted(re.findall(r"\([a-f]\) .*", category_text.group(1)))
333
+
334
+ if len(matches) == 0:
335
+ matches = sorted(re.findall(r"\([1-6]\) .*", category_text.group(1)))
336
+ sub_categories = [
337
+ f"({i})" + " " for i in range(1, len(self.sub_categories) + 1)
338
+ ]
339
+
340
+ for position, sub_category in enumerate(sub_categories):
341
+ for match in range(len(matches)):
342
+ if matches[match].startswith(sub_category):
343
+ count = re.findall(r"(?<=: )\b\d+\b(?=\.)", matches[match])
344
+ if len(count) > 0:
345
+ sub_counts[position] = int(count[0])
346
+ return sum(sub_counts), sub_counts
347
+
348
+ def parse_error_sentences(self, response, category):
349
+ if category not in self.categories:
350
+ raise ValueError(
351
+ f"Category {category} is not a valid category. Please choose from {self.categories}."
352
+ )
353
+ pattern = rf"\[{category}\]:\s*(.*?)(?:\n\s*\n|\Z)"
354
+ category_text = re.search(pattern, response, re.DOTALL)
355
+ sub_category_dict_sentences = {}
356
+ for sub_category in self.sub_categories:
357
+ sub_category_dict_sentences[sub_category] = []
358
+
359
+ if not category_text:
360
+ return sub_category_dict_sentences
361
+ if category_text.group(1).startswith("No"):
362
+ return sub_category_dict_sentences
363
+
364
+ if category == "Matched Findings":
365
+ return (
366
+ category_text.group(1).rsplit(":", 1)[-1].rsplit(".", 1)[-1].split(";")
367
+ )
368
+
369
+ matches = sorted(re.findall(r"\([a-f]\) .*", category_text.group(1)))
370
+
371
+ if len(matches) == 0:
372
+ matches = sorted(re.findall(r"\([1-6]\) .*", category_text.group(1)))
373
+ self.sub_categories = [
374
+ f"({i})" + " " for i in range(1, len(self.sub_categories) + 1)
375
+ ]
376
+
377
+ for position, sub_category in enumerate(self.sub_categories):
378
+ for match in range(len(matches)):
379
+ if matches[match].startswith(sub_category):
380
+ sentences_list = (
381
+ matches[match].rsplit(":", 1)[-1].split(".", 1)[-1].split(";")
382
+ )
383
+ sub_category_dict_sentences[self.sub_categories[position]] = (
384
+ sentences_list
385
+ )
386
+
387
+ return sub_category_dict_sentences
388
+
389
+ def compute_sentences(self, response):
390
+ return self.parse_error_sentences(response, self.categories[0])
391
+
392
+ def get_representative_sentences(self, responses):
393
+ list_sentences = []
394
+ for i in responses:
395
+ sentences = self.compute_sentences(i)
396
+ list_sentences.append(sentences)
397
+
398
+ dict_sentences = flatten_values_lists_of_list_dicts_to_dict(list_sentences)
399
+
400
+ result_sentences_dict = {}
401
+
402
+ for i in self.sub_categories:
403
+ sentences = dict_sentences[i]
404
+ sentences = [i for i in sentences if i.strip() != ""]
405
+ _, sentences_of_largest_cluster = compute_largest_cluster(sentences)
406
+ result_sentences_dict[i] = sentences_of_largest_cluster
407
+
408
+ return result_sentences_dict
409
+
410
+ def compute_accuracy(self, responses):
411
+ counts = []
412
+ for response in responses:
413
+ _, sig_errors = self.parse_error_counts(response, self.categories[0])
414
+ counts.append(sig_errors)
415
+
416
+ counts = np.array(counts)
417
+
418
+ dict_acc = {}
419
+ for i in range(len(self.sub_categories)):
420
+ error_counts = counts[:, i]
421
+ accuracy = np.mean(error_counts == 0)
422
+ dict_acc[self.sub_categories[i]] = accuracy
423
+
424
+ return dict_acc
425
+
426
+ def compute_summary(self):
427
+ print("Computing summary ...")
428
+ representative_sentences = self.get_representative_sentences(self.completions)
429
+ accuracies = self.compute_accuracy(self.completions)
430
+ mean = np.mean(self.green_scores)
431
+ std = np.std(self.green_scores)
432
+
433
+ summary = f"\n-------------{self.model_name}----------------\n [Summary]: Green average {mean} and standard deviation {std} \n [Clinically Significant Errors Analyses]: <accuracy>. <representative error>\n\n"
434
+ for idx, sub_category in enumerate(self.sub_categories):
435
+ accuracy = accuracies[sub_category]
436
+ sentences = representative_sentences[sub_category]
437
+ summary += f"{sub_category}: {accuracy}. \n {sentences} \n\n"
438
+ summary += "----------------------------------\n"
439
+
440
+ return mean, std, summary
441
+
442
+
443
+ if __name__ == "__main__":
444
+ refs = [
445
+ "Interstitial opacities without changes.",
446
+ "Interval development of segmental heterogeneous airspace opacities throughout the lungs . No significant pneumothorax or pleural effusion . Bilateral calcified pleural plaques are scattered throughout the lungs . The heart is not significantly enlarged .",
447
+ "Lung volumes are low, causing bronchovascular crowding. The cardiomediastinal silhouette is unremarkable. No focal consolidation, pleural effusion, or pneumothorax detected. Within the limitations of chest radiography, osseous structures are unremarkable.",
448
+ ]
449
+ hyps = [
450
+ "Interstitial opacities at bases without changes.",
451
+ "Interval development of segmental heterogeneous airspace opacities throughout the lungs . No significant pneumothorax or pleural effusion . Bilateral calcified pleural plaques are scattered throughout the lungs . The heart is not significantly enlarged .",
452
+ "Endotracheal and nasogastric tubes have been removed. Changes of median sternotomy, with continued leftward displacement of the fourth inferiomost sternal wire. There is continued moderate-to-severe enlargement of the cardiac silhouette. Pulmonary aeration is slightly improved, with residual left lower lobe atelectasis. Stable central venous congestion and interstitial pulmonary edema. Small bilateral pleural effusions are unchanged.",
453
+ ]
454
+
455
+ model_name = "StanfordAIMI/GREEN-radllama2-7b"
456
+
457
+ green_scorer = GREEN(model_name, output_dir=".")
458
+ mean, std, green_score_list, summary, result_df = green_scorer(refs, hyps)
459
+ print(green_score_list)
460
+ print(summary)
461
+ # for index, row in result_df.iterrows():
462
+ # print(f"Row {index}:\n")
463
+ # for col_name in result_df.columns:
464
+ # print(f"{col_name}: {row[col_name]}\n")
465
+ # print('-' * 80)
factual/green_score/utils.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.distributed as dist
2
+ import os
3
+ import sys
4
+
5
+ from sklearn.metrics import silhouette_score
6
+ from sklearn.cluster import KMeans
7
+ from sklearn import preprocessing
8
+ from sentence_transformers import SentenceTransformer
9
+ from scipy.spatial import distance
10
+ import numpy as np
11
+
12
+ # A dictionary to store rewards for pairs of reference and hypothesis reports
13
+
14
+
15
+ def compute_largest_cluster(sentences):
16
+ """
17
+ Computes the largest cluster of sentences using K-means clustering, finds the sentences within the largest cluster, and orders them by their distance to the cluster center.
18
+
19
+ Args:
20
+ sentences (list): List of sentences to be clustered.
21
+
22
+ Returns:
23
+ tuple: A tuple containing:
24
+ - embeddings (ndarray): Normalized embeddings of the input sentences.
25
+ - sentences_of_largest_cluster (list): Sentences in the largest cluster, ordered by their proximity
26
+ to the cluster center.
27
+ """
28
+ if len(sentences) == 0:
29
+ return None, None
30
+ embeddings, kmeans = compute_kmeans(sentences)
31
+ cluster_sizes = np.bincount(kmeans.labels_)
32
+ largest_cluster_idx = np.argmax(cluster_sizes)
33
+ cluster_member_ids = np.where(kmeans.labels_ == largest_cluster_idx)[0]
34
+ sentences_of_largest_cluster = [sentences[i] for i in cluster_member_ids]
35
+
36
+ largest_cluster_mean = kmeans.cluster_centers_[largest_cluster_idx]
37
+ embeddings_of_largest_cluster = [embeddings[i] for i in cluster_member_ids]
38
+ distances = distance.cdist(
39
+ embeddings_of_largest_cluster, [largest_cluster_mean], "cosine"
40
+ ).flatten()
41
+ closest_point_indices = np.argsort(distances)[0]
42
+
43
+ sentences_of_largest_cluster = sentences_of_largest_cluster[closest_point_indices]
44
+
45
+ return embeddings, sentences_of_largest_cluster
46
+
47
+
48
+ def compute_kmeans(sentences):
49
+ """
50
+ Computes K-means clustering for a list of sentences by generating their embeddings, normalizing the embeddings, and determining the optimal number of clusters using binary search.
51
+
52
+ Args:
53
+ sentences (list): List of sentences to be clustered.
54
+
55
+ Returns:
56
+ tuple: A tuple containing:
57
+ - embeddings (ndarray): Normalized embeddings of the input sentences.
58
+ - kmeans (KMeans): The KMeans object with the optimal number of clusters determined.
59
+ """
60
+ # sentence embeddings
61
+ model = SentenceTransformer("sentence-transformers/paraphrase-mpnet-base-v2")
62
+ embeddings = model.encode(sentences)
63
+ # normalize the embeddings for equivalent computation of the cosine distance
64
+ embeddings = preprocessing.normalize(embeddings)
65
+ # compute the number of clusters with binary search
66
+ kmeans = binary_search_optimal_kmeans(embeddings, min_k=0, max_k=len(sentences))
67
+ return embeddings, kmeans
68
+
69
+
70
+ def binary_search_optimal_kmeans(data, min_k, max_k):
71
+ """
72
+ Finds the optimal k for KMeans clustering using binary search on the silhouette score.
73
+
74
+ Args:
75
+ data (list): cluster data.
76
+ min_k: minimum k for binary search
77
+ max_k: maximum k for binary search
78
+
79
+ Returns:
80
+ list: List of cleaned response strings.
81
+ """
82
+ best_k = min_k
83
+ best_score = -1
84
+ best_kmeans = KMeans(n_clusters=1, random_state=42).fit(
85
+ data
86
+ ) # start with 1 cluster for len(data) < 2
87
+
88
+ while min_k <= max_k:
89
+ mid_k = (min_k + max_k) // 2
90
+ if mid_k < 2:
91
+ break
92
+
93
+ kmeans = KMeans(n_clusters=mid_k, random_state=42).fit(data)
94
+ labels = kmeans.labels_
95
+ score = silhouette_score(data, labels)
96
+
97
+ if score > best_score:
98
+ best_score = score
99
+ best_k = mid_k
100
+ best_kmeans = kmeans # Update the best KMeans model
101
+ min_k = mid_k + 1
102
+ else:
103
+ max_k = mid_k - 1
104
+
105
+ return best_kmeans
106
+
107
+
108
+ def flatten_values_lists_of_list_dicts_to_dict(item):
109
+ """
110
+ Flattens a list of dictionaries containing lists of values into a single dictionary.
111
+
112
+ Args:
113
+ item (list): List of dictionaries, where each dictionary's values are lists. If any element of the list is itself a list, the function will consider only the first dictionary in that sublist.
114
+
115
+ Returns:
116
+ dict: A dictionary where each key corresponds to the keys in the input dictionaries, and each value is a flattened list of all values associated with that key across all input dictionaries.
117
+ """
118
+
119
+ result = {}
120
+ for i in item:
121
+ if isinstance(i, list):
122
+ i = i[0]
123
+ for key, lists in i.items():
124
+ if key not in result:
125
+ result[key] = []
126
+ result[key].extend(lists)
127
+
128
+ return result
129
+
130
+
131
+ def gather_processes(local_candidates, local_references=None):
132
+ world_size = int(os.environ.get("WORLD_SIZE", "1"))
133
+ local_rank = int(os.environ.get("RANK", "0"))
134
+ global_candidates_list = None
135
+ global_references_list = None
136
+
137
+ if local_rank == 0:
138
+ # Initialize the gather list only on the root process
139
+ global_candidates_list = [None for _ in range(world_size)]
140
+ global_references_list = [None for _ in range(world_size)]
141
+ try:
142
+ dist.gather_object(local_candidates, global_candidates_list, dst=0)
143
+
144
+ if not local_references is None:
145
+ dist.gather_object(local_references, global_references_list, dst=0)
146
+
147
+ except Exception as e:
148
+ print(f"Error during result gathering: {e}")
149
+
150
+ if local_rank != 0:
151
+ # Exit the process
152
+ # print(f"Rank {dist.get_rank()} exiting.")
153
+ dist.destroy_process_group() # Clean up the distributed processing group
154
+ sys.exit() # Exit the process
155
+
156
+ # Flatten the gathered list
157
+ candidates_list = []
158
+ for i in global_candidates_list:
159
+ candidates_list.extend(i)
160
+
161
+ if not global_references_list[0] is None:
162
+ references_list = []
163
+ for i in global_references_list:
164
+ references_list.extend(i)
165
+ print(f"References list: {len(references_list)}")
166
+ return candidates_list, references_list
167
+
168
+ return candidates_list
169
+
170
+
171
+ def clean_responses(response):
172
+ if "[Explanation]:" in response:
173
+ if "<|assistant|>" in response:
174
+ response = response.split("<|assistant|>")[-1]
175
+ if (
176
+ "[Explanation]:\n <Explanation>\n" or "[Explanation]:\n<Explanation>"
177
+ ) in response:
178
+ response = response.split("[Explanation]:")[1]
179
+ else:
180
+ response = response.split("[Explanation]:")[-1]
181
+ if "<|assistant|>" in response:
182
+ response = response.split("<|assistant|>")[-1]
183
+ return response.replace("</s>", "").replace("<unk>", "")
184
+
185
+
186
+ def make_prompt(text1, text2, max_len=300):
187
+ """
188
+ Creates a prompt for evaluating the accuracy of a candidate radiology report in comparison to a reference radiology report.
189
+
190
+ Args:
191
+ text1 (str): Reference radiology report.
192
+ text2 (str): Candidate radiology report.
193
+
194
+ Returns:
195
+ str: Formatted prompt string.
196
+ """
197
+ text1 = " ".join(text1.split()[:max_len])
198
+ text2 = " ".join(text2.split()[:max_len])
199
+ prompt = f"Objective: Evaluate the accuracy of a candidate radiology report in comparison to a reference radiology report composed by expert radiologists.\n\n Process Overview: You will be presented with:\n\n 1. The criteria for making a judgment.\n 2. The reference radiology report.\n 3. The candidate radiology report.\n 4. The desired format for your assessment.\n\n 1. Criteria for Judgment:\n\n For each candidate report, determine:\n\n The count of clinically significant errors.\n The count of clinically insignificant errors.\n\n Errors can fall into one of these categories:\n\n a) False report of a finding in the candidate.\n b) Missing a finding present in the reference.\n c) Misidentification of a finding's anatomic location/position.\n d) Misassessment of the severity of a finding.\n e) Mentioning a comparison that isn't in the reference.\n f) Omitting a comparison detailing a change from a prior study.\n Note: Concentrate on the clinical findings rather than the report's writing style. Evaluate only the findings that appear in both reports.\n\n 2. Reference Report:\n {text1}\n\n 3. Candidate Report:\n {text2}\n\n 4. Reporting Your Assessment:\n\n Follow this specific format for your output, even if no errors are found:\n ```\n [Explanation]:\n <Explanation>\n\n [Clinically Significant Errors]:\n (a) <Error Type>: <The number of errors>. <Error 1>; <Error 2>; ...; <Error n>\n ....\n (f) <Error Type>: <The number of errors>. <Error 1>; <Error 2>; ...; <Error n>\n\n [Clinically Insignificant Errors]:\n (a) <Error Type>: <The number of errors>. <Error 1>; <Error 2>; ...; <Error n>\n ....\n (f) <Error Type>: <The number of errors>. <Error 1>; <Error 2>; ...; <Error n>\n\n [Matched Findings]:\n <The number of matched findings>. <Finding 1>; <Finding 2>; ...; <Finding n>\n ```\n"
200
+ return prompt
nlg/__init__.py ADDED
File without changes
nlg/bertscore/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
nlg/bertscore/bertscore.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from bert_score import BERTScorer
4
+
5
+
6
+ class BertScore(nn.Module):
7
+ def __init__(self,
8
+ model_type='distilbert-base-uncased',
9
+ num_layers=5,
10
+ rescale_with_baseline=True,
11
+ idf=False,
12
+ ):
13
+ super(BertScore, self).__init__()
14
+ with torch.no_grad():
15
+ self.bert_scorer = BERTScorer(model_type=model_type,
16
+ num_layers=num_layers,
17
+ batch_size=64,
18
+ nthreads=4,
19
+ all_layers=False,
20
+ idf=idf,
21
+ device=None,
22
+ lang='en',
23
+ rescale_with_baseline=rescale_with_baseline,
24
+ baseline_path=None)
25
+
26
+ def forward(self, refs, hyps):
27
+ p, r, f = self.bert_scorer.score(
28
+ cands=hyps,
29
+ refs=refs,
30
+ verbose=False,
31
+ batch_size=64,
32
+ )
33
+ return torch.mean(f).item(), f.tolist()
34
+
35
+
36
+ if __name__ == '__main__':
37
+ x, y = (BertScore()(
38
+ hyps=[
39
+ "nothing to do lol",
40
+ "nothing to do x",
41
+ 'there are moderate bilateral pleural effusions with overlying atelectasis, underlying consolidation not excluded. mild prominence of the interstitial markings suggests mild pulmonary edema. the cardiac silhouette is mildly enlarged. the mediastinal contours are unremarkable. there is no evidence of pneumothorax.'
42
+ ],
43
+ refs=[
44
+ 'heart size is moderately enlarged. the mediastinal and hilar contours are unchanged. there is no pulmonary edema. small left pleural effusion is present. patchy opacities in the lung bases likely reflect atelectasis. no pneumothorax is seen. there are no acute osseous abnormalities.',
45
+ 'heart size is mildly enlarged. the mediastinal and hilar contours are normal. there is mild pulmonary edema. moderate bilateral pleural effusions are present, left greater than right. bibasilar airspace opacities likely reflect atelectasis. no pneumothorax is seen. there are no acute osseous abnormalities.',
46
+ 'heart size is mildly enlarged. the mediastinal and hilar contours are normal. there is mild pulmonary edema. moderate bilateral pleural effusions are present, left greater than right. bibasilar airspace opacities likely reflect atelectasis. no pneumothorax is seen. there are no acute osseous abnormalities.'
47
+ ])
48
+ )
49
+ print(x)
50
+ print(y)
nlg/bleu/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
nlg/bleu/bleu.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : bleu.py
4
+ #
5
+ # Description : Wrapper for BLEU scorer.
6
+ #
7
+ # Creation Date : 06-01-2015
8
+ # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9
+ # Authors : Hao Fang <[email protected]> and Tsung-Yi Lin <[email protected]>
10
+
11
+ import torch.nn as nn
12
+ from .bleu_scorer import BleuScorer
13
+
14
+
15
+ class Bleu(nn.Module):
16
+ def __init__(self, n=4, **kwargs):
17
+ # default compute Blue score up to 4
18
+ super().__init__()
19
+ self._n = n
20
+
21
+ def forward(self, gts, res):
22
+ return self.compute_score(gts, res)
23
+
24
+ def compute_score(self, gts, res):
25
+ res = {i: [v] for i, v in enumerate(res)}
26
+ gts = {i: [v] for i, v in enumerate(gts)}
27
+ bleu_scorer = BleuScorer(n=self._n)
28
+
29
+ for id in sorted(gts.keys()):
30
+ hypo = res[id]
31
+ ref = gts[id]
32
+
33
+ # Sanity check.
34
+ assert (type(hypo) is list)
35
+ assert (len(hypo) == 1)
36
+ assert (type(ref) is list)
37
+ assert (len(ref) >= 1)
38
+
39
+ bleu_scorer += (hypo[0], ref)
40
+
41
+ # score, scores = bleu_scorer.compute_score(option='shortest')
42
+ score, scores = bleu_scorer.compute_score(option='closest', verbose=0)
43
+ # score, scores = bleu_scorer.compute_score(option='average', verbose=1)
44
+
45
+ # return (bleu, bleu_info)
46
+ return score[self._n-1], scores[self._n-1]
47
+
48
+ def method(self):
49
+ return "Bleu"
nlg/bleu/bleu_scorer.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # bleu_scorer.py
4
+ # David Chiang <[email protected]>
5
+
6
+ # Copyright (c) 2004-2006 University of Maryland. All rights
7
+ # reserved. Do not redistribute without permission from the
8
+ # author. Not for commercial use.
9
+
10
+ # Modified by:
11
+ # Hao Fang <[email protected]>
12
+ # Tsung-Yi Lin <[email protected]>
13
+
14
+ '''Provides:
15
+ cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16
+ cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17
+ '''
18
+
19
+ import copy
20
+ import sys, math, re
21
+ from collections import defaultdict
22
+
23
+ import six
24
+ from six.moves import xrange as range
25
+
26
+
27
+ def precook(s, n=4, out=False):
28
+ """Takes a string as input and returns an object that can be given to
29
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
30
+ can take string arguments as well."""
31
+ words = s.split()
32
+ counts = defaultdict(int)
33
+ for k in range(1,n+1):
34
+ for i in range(len(words)-k+1):
35
+ ngram = tuple(words[i:i+k])
36
+ counts[ngram] += 1
37
+ return (len(words), counts)
38
+
39
+ def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
40
+ '''Takes a list of reference sentences for a single segment
41
+ and returns an object that encapsulates everything that BLEU
42
+ needs to know about them.'''
43
+
44
+ reflen = []
45
+ maxcounts = {}
46
+ for ref in refs:
47
+ rl, counts = precook(ref, n)
48
+ reflen.append(rl)
49
+ for (ngram,count) in six.iteritems(counts):
50
+ maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
51
+
52
+ # Calculate effective reference sentence length.
53
+ if eff == "shortest":
54
+ reflen = min(reflen)
55
+ elif eff == "average":
56
+ reflen = float(sum(reflen))/len(reflen)
57
+
58
+ ## lhuang: N.B.: leave reflen computaiton to the very end!!
59
+
60
+ ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
61
+
62
+ return (reflen, maxcounts)
63
+
64
+ def cook_test(test, reflen_refmaxcounts, eff=None, n=4):
65
+ '''Takes a test sentence and returns an object that
66
+ encapsulates everything that BLEU needs to know about it.'''
67
+
68
+ reflen, refmaxcounts = reflen_refmaxcounts
69
+ testlen, counts = precook(test, n, True)
70
+
71
+ result = {}
72
+
73
+ # Calculate effective reference sentence length.
74
+
75
+ if eff == "closest":
76
+ result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
77
+ else: ## i.e., "average" or "shortest" or None
78
+ result["reflen"] = reflen
79
+
80
+ result["testlen"] = testlen
81
+
82
+ result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
83
+
84
+ result['correct'] = [0]*n
85
+ for (ngram, count) in six.iteritems(counts):
86
+ result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
87
+
88
+ return result
89
+
90
+ class BleuScorer(object):
91
+ """Bleu scorer.
92
+ """
93
+
94
+ __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
95
+ # special_reflen is used in oracle (proportional effective ref len for a node).
96
+
97
+ def copy(self):
98
+ ''' copy the refs.'''
99
+ new = BleuScorer(n=self.n)
100
+ new.ctest = copy.copy(self.ctest)
101
+ new.crefs = copy.copy(self.crefs)
102
+ new._score = None
103
+ return new
104
+
105
+ def __init__(self, test=None, refs=None, n=4, special_reflen=None):
106
+ ''' singular instance '''
107
+
108
+ self.n = n
109
+ self.crefs = []
110
+ self.ctest = []
111
+ self.cook_append(test, refs)
112
+ self.special_reflen = special_reflen
113
+
114
+ def cook_append(self, test, refs):
115
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
116
+
117
+ if refs is not None:
118
+ self.crefs.append(cook_refs(refs))
119
+ if test is not None:
120
+ cooked_test = cook_test(test, self.crefs[-1])
121
+ self.ctest.append(cooked_test) ## N.B.: -1
122
+ else:
123
+ self.ctest.append(None) # lens of crefs and ctest have to match
124
+
125
+ self._score = None ## need to recompute
126
+
127
+ def ratio(self, option=None):
128
+ self.compute_score(option=option)
129
+ return self._ratio
130
+
131
+ def score_ratio(self, option=None):
132
+ '''return (bleu, len_ratio) pair'''
133
+ return (self.fscore(option=option), self.ratio(option=option))
134
+
135
+ def score_ratio_str(self, option=None):
136
+ return "%.4f (%.2f)" % self.score_ratio(option)
137
+
138
+ def reflen(self, option=None):
139
+ self.compute_score(option=option)
140
+ return self._reflen
141
+
142
+ def testlen(self, option=None):
143
+ self.compute_score(option=option)
144
+ return self._testlen
145
+
146
+ def retest(self, new_test):
147
+ if type(new_test) is str:
148
+ new_test = [new_test]
149
+ assert len(new_test) == len(self.crefs), new_test
150
+ self.ctest = []
151
+ for t, rs in zip(new_test, self.crefs):
152
+ self.ctest.append(cook_test(t, rs))
153
+ self._score = None
154
+
155
+ return self
156
+
157
+ def rescore(self, new_test):
158
+ ''' replace test(s) with new test(s), and returns the new score.'''
159
+
160
+ return self.retest(new_test).compute_score()
161
+
162
+ def size(self):
163
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
164
+ return len(self.crefs)
165
+
166
+ def __iadd__(self, other):
167
+ '''add an instance (e.g., from another sentence).'''
168
+
169
+ if type(other) is tuple:
170
+ ## avoid creating new BleuScorer instances
171
+ self.cook_append(other[0], other[1])
172
+ else:
173
+ assert self.compatible(other), "incompatible BLEUs."
174
+ self.ctest.extend(other.ctest)
175
+ self.crefs.extend(other.crefs)
176
+ self._score = None ## need to recompute
177
+
178
+ return self
179
+
180
+ def compatible(self, other):
181
+ return isinstance(other, BleuScorer) and self.n == other.n
182
+
183
+ def single_reflen(self, option="average"):
184
+ return self._single_reflen(self.crefs[0][0], option)
185
+
186
+ def _single_reflen(self, reflens, option=None, testlen=None):
187
+
188
+ if option == "shortest":
189
+ reflen = min(reflens)
190
+ elif option == "average":
191
+ reflen = float(sum(reflens))/len(reflens)
192
+ elif option == "closest":
193
+ reflen = min((abs(l-testlen), l) for l in reflens)[1]
194
+ else:
195
+ assert False, "unsupported reflen option %s" % option
196
+
197
+ return reflen
198
+
199
+ def recompute_score(self, option=None, verbose=0):
200
+ self._score = None
201
+ return self.compute_score(option, verbose)
202
+
203
+ def compute_score(self, option=None, verbose=0):
204
+ n = self.n
205
+ small = 1e-9
206
+ tiny = 1e-15 ## so that if guess is 0 still return 0
207
+ bleu_list = [[] for _ in range(n)]
208
+
209
+ if self._score is not None:
210
+ return self._score
211
+
212
+ if option is None:
213
+ option = "average" if len(self.crefs) == 1 else "closest"
214
+
215
+ self._testlen = 0
216
+ self._reflen = 0
217
+ totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
218
+
219
+ # for each sentence
220
+ for comps in self.ctest:
221
+ testlen = comps['testlen']
222
+ self._testlen += testlen
223
+
224
+ if self.special_reflen is None: ## need computation
225
+ reflen = self._single_reflen(comps['reflen'], option, testlen)
226
+ else:
227
+ reflen = self.special_reflen
228
+
229
+ self._reflen += reflen
230
+
231
+ for key in ['guess','correct']:
232
+ for k in range(n):
233
+ totalcomps[key][k] += comps[key][k]
234
+
235
+ # append per image bleu score
236
+ bleu = 1.
237
+ for k in range(n):
238
+ bleu *= (float(comps['correct'][k]) + tiny) \
239
+ /(float(comps['guess'][k]) + small)
240
+ bleu_list[k].append(bleu ** (1./(k+1)))
241
+ ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
242
+ if ratio < 1:
243
+ for k in range(n):
244
+ bleu_list[k][-1] *= math.exp(1 - 1/ratio)
245
+
246
+ if verbose > 1:
247
+ print(comps, reflen)
248
+
249
+ totalcomps['reflen'] = self._reflen
250
+ totalcomps['testlen'] = self._testlen
251
+
252
+ bleus = []
253
+ bleu = 1.
254
+ for k in range(n):
255
+ bleu *= float(totalcomps['correct'][k] + tiny) \
256
+ / (totalcomps['guess'][k] + small)
257
+ bleus.append(bleu ** (1./(k+1)))
258
+ ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
259
+ if ratio < 1:
260
+ for k in range(n):
261
+ bleus[k] *= math.exp(1 - 1/ratio)
262
+
263
+ if verbose > 0:
264
+ print(totalcomps)
265
+ print("ratio:", ratio)
266
+
267
+ self._score = bleus
268
+ return self._score, bleu_list
nlg/radevalbertscore.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from bert_score import score
3
+
4
+ def _get_default_device():
5
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+
7
+ class RadEvalBERTScorer:
8
+ """
9
+ Wrapper around bert_score for radiology reports using a custom BERT model.
10
+ """
11
+ def __init__(self,
12
+ model_type: str = "IAMJB/RadEvalModernBERT",
13
+ num_layers: int = None,
14
+ use_fast_tokenizer: bool = True,
15
+ rescale_with_baseline: bool = False,
16
+ device: torch.device = None):
17
+ self.model_type = model_type
18
+ self.num_layers = num_layers
19
+ self.use_fast_tokenizer = use_fast_tokenizer
20
+ self.rescale_with_baseline = rescale_with_baseline
21
+ self.device = device or _get_default_device()
22
+
23
+ def score(self, refs: list[str], hyps: list[str]) -> float:
24
+ """
25
+ Compute BERTScore F1 between reference and hypothesis texts.
26
+
27
+ Args:
28
+ refs: list of reference sentences.
29
+ hyps: list of hypothesis sentences (predictions).
30
+
31
+ Returns:
32
+ Mean F1 score as a float.
33
+ """
34
+ # bert_score expects cands (hypotheses) first, then refs
35
+ P, R, F1 = score(
36
+ cands=hyps,
37
+ refs=refs,
38
+ model_type=self.model_type,
39
+ num_layers=self.num_layers,
40
+ use_fast_tokenizer=self.use_fast_tokenizer,
41
+ rescale_with_baseline=self.rescale_with_baseline,
42
+ device=self.device
43
+ )
44
+ # Return the mean F1 over all pairs
45
+ return F1.mean().item(), F1
46
+
47
+ if __name__ == "__main__":
48
+ # Example usage
49
+ refs = ["Chronic mild to moderate cardiomegaly and pulmonary venous hypertension."]
50
+ hyps = ["Mild left basal atelectasis; no pneumonia."]
51
+ scorer = RadiologyBERTScorer(num_layers=23)
52
+ f1_score = scorer.score(refs, hyps)
53
+ print(f"Mean F1 score: {f1_score:.4f}")
nlg/rouge/rouge.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from rouge_score import rouge_scorer
3
+ from six.moves import zip_longest
4
+ import numpy as np
5
+
6
+
7
+ class Rouge(nn.Module):
8
+ def __init__(self, rouges, **kwargs):
9
+ super().__init__()
10
+ rouges = [r.replace('rougel', 'rougeL') for r in rouges]
11
+ self.scorer = rouge_scorer.RougeScorer(rouges, use_stemmer=True)
12
+ self.rouges = rouges
13
+
14
+ def forward(self, refs, hyps):
15
+ scores = []
16
+ for target_rec, prediction_rec in zip_longest(refs, hyps):
17
+ if target_rec is None or prediction_rec is None:
18
+ raise ValueError("Must have equal number of lines across target and "
19
+ "prediction.")
20
+ scores.append(self.scorer.score(target_rec, prediction_rec))
21
+ f1_rouge = [s[self.rouges[0]].fmeasure for s in scores]
22
+ return np.mean(f1_rouge), f1_rouge
23
+
24
+
25
+ class Rouge1(Rouge):
26
+ def __init__(self, **kwargs):
27
+ super(Rouge1, self).__init__(rouges=['rouge1'])
28
+
29
+
30
+ class Rouge2(Rouge):
31
+ def __init__(self, **kwargs):
32
+ super(Rouge2, self).__init__(rouges=['rouge2'])
33
+
34
+
35
+ class RougeL(Rouge):
36
+ def __init__(self, **kwargs):
37
+ super(RougeL, self).__init__(rouges=['rougeL'])
utils.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------
2
+ # This file includes code adapted from:
3
+ # https://github.com/jbdel/RadEval/blob/null-hypothesis/utils.py
4
+ # Original author: Justin Xu
5
+ # ---------------------------------------------------------------
6
+ import re
7
+ import nltk
8
+ import random
9
+ from typing import List, Dict, Tuple, Callable, Optional
10
+ from collections import defaultdict
11
+
12
+ nltk.download("punkt_tab", quiet=True)
13
+
14
+ def clean_numbered_list(text):
15
+ """
16
+ Clean a report if it's a numbered list by:
17
+ 1. Adding proper spacing between numbered items
18
+ 2. Removing the numbered list markers
19
+ 3. Adding spaces after periods between sentences
20
+ """
21
+ # First, separate numbered items that are stuck together without spaces
22
+ # Example: "textx.2. text2" -> "texty. 2. text2"
23
+ text = re.sub(r'\.(\d+\.)', r'. \1', text)
24
+
25
+ # Handle patterns where there's no period between numbered entries
26
+ # Example: "1. item1 2. item2" -> "1. item1. 2. item2"
27
+ text = re.sub(r'(\d+\.\s*[^.]+?)\s+(?=\d+\.)', r'\1. ', text)
28
+
29
+ # Then remove the numbered list markers
30
+ # But avoid removing decimal numbers in measurements like "3.5 cm"
31
+ text = re.sub(r'(?<!\d)\d+\.\s*', '', text)
32
+
33
+ # Add spaces after periods between sentences if missing
34
+ # Example: "sentence1.sentence2" -> "sentence1. sentence2"
35
+ # But don't split decimal numbers like "3.5 cm"
36
+ text = re.sub(r'\.([A-Za-z])', r'. \1', text)
37
+ return nltk.sent_tokenize(text)
38
+
39
+ class PairedTest:
40
+ """
41
+ Paired significance testing for comparing radiology report generation systems.
42
+
43
+ Supports paired approximate randomization (AR).
44
+ """
45
+
46
+ def __init__(self,
47
+ systems: Dict[str, List[str]],
48
+ metrics: Dict[str, Callable],
49
+ references: Optional[List[str]],
50
+ n_samples: int = 10000,
51
+ n_jobs: int = 1,
52
+ seed: int = 12345):
53
+ """
54
+ Args:
55
+ systems: Dictionary mapping system names to their generated reports
56
+ metrics: Dictionary mapping metric names to metric functions
57
+ references: List of reference reports
58
+ n_samples: Number of resampling trials (default: 10000)
59
+ n_jobs: Number of parallel jobs (default: 1)
60
+ seed: Random seed for reproducibility
61
+ """
62
+ self.systems = systems
63
+ self.metrics = metrics
64
+ self.references = references
65
+ self.n_samples = n_samples
66
+ self.n_jobs = n_jobs
67
+ self.seed = seed
68
+
69
+ random.seed(seed)
70
+
71
+ if not systems:
72
+ raise ValueError("At least one system is required")
73
+
74
+ system_lengths = [len(outputs) for outputs in systems.values()]
75
+ if len(set(system_lengths)) > 1:
76
+ raise ValueError("All systems must have the same number of outputs")
77
+
78
+ if references and len(references) != system_lengths[0]:
79
+ raise ValueError("References must have same length as system outputs")
80
+
81
+ self.n_instances = system_lengths[0]
82
+
83
+ def __call__(self) -> Tuple[Dict[str, str], Dict[str, Dict[str, float]]]:
84
+ """
85
+ Run the paired significance test.
86
+
87
+ Returns:
88
+ Tuple of (signatures, scores) where:
89
+ - signatures: Dict mapping metric names to signature strings
90
+ - scores: Dict mapping system names to metric scores and p-values
91
+ """
92
+ # Calculate baseline scores for all systems and metrics
93
+ baseline_scores = self._calculate_baseline_scores()
94
+
95
+ # Get baseline system (first system)
96
+ baseline_name = list(self.systems.keys())[0]
97
+
98
+ scores = {}
99
+ signatures = {}
100
+
101
+ # Calculate scores and p-values for each system
102
+ for system_name in self.systems.keys():
103
+ scores[system_name] = {}
104
+
105
+ for metric_name in self.metrics.keys():
106
+ score = baseline_scores[system_name][metric_name]
107
+ scores[system_name][metric_name] = score
108
+
109
+ if system_name != baseline_name:
110
+ p_value = self._calculate_p_value(
111
+ baseline_name, system_name, metric_name, baseline_scores
112
+ )
113
+ scores[system_name][f'{metric_name}_pvalue'] = p_value
114
+
115
+ for metric_name in self.metrics.keys():
116
+ signatures[metric_name] = f"{metric_name}|{'ar'}:{self.n_samples}|seed:{self.seed}"
117
+
118
+ return signatures, scores
119
+
120
+ def _calculate_baseline_scores(self) -> Dict[str, Dict[str, float]]:
121
+ """Calculate baseline scores for all systems and metrics."""
122
+ scores = defaultdict(dict)
123
+
124
+ for system_name, outputs in self.systems.items():
125
+ for metric_name, metric_func in self.metrics.items():
126
+ if self.references:
127
+ score = metric_func(outputs, self.references)
128
+ else:
129
+ score = metric_func(outputs)
130
+
131
+ if isinstance(score, dict):
132
+ if 'score' in score:
133
+ scores[system_name][metric_name] = score['score']
134
+ else:
135
+ scores[system_name][metric_name] = list(score.values())[0]
136
+ elif isinstance(score, (tuple, list)):
137
+ scores[system_name][metric_name] = score[0]
138
+ else:
139
+ scores[system_name][metric_name] = score
140
+
141
+ return scores
142
+
143
+ def _calculate_p_value(self,
144
+ baseline_name: str,
145
+ system_name: str,
146
+ metric_name: str,
147
+ baseline_scores: Dict[str, Dict[str, float]]) -> float:
148
+ """Calculate p-value using AR test"""
149
+
150
+ baseline_outputs = self.systems[baseline_name]
151
+ system_outputs = self.systems[system_name]
152
+ metric_func = self.metrics[metric_name]
153
+
154
+ baseline_score = baseline_scores[baseline_name][metric_name]
155
+ system_score = baseline_scores[system_name][metric_name]
156
+ original_delta = abs(system_score - baseline_score)
157
+
158
+ return self._approximate_randomization_test(
159
+ baseline_outputs, system_outputs, metric_func, original_delta
160
+ )
161
+
162
+ def _approximate_randomization_test(self,
163
+ baseline_outputs: List[str],
164
+ system_outputs: List[str],
165
+ metric_func: Callable,
166
+ original_delta: float) -> float:
167
+ """
168
+ Perform AR test.
169
+
170
+ For each trial, randomly swap outputs between systems and calculate
171
+ the score difference. P-value is the proportion of trials where
172
+ the randomized delta >= original delta.
173
+ """
174
+ count_greater = 0
175
+
176
+ for _ in range(self.n_samples):
177
+ randomized_baseline = []
178
+ randomized_system = []
179
+
180
+ for i in range(self.n_instances):
181
+ if random.random() < 0.5:
182
+ # Don't swap
183
+ randomized_baseline.append(baseline_outputs[i])
184
+ randomized_system.append(system_outputs[i])
185
+ else:
186
+ # Swap
187
+ randomized_baseline.append(system_outputs[i])
188
+ randomized_system.append(baseline_outputs[i])
189
+
190
+ if self.references:
191
+ rand_baseline_score = metric_func(randomized_baseline, self.references)
192
+ rand_system_score = metric_func(randomized_system, self.references)
193
+ else:
194
+ rand_baseline_score = metric_func(randomized_baseline)
195
+ rand_system_score = metric_func(randomized_system)
196
+
197
+ if isinstance(rand_baseline_score, dict):
198
+ rand_baseline_score = rand_baseline_score.get('score', list(rand_baseline_score.values())[0])
199
+ elif isinstance(rand_baseline_score, (tuple, list)):
200
+ rand_baseline_score = rand_baseline_score[0]
201
+
202
+ if isinstance(rand_system_score, dict):
203
+ rand_system_score = rand_system_score.get('score', list(rand_system_score.values())[0])
204
+ elif isinstance(rand_system_score, (tuple, list)):
205
+ rand_system_score = rand_system_score[0]
206
+
207
+ rand_delta = abs(rand_system_score - rand_baseline_score)
208
+
209
+ if rand_delta >= original_delta:
210
+ count_greater += 1
211
+
212
+ return count_greater / self.n_samples
213
+
214
+
215
+ def print_significance_results(scores: Dict[str, Dict[str, float]],
216
+ signatures: Dict[str, str],
217
+ baseline_name: str,
218
+ significance_level: float = 0.05):
219
+ """
220
+ Args:
221
+ scores: Dictionary of system scores and p-values
222
+ signatures: Dictionary of metric signatures
223
+ baseline_name: Name of the baseline system
224
+ significance_level: Significance threshold (default: 0.05)
225
+ """
226
+ assert baseline_name in scores, f"Baseline system '{baseline_name}' not found in scores."
227
+
228
+ metric_names = [name for name in signatures.keys()]
229
+ system_names = list(scores.keys())
230
+
231
+ print("=" * 80)
232
+ print("PAIRED SIGNIFICANCE TEST RESULTS")
233
+ print("=" * 80)
234
+
235
+ header = f"{'System':<40}"
236
+ for metric in metric_names:
237
+ header += f"{metric:>15}"
238
+ print(header)
239
+ print("-" * len(header))
240
+
241
+ baseline_row = f"Baseline: {baseline_name:<32}"
242
+ for metric in metric_names:
243
+ score = scores[baseline_name][metric]
244
+ baseline_row += f"{score:>12.4f} "
245
+ print(baseline_row)
246
+ print("-" * len(header))
247
+
248
+ for system_name in system_names:
249
+ if system_name == baseline_name:
250
+ continue
251
+
252
+ system_row = f"{system_name:<40}"
253
+ for metric in metric_names:
254
+ score = scores[system_name].get(metric, 0.0)
255
+ if isinstance(score, float):
256
+ system_row += f"{score:>12.4f} "
257
+ else:
258
+ system_row += f"{str(score):>12} "
259
+ print(system_row)
260
+
261
+ # P-value row
262
+ pvalue_row = " " * 40
263
+ for metric in metric_names:
264
+ pvalue_key = f"{metric}_pvalue"
265
+ if pvalue_key in scores[system_name]:
266
+ p_val = scores[system_name][pvalue_key]
267
+ significance_marker = "*" if p_val < significance_level else ""
268
+ pvalue_row += f"(p={p_val:.4f}){significance_marker:<2}".rjust(15)
269
+ else:
270
+ pvalue_row += " " * 15
271
+ print(pvalue_row)
272
+ print("-" * len(header))
273
+
274
+ # Footer
275
+ print(f"- Significance level: {significance_level}")
276
+ print("- '*' indicates significant difference (p < significance level)")
277
+ print("- Null hypothesis: systems are essentially the same")
278
+ print("- Significant results suggest systems are meaningfully different\n")
279
+
280
+ print("METRIC SIGNATURES:")
281
+ for metric, signature in signatures.items():
282
+ print(f"- {metric}: {signature}")
283
+
284
+
285
+ def compare_systems(systems: Dict[str, List[str]],
286
+ metrics: Dict[str, Callable],
287
+ references: Optional[List[str]] = None,
288
+ n_samples: int = 10000,
289
+ significance_level: float = 0.05,
290
+ seed: int = 12345,
291
+ print_results: bool = True) -> Tuple[Dict[str, str], Dict[str, Dict[str, float]]]:
292
+ """
293
+ Args:
294
+ systems: Dictionary mapping system names to their generated reports
295
+ metrics: Dictionary mapping metric names to metric functions
296
+ references: Optional list of reference reports
297
+ n_samples: Number of resampling trials
298
+ significance_level: Significance threshold for printing results
299
+ seed: Random seed for reproducibility
300
+ print_results: Whether to print formatted results
301
+
302
+ Returns:
303
+ Tuple of (signatures, scores)
304
+
305
+ Example:
306
+ ```python
307
+ systems = {
308
+ 'baseline_model': baseline_reports,
309
+ 'new_model': new_model_reports,
310
+ 'other_model': other_model_reports
311
+ }
312
+
313
+ metrics = {
314
+ 'bleu': lambda hyp, ref: bleu_score(hyp, ref),
315
+ 'rouge': lambda hyp, ref: rouge_score(hyp, ref),
316
+ 'bertscore': lambda hyp, ref: bert_score(hyp, ref)
317
+ 'custom_metric': lambda hyp, ref: custom_metric(hyp, ref)
318
+ }
319
+
320
+ signatures, scores = compare_systems(
321
+ systems, metrics, references,
322
+ n_samples=10000
323
+ )
324
+ ```
325
+ """
326
+
327
+ paired_test = PairedTest(
328
+ systems=systems,
329
+ metrics=metrics,
330
+ references=references,
331
+ n_samples=n_samples,
332
+ seed=seed
333
+ )
334
+
335
+ signatures, scores = paired_test()
336
+
337
+ if print_results:
338
+ baseline_name = list(systems.keys())[0]
339
+ print_significance_results(scores, signatures, baseline_name, significance_level)
340
+
341
+ return signatures, scores