distilbert-base-uncased-measuring-hate-speech / hierarchical_summarization.py
wetey
english trained model
78dac5f
raw
history blame
15.1 kB
from groq import Groq
import pandas as pd
import os
from scipy.cluster.hierarchy import linkage
import numpy as np
from scipy.cluster import hierarchy
from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from scipy.spatial.distance import cosine
def convert_labels(dataset):
labels = dataset.label.unique()
labels_mapping = {}
for label in labels:
pred = dataset.loc[dataset['label'] == label, 'label_y'].values[0]
labels_mapping[label] = pred
for label in labels_mapping:
dataset.loc[dataset.pred == label, 'pred_label'] = labels_mapping[label]
return dataset
def add_labels(dataset, original_dataset):
original_dataset = original_dataset.rename(columns={'text':'content'})
original_dataset = original_dataset[['content', 'label_y']]
#content column to compare with original dataset
dataset_content = dataset.content
#retrieve all columns for each row in test set from original dataset
subset_original_content = original_dataset.loc[original_dataset.content.isin(dataset_content)]
#merge dataframes
dataset = pd.merge(dataset, subset_original_content, on = 'content')
dataset = convert_labels(dataset)
return dataset
def tree_depth(node):
if node is None:
return 0
else:
left_depth = tree_depth(node.get_left())
right_depth = tree_depth(node.get_right())
return max(left_depth, right_depth) + 1
def reconstruct_tree(mergings, content):
tree = {}
for i, merge in enumerate(mergings):
#these are the leaves, they'll have an index less than the number of examples
if merge[0] <= len(mergings):
a = content[int(merge[0]) - 1]
else:
#if here then that's a merged cluster
a = tree[int(merge[0])]
#these are the leaves, they'll have an index less than the number of examples
if merge[1] <= len(mergings):
b = content[int(merge[1]) - 1]
else:
#if here then that's a merged cluster
b = tree[int(merge[1])]
tree[1 + i + len(mergings)] = [a,b]
return tree
#remove nested lists in branches and put all nodes in a 1-D list
def flatten(dict_list):
flat_list = []
for item in dict_list:
if isinstance(item, list):
flat_list.extend(flatten(item))
else:
flat_list.append(item)
return flat_list
#pass prompt to llm
def get_answer(prompt, system_prompt):
chat_completion = client.chat.completions.create(
messages=[
{
"role":"system",
"content":f"{system_prompt}"
},
{
"role": "user",
"content": f"{prompt}",
}
],
model="mixtral-8x7b-32768"
)
return chat_completion.choices[0].message.content
#pass two leaves to the llm and get a list of their similarities
#leaf will have the content, predicted label, and actual label
def merge_two_leaves(leaf_0, leaf_1):
system_prompt =f'You are given two statements from an offensive language dataset that were misclassified by an offensive language detection system. Analyze the two statements thoroughly and provide a bullet list explanation of the similarities between the two statements. Your list should have the following format: * Error_Feature: <Explanation> where Error_Feature: is a two word discription of the feature and Explanation is a one sentence explanation of the feature. Make sure to stick to the format specified. Avoid making explicit references to the examples and use layman terms for the explanations.'
prompt = f'Statement_0: {leaf_0.content}\npredicted_label_0: {leaf_0.predicted_label}\nactual_label_0: {leaf_0.actual_label}\n\nStatement_1: {leaf_1.content}\npredicted_label_1: {leaf_1.predicted_label}\nactual_label_1: {leaf_1.actual_label}\n\nList: '
#pass prompts to llm and get answer
base_list = get_answer(prompt, system_prompt)
return base_list
#case 2: pass leaf (with content, predicted label, and actual label) and the list previously generated
def applies(leaf, list, threshold = 2): #another way to do this would be to split the list and check every bullet points -> more api calls
system_prompt = f"You are given a statement from an offensive language dataset that was misclassified by an offensive language detection system. In addition, you are given a list of features generated by an LLM for other statements that were misclassified. Perform a thorough analysis of the statement and the list. If at least {threshold} points apply to the statement return YES otherwise return NO."
prompt = f"Statement: {leaf.content}\npredicted_label: {leaf.predicted_label}\nactual_label: {leaf.actual_label}\n\nList: {list}\n\nAnswer: "
#convert answer to all lower case to avoid llm inconsistency
check = get_answer(prompt, system_prompt).lower()
#if yes return true otherwise return false
return 'yes' in check
def merge_leaf(leaf, list):
system_prompt = f"You are given a statement from an offensive language dataset that was misclassified by an offensive language detection system. In addition, you are given a list of features generated by an LLM for other statements that were misclassified. Make the minimal changes so the list also applies to the given statement. Maintain the same format * Error_Feature: <Explanation> where Error_Feature: is a two word discription of the feature and Explanation is a one sentence explanation of the feature. Make sure to stick to the format specified. Avoid making explicit references to the examples and use layman terms for the explanations. "
prompt = f"Statement: {leaf.content}\npredicted_label: {leaf.predicted_label}\nactual_label: {leaf.actual_label}\n\nList: {list}\n\nUpdated list: "
if applies(leaf, list):
return 'edited', get_answer(prompt, system_prompt)
else:
return 'not edited', list
def get_bullet_points(list):
#split on new line to get the individual bullet points
return list.split('\n')
def construct_bipartite_graph(bullet_list_0, bullet_list_1):
bipartite_graph = []
for first in bullet_list_0: #o(n)
for second in bullet_list_1: #o(m)
if ((first, second) and (second, first)) not in bipartite_graph: #check pair is not already in list, order doesn't matter, o(k)
bipartite_graph.append((first,second))
return bipartite_graph
def sbert_embeddings(bipartite_graph):
sbert_bipartite_graph = []
for pair in bipartite_graph:
first = sbert_model.encode(pair[0])
second = sbert_model.encode(pair[1])
sbert_bipartite_graph.append((first, second))
return sbert_bipartite_graph
def compute_cosine_similarity(sbert_bipartite_embeddings):
cosine_similarity = []
for pair in sbert_bipartite_embeddings:
similarity = 1 - cosine(pair[0], pair[1])
cosine_similarity.append(similarity)
return cosine_similarity
def combine(cosine_similarity, bipartite_graph, similarity_threshold):
pairs_to_combine = []
for index in range(len(cosine_similarity)):
if cosine_similarity[index] > similarity_threshold: #there needs to be a different threshold/criteria
pairs_to_combine.append(bipartite_graph[index])
return pairs_to_combine
#check the overlap between the two lists
def overlap(list_0, list_1, overlap_threshold = 0.5, similarity_threshold = 0.75):
#step 0: separate the list to individual bullet points
bullet_list_0 = get_bullet_points(list_0)
bullet_list_1 = get_bullet_points(list_1)
#step 1: construct a bipartite graph
bipartite_graph = construct_bipartite_graph(bullet_list_0, bullet_list_1)
#step 2: compute the sbert embeddings
sbert_bipartite_graph = sbert_embeddings(bipartite_graph)
#step 3: calculate the cosine similarity
cosine_similarity = compute_cosine_similarity(sbert_bipartite_graph)
#step 4: if similarity above threshold -> combine otherwise leave as separate
pairs_to_combine = combine(cosine_similarity, bipartite_graph, similarity_threshold)
#step 5: increment overlap score
overlap_score = len(pairs_to_combine) / len(bipartite_graph)
#step 6: if score is more than overlap_threshold -> pair should be combined (save this pair)
return overlap_score > overlap_threshold, bipartite_graph, pairs_to_combine
def union(bipartite_graph, pairs_to_combine):
#to get the union
#step 0: remove the pairs_to_combine from bipartite_graph
bipartite_graph = [pair for pair in bipartite_graph if pair not in pairs_to_combine]
#step 1: remove all the pairs where one of the elements is also in pairs_to_combine
#step 1.1: convert pair_to_combine to a set
distinct_features = set()
for pair in pairs_to_combine:
distinct_features.add(pair[0])
distinct_features.add(pair[1])
#step 1.2: remove the any pairs that have elements in pairs_to_combine
bipartite_graph = [pair for pair in bipartite_graph if pair[0] or pair[1] not in distinct_features]
dont_combine = set()
for pair in bipartite_graph:
dont_combine.add(pair[0])
dont_combine.add(pair[1])
return dont_combine
#take the union between the lists
def list_union(bipartite_graph, pairs_to_combine):
dont_combine = union(bipartite_graph, pairs_to_combine)
union_list = '\n'.join(dont_combine)
system_prompt = f"You are given two bullet points generated to explain similarities between statements. You are tasked to combine these two bullet points into one. Make sure to maintain the same format * Error_Feature: <Explanation> where Error_Feature: is a two word discription of the feature and Explanation is a one sentence explanation of the feature. Make sure to stick to the format specified."
for pair in pairs_to_combine:
prompt = f"First point: {pair[0]}\n\nSecond point: {pair[1]}\n\nNew point: "
union_list += get_answer(prompt, system_prompt) + '\n'
return union_list
#read data
dataset = pd.read_json("improved_english/clusters/baseline_sb.json")
original_dataset = pd.read_csv('clusters/mhs_lhs_errors.csv')
#this is using one of the sbert clustering but we just want to the embeddings and content (maybe labels)
dataset.drop(['slice', 'centroid', 'cluster'], inplace=True, axis=1)
dataset = add_labels(dataset, original_dataset)
dataset = dataset.rename(columns={'label_y':'actual_label', 'pred_label':'predicted_label'})
#generate the hierarchical tree
mergings = linkage(np.array(dataset.embedding.to_list()), method='complete', metric='cosine')
#convert to a tree to traverse
root, nodelist = hierarchy.to_tree(mergings, rd = True)
#construct tree using examples
tree = reconstruct_tree(mergings, dataset.content.to_list())
client = Groq(
api_key=os.environ.get("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"),
)
#load sbert model
sbert_model = SentenceTransformer('all-distilroberta-v1')
#store the intermediate steps
intermediate_steps = []
end_summaries = []
for id, node in tqdm(enumerate(mergings)):
#first case: if both are leaves -> send to llm
if node[0] <= len(mergings) and node[1] <= len(mergings):
#pass two leaves
leaf_0 = dataset.iloc[[int(node[0])]]
leaf_1 = dataset.iloc[[int(node[1])]]
leaf_list = merge_two_leaves(leaf_0, leaf_1)
current = {'id': int(id + len(mergings) + 1), #index in mergings DS + number of clusters idk if this correct
'examples': [[leaf_0.content,
leaf_0.predicted_label,
leaf_0.actual_label,
int(node[0])],
[leaf_1.content,
leaf_1.predicted_label,
leaf_1.actual_label,
int(node[1])]],
'bullet_list': leaf_list,
'edited': 'both are leaves',
'previous_list': 'base list'}
#second case: if you're merging a leaf to a merged cluster -> send to check if it applies
elif (node[0] >= len(mergings)) ^ (node[1] >= len(mergings)):
#use the cluster id of the merged list to get the previous list
if node[0] <= len(mergings):
leaf = dataset.iloc[[int(node[0])]]
previous_list = int(node[1]) #this is the id
leaf_id = int(node[0])
else: #I don't think this will ever be executed idk
leaf = dataset.iloc[[int(node[1])]]
previous_list = int(node[0]) #this is the id
leaf_id = int(node[1])
previous_bullet_list = next(item for item in intermediate_steps if item['id'] == previous_list)
previous_bullet_list = previous_bullet_list['bullet_list']
#pass the previous list
edited, merged_leaf = merge_leaf(leaf, previous_bullet_list) #hyperparameter threshold (how many points apply to the example)
#store the list and examples and verdict so they're easy to retrieve
current = {'id':int(id + len(mergings) + 1), #index in mergings DS + number of clusters
'examples': [leaf.content,
leaf.predicted_label,
leaf.actual_label,
leaf_id],
'bullet_list': merged_leaf,
'edited': edited,
'previous_list':previous_list
}
#third case: merging to clusters
else:
#get the list generated at each node
list_0 = next(item for item in intermediate_steps if item['id'] == int(node[0]))
list_0_id = list_0['bullet_list']
list_1 = next(item for item in intermediate_steps if item['id'] == int(node[1]))
list_1_id = list_1['bullet_list']
#if there is "enough" overlap merge the cluster
enough, bipartite_graph, pairs_to_combine = overlap(list_0_id,list_1_id)
if enough:
union_list = list_union(bipartite_graph, pairs_to_combine)
current = {'id':int(id + len(mergings) + 1), #index in mergings DS + number of clusters
'examples': [list_0['id'],
list_1['id']],
'bullet_list': union_list,
'edited': 'merging two clusters',
'previous_list':'enough overlap to merge'
}
#not enough overlap, separate into two clusters
else:
print('not merging')
end_summaries.append(list_0)
end_summaries.append(list_1)
intermediate_steps.append(current)
if id == 118:
break
intermediate_steps = pd.DataFrame(intermediate_steps)
intermediate_steps.to_json('intermediate_steps.json', orient='records', indent=4)
end_summaries = pd.DataFrame(end_summaries)
end_summaries.to_json('end_summaries.json', orient='records', indent=4)