from groq import Groq import pandas as pd import os from scipy.cluster.hierarchy import linkage import numpy as np from scipy.cluster import hierarchy from tqdm import tqdm from sentence_transformers import SentenceTransformer from scipy.spatial.distance import cosine def convert_labels(dataset): labels = dataset.label.unique() labels_mapping = {} for label in labels: pred = dataset.loc[dataset['label'] == label, 'label_y'].values[0] labels_mapping[label] = pred for label in labels_mapping: dataset.loc[dataset.pred == label, 'pred_label'] = labels_mapping[label] return dataset def add_labels(dataset, original_dataset): original_dataset = original_dataset.rename(columns={'text':'content'}) original_dataset = original_dataset[['content', 'label_y']] #content column to compare with original dataset dataset_content = dataset.content #retrieve all columns for each row in test set from original dataset subset_original_content = original_dataset.loc[original_dataset.content.isin(dataset_content)] #merge dataframes dataset = pd.merge(dataset, subset_original_content, on = 'content') dataset = convert_labels(dataset) return dataset def tree_depth(node): if node is None: return 0 else: left_depth = tree_depth(node.get_left()) right_depth = tree_depth(node.get_right()) return max(left_depth, right_depth) + 1 def reconstruct_tree(mergings, content): tree = {} for i, merge in enumerate(mergings): #these are the leaves, they'll have an index less than the number of examples if merge[0] <= len(mergings): a = content[int(merge[0]) - 1] else: #if here then that's a merged cluster a = tree[int(merge[0])] #these are the leaves, they'll have an index less than the number of examples if merge[1] <= len(mergings): b = content[int(merge[1]) - 1] else: #if here then that's a merged cluster b = tree[int(merge[1])] tree[1 + i + len(mergings)] = [a,b] return tree #remove nested lists in branches and put all nodes in a 1-D list def flatten(dict_list): flat_list = [] for item in dict_list: if isinstance(item, list): flat_list.extend(flatten(item)) else: flat_list.append(item) return flat_list #pass prompt to llm def get_answer(prompt, system_prompt): chat_completion = client.chat.completions.create( messages=[ { "role":"system", "content":f"{system_prompt}" }, { "role": "user", "content": f"{prompt}", } ], model="mixtral-8x7b-32768" ) return chat_completion.choices[0].message.content #pass two leaves to the llm and get a list of their similarities #leaf will have the content, predicted label, and actual label def merge_two_leaves(leaf_0, leaf_1): system_prompt =f'You are given two statements from an offensive language dataset that were misclassified by an offensive language detection system. Analyze the two statements thoroughly and provide a bullet list explanation of the similarities between the two statements. Your list should have the following format: * Error_Feature: where Error_Feature: is a two word discription of the feature and Explanation is a one sentence explanation of the feature. Make sure to stick to the format specified. Avoid making explicit references to the examples and use layman terms for the explanations.' prompt = f'Statement_0: {leaf_0.content}\npredicted_label_0: {leaf_0.predicted_label}\nactual_label_0: {leaf_0.actual_label}\n\nStatement_1: {leaf_1.content}\npredicted_label_1: {leaf_1.predicted_label}\nactual_label_1: {leaf_1.actual_label}\n\nList: ' #pass prompts to llm and get answer base_list = get_answer(prompt, system_prompt) return base_list #case 2: pass leaf (with content, predicted label, and actual label) and the list previously generated def applies(leaf, list, threshold = 2): #another way to do this would be to split the list and check every bullet points -> more api calls system_prompt = f"You are given a statement from an offensive language dataset that was misclassified by an offensive language detection system. In addition, you are given a list of features generated by an LLM for other statements that were misclassified. Perform a thorough analysis of the statement and the list. If at least {threshold} points apply to the statement return YES otherwise return NO." prompt = f"Statement: {leaf.content}\npredicted_label: {leaf.predicted_label}\nactual_label: {leaf.actual_label}\n\nList: {list}\n\nAnswer: " #convert answer to all lower case to avoid llm inconsistency check = get_answer(prompt, system_prompt).lower() #if yes return true otherwise return false return 'yes' in check def merge_leaf(leaf, list): system_prompt = f"You are given a statement from an offensive language dataset that was misclassified by an offensive language detection system. In addition, you are given a list of features generated by an LLM for other statements that were misclassified. Make the minimal changes so the list also applies to the given statement. Maintain the same format * Error_Feature: where Error_Feature: is a two word discription of the feature and Explanation is a one sentence explanation of the feature. Make sure to stick to the format specified. Avoid making explicit references to the examples and use layman terms for the explanations. " prompt = f"Statement: {leaf.content}\npredicted_label: {leaf.predicted_label}\nactual_label: {leaf.actual_label}\n\nList: {list}\n\nUpdated list: " if applies(leaf, list): return 'edited', get_answer(prompt, system_prompt) else: return 'not edited', list def get_bullet_points(list): #split on new line to get the individual bullet points return list.split('\n') def construct_bipartite_graph(bullet_list_0, bullet_list_1): bipartite_graph = [] for first in bullet_list_0: #o(n) for second in bullet_list_1: #o(m) if ((first, second) and (second, first)) not in bipartite_graph: #check pair is not already in list, order doesn't matter, o(k) bipartite_graph.append((first,second)) return bipartite_graph def sbert_embeddings(bipartite_graph): sbert_bipartite_graph = [] for pair in bipartite_graph: first = sbert_model.encode(pair[0]) second = sbert_model.encode(pair[1]) sbert_bipartite_graph.append((first, second)) return sbert_bipartite_graph def compute_cosine_similarity(sbert_bipartite_embeddings): cosine_similarity = [] for pair in sbert_bipartite_embeddings: similarity = 1 - cosine(pair[0], pair[1]) cosine_similarity.append(similarity) return cosine_similarity def combine(cosine_similarity, bipartite_graph, similarity_threshold): pairs_to_combine = [] for index in range(len(cosine_similarity)): if cosine_similarity[index] > similarity_threshold: #there needs to be a different threshold/criteria pairs_to_combine.append(bipartite_graph[index]) return pairs_to_combine #check the overlap between the two lists def overlap(list_0, list_1, overlap_threshold = 0.5, similarity_threshold = 0.75): #step 0: separate the list to individual bullet points bullet_list_0 = get_bullet_points(list_0) bullet_list_1 = get_bullet_points(list_1) #step 1: construct a bipartite graph bipartite_graph = construct_bipartite_graph(bullet_list_0, bullet_list_1) #step 2: compute the sbert embeddings sbert_bipartite_graph = sbert_embeddings(bipartite_graph) #step 3: calculate the cosine similarity cosine_similarity = compute_cosine_similarity(sbert_bipartite_graph) #step 4: if similarity above threshold -> combine otherwise leave as separate pairs_to_combine = combine(cosine_similarity, bipartite_graph, similarity_threshold) #step 5: increment overlap score overlap_score = len(pairs_to_combine) / len(bipartite_graph) #step 6: if score is more than overlap_threshold -> pair should be combined (save this pair) return overlap_score > overlap_threshold, bipartite_graph, pairs_to_combine def union(bipartite_graph, pairs_to_combine): #to get the union #step 0: remove the pairs_to_combine from bipartite_graph bipartite_graph = [pair for pair in bipartite_graph if pair not in pairs_to_combine] #step 1: remove all the pairs where one of the elements is also in pairs_to_combine #step 1.1: convert pair_to_combine to a set distinct_features = set() for pair in pairs_to_combine: distinct_features.add(pair[0]) distinct_features.add(pair[1]) #step 1.2: remove the any pairs that have elements in pairs_to_combine bipartite_graph = [pair for pair in bipartite_graph if pair[0] or pair[1] not in distinct_features] dont_combine = set() for pair in bipartite_graph: dont_combine.add(pair[0]) dont_combine.add(pair[1]) return dont_combine #take the union between the lists def list_union(bipartite_graph, pairs_to_combine): dont_combine = union(bipartite_graph, pairs_to_combine) union_list = '\n'.join(dont_combine) system_prompt = f"You are given two bullet points generated to explain similarities between statements. You are tasked to combine these two bullet points into one. Make sure to maintain the same format * Error_Feature: where Error_Feature: is a two word discription of the feature and Explanation is a one sentence explanation of the feature. Make sure to stick to the format specified." for pair in pairs_to_combine: prompt = f"First point: {pair[0]}\n\nSecond point: {pair[1]}\n\nNew point: " union_list += get_answer(prompt, system_prompt) + '\n' return union_list #read data dataset = pd.read_json("improved_english/clusters/baseline_sb.json") original_dataset = pd.read_csv('clusters/mhs_lhs_errors.csv') #this is using one of the sbert clustering but we just want to the embeddings and content (maybe labels) dataset.drop(['slice', 'centroid', 'cluster'], inplace=True, axis=1) dataset = add_labels(dataset, original_dataset) dataset = dataset.rename(columns={'label_y':'actual_label', 'pred_label':'predicted_label'}) #generate the hierarchical tree mergings = linkage(np.array(dataset.embedding.to_list()), method='complete', metric='cosine') #convert to a tree to traverse root, nodelist = hierarchy.to_tree(mergings, rd = True) #construct tree using examples tree = reconstruct_tree(mergings, dataset.content.to_list()) client = Groq( api_key=os.environ.get("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"), ) #load sbert model sbert_model = SentenceTransformer('all-distilroberta-v1') #store the intermediate steps intermediate_steps = [] end_summaries = [] for id, node in tqdm(enumerate(mergings)): #first case: if both are leaves -> send to llm if node[0] <= len(mergings) and node[1] <= len(mergings): #pass two leaves leaf_0 = dataset.iloc[[int(node[0])]] leaf_1 = dataset.iloc[[int(node[1])]] leaf_list = merge_two_leaves(leaf_0, leaf_1) current = {'id': int(id + len(mergings) + 1), #index in mergings DS + number of clusters idk if this correct 'examples': [[leaf_0.content, leaf_0.predicted_label, leaf_0.actual_label, int(node[0])], [leaf_1.content, leaf_1.predicted_label, leaf_1.actual_label, int(node[1])]], 'bullet_list': leaf_list, 'edited': 'both are leaves', 'previous_list': 'base list'} #second case: if you're merging a leaf to a merged cluster -> send to check if it applies elif (node[0] >= len(mergings)) ^ (node[1] >= len(mergings)): #use the cluster id of the merged list to get the previous list if node[0] <= len(mergings): leaf = dataset.iloc[[int(node[0])]] previous_list = int(node[1]) #this is the id leaf_id = int(node[0]) else: #I don't think this will ever be executed idk leaf = dataset.iloc[[int(node[1])]] previous_list = int(node[0]) #this is the id leaf_id = int(node[1]) previous_bullet_list = next(item for item in intermediate_steps if item['id'] == previous_list) previous_bullet_list = previous_bullet_list['bullet_list'] #pass the previous list edited, merged_leaf = merge_leaf(leaf, previous_bullet_list) #hyperparameter threshold (how many points apply to the example) #store the list and examples and verdict so they're easy to retrieve current = {'id':int(id + len(mergings) + 1), #index in mergings DS + number of clusters 'examples': [leaf.content, leaf.predicted_label, leaf.actual_label, leaf_id], 'bullet_list': merged_leaf, 'edited': edited, 'previous_list':previous_list } #third case: merging to clusters else: #get the list generated at each node list_0 = next(item for item in intermediate_steps if item['id'] == int(node[0])) list_0_id = list_0['bullet_list'] list_1 = next(item for item in intermediate_steps if item['id'] == int(node[1])) list_1_id = list_1['bullet_list'] #if there is "enough" overlap merge the cluster enough, bipartite_graph, pairs_to_combine = overlap(list_0_id,list_1_id) if enough: union_list = list_union(bipartite_graph, pairs_to_combine) current = {'id':int(id + len(mergings) + 1), #index in mergings DS + number of clusters 'examples': [list_0['id'], list_1['id']], 'bullet_list': union_list, 'edited': 'merging two clusters', 'previous_list':'enough overlap to merge' } #not enough overlap, separate into two clusters else: print('not merging') end_summaries.append(list_0) end_summaries.append(list_1) intermediate_steps.append(current) if id == 118: break intermediate_steps = pd.DataFrame(intermediate_steps) intermediate_steps.to_json('intermediate_steps.json', orient='records', indent=4) end_summaries = pd.DataFrame(end_summaries) end_summaries.to_json('end_summaries.json', orient='records', indent=4)