import os import numpy as np import jsonlines from collections import defaultdict from sklearn.metrics import classification_report def get_links(sample_string, sample_index): """ takes a sample string and returns a list of attach tuples and a list of rel type strings """ #MINECRAFT labels labels = ['COM','CONTR','CORR','QAP','ACK','ELAB','CLARIFQ','COND','CONTIN', 'RES','EXPL','QELAB','ALT','NARR','CONFQ','SEQ'] split_list = [st.strip() for st in sample_string.split(' ')] rel_list = [] attach_list = [] bad = 0 good = 0 for a in split_list: s_tuple = None rel = None try: s = a.split('(')[1].split(')')[0].split(',') r = a.split('(')[0].strip() except IndexError: print('split error at ', sample_index) else: try: s_tuple = (int(s[0]), int(s[1])) except IndexError: print('split error at ', sample_index) except ValueError: print('value error at ', sample_index) if r in labels: #make sure the label is well-formed rel = r if rel != None and s_tuple != None and (s_tuple[1] - s_tuple[0]) <= 15: #if using a DISTANCE cutoff # if rel != None and s_tuple != None: #if not using a DISTANCE cutoff attach_list.append((int(s[0]), int(s[1]))) rel_list.append(r) good += 1 else: bad += 1 #re-construct the full list #a list of tuples (rel, x, y) #but don't allow doubles!! full_list = [] endpoints = [] for i, r in enumerate(attach_list): if r not in endpoints: endpoints.append(r) full_list.append((rel_list[i], r[0], r[1])) return endpoints, full_list, [good, bad] current_folder=os.getcwd() gold_path = '/path/to/jsonl' pred_path = '/path/to/llamipa_output.txt' save_results = '/path/to/eval_.txt' #to create #get predicted with open(pred_path, 'r') as txt: text = txt.read().split('\n') pred_outputs = [] for t in text: if t.startswith(' ### DS:'): sample = t.split('### DS:')[1].strip() pred_outputs.append(sample) print(len(pred_outputs)) #get gold gold_outputs = [] with jsonlines.open(gold_path) as reader: for obj in reader: if not obj['sample'].startswith('NEW DIALOGUE'): #make sure to ignore incremental formatting gold_outputs.append(obj['PS']) att_f1_l = [] att_prec_l = [] att_rec_l = [] total_attach_tp = 0 total_attach_fp = 0 total_attach_fn = 0 type_f1_l = [] type_prec_l = [] type_rec_l = [] total_TP = [] matrix_list = [] bad_output = 0 good_output = 0 for i, s in enumerate(pred_outputs): pred_att, pred_all, malform = get_links(s, i) gold_att, gold_all, malform = get_links(gold_outputs[i], i) bad_output += malform[1] good_output += malform[0] #calculate number of nulls there should be -- will use to check null count below common = len(set(pred_att).intersection(set(gold_att))) expected_nulls = (len(pred_att) - common) + (len(gold_att) - common) #calculate the precision, recall, and f1 for the sample FOR ATTACHMENTS if len(gold_att) > 0 and len(pred_att) > 0: prec = len([e for e in pred_att if e in gold_att])/len(pred_att) rec = len([e for e in pred_att if e in gold_att])/len(gold_att) total_attach_tp += len([e for e in pred_att if e in gold_att]) total_attach_fp += len([e for e in pred_att if e not in gold_att]) total_attach_fn += len([e for e in gold_att if e not in pred_att]) else: prec = 0 rec = 0 att_prec_l.append(prec) att_rec_l.append(rec) if prec+rec==0: att_f1_l.append(0) else: att_f1_l.append(2*prec*rec/(prec+rec)) #calculate the precision, recall, and f1 for the sample FOR ATTACHMENTS+RELATION TYPE if len(gold_all) > 0 and len(pred_all) > 0: prec = len([e for e in pred_all if e in gold_all])/len(pred_all) rec = len([e for e in pred_all if e in gold_all])/len(gold_all) else: prec = 0 rec = 0 type_prec_l.append(prec) type_rec_l.append(rec) if prec+rec==0: type_f1_l.append(0) else: type_f1_l.append(2*prec*rec/(prec+rec)) #create the relation comparisons by type TP = [e for e in pred_all if e in gold_all] leftover_pred = [p for p in pred_all if p not in TP] leftover_gold = [p for p in gold_all if p not in TP] #then process the TP, FP, FN for matrix total_TP.extend(TP) rem_dict = defaultdict(list) for x in TP: matrix_list.append([x[0], x[0]]) for x in leftover_pred: rem_dict[(x[1], x[2])].append(('p', x[0])) for x in leftover_gold: rem_dict[(x[1], x[2])].append(('g', x[0])) p_count = 0 g_count = 0 null_count = 0 for k in rem_dict.keys(): p = 'NULL' t = 'NULL' for re in rem_dict[k]: if re[0] == 'p': p = re[1] p_count += 1 elif re[0] == 'g': t = re[1] g_count += 1 matrix_list.append([t,p]) if 'NULL' in [t,p]: null_count += 1 assert(len(TP) + p_count == len(pred_all)) assert(len(TP) + g_count == len(gold_all)) assert null_count == expected_nulls #compute labels in gold and pred gold = [m[0] for m in matrix_list] pred = [m[1] for m in matrix_list] gold.extend(pred) labels = list(set(gold)) microf1 = total_attach_tp/(total_attach_tp + 0.5*(total_attach_fp + total_attach_fn)) gold_list = [labels.index(m[0]) for m in matrix_list] pred_list = [labels.index(m[1]) for m in matrix_list] f = open(save_results,"w") print("Attachment F1:",np.mean(att_f1_l),len(att_f1_l), file=f) print("Attachment Average Precision:",np.mean(att_prec_l), file=f) print("Attachment Average Recall:",np.mean(att_rec_l), file=f) print('Micro F1: ', microf1, file=f) print('--------------------------------', file=f) print("Attachment + Rel F1:",np.mean(type_f1_l),len(type_f1_l)) print("Attachment + Rel Average Precision:",np.mean(type_prec_l)) print("Attachment + Rel Average Recall:",np.mean(type_rec_l)) print('---------------------------------------') print(classification_report(gold_list,pred_list,target_names=labels), file=f) # The F1-scores for the relation types displayed in the above table are correct. #That is, while calculating F1 for label l, all the ["NULL", l] entries count towards false-positive for label l #and all the [l, "NULL"] entries count towards false-negative for label l. #So, the "NULL" type is affecting the precision/recall/F1 for label l (as it should). #Now, for the overall weighted average precision/recall/f1-score, # we want the average to be over the actual relation labels set (i.e. excluding "NULL" class). #For that, we do this: d = classification_report(gold_list,pred_list,target_names=labels,output_dict=True) prec = 0 rec = 0 f1 = 0 count = 0 for label in labels: if label!="NULL": prec+=d[label]["precision"]*d[label]["support"] rec+=d[label]["recall"]*d[label]["support"] f1+=d[label]["f1-score"]*d[label]["support"] count+=d[label]["support"] # checking that support is same as the number of ground truth instance for the label # assert d[label]["support"] == Counter(g_label_l)[label] print('--------------------------------', file=f) print("Weighted Average Precision:", prec/count, file=f) print("Weighted Average Recall:", rec/count, file=f) print("Weighted Average F1 score:", f1/count, file=f) f.close()