|
from groq import Groq |
|
import pandas as pd |
|
import os |
|
from scipy.cluster.hierarchy import linkage |
|
import numpy as np |
|
from scipy.cluster import hierarchy |
|
from tqdm import tqdm |
|
from sentence_transformers import SentenceTransformer |
|
from scipy.spatial.distance import cosine |
|
|
|
def convert_labels(dataset): |
|
labels = dataset.label.unique() |
|
labels_mapping = {} |
|
|
|
for label in labels: |
|
pred = dataset.loc[dataset['label'] == label, 'label_y'].values[0] |
|
labels_mapping[label] = pred |
|
|
|
for label in labels_mapping: |
|
dataset.loc[dataset.pred == label, 'pred_label'] = labels_mapping[label] |
|
return dataset |
|
|
|
def add_labels(dataset, original_dataset): |
|
|
|
original_dataset = original_dataset.rename(columns={'text':'content'}) |
|
original_dataset = original_dataset[['content', 'label_y']] |
|
|
|
dataset_content = dataset.content |
|
|
|
subset_original_content = original_dataset.loc[original_dataset.content.isin(dataset_content)] |
|
|
|
dataset = pd.merge(dataset, subset_original_content, on = 'content') |
|
dataset = convert_labels(dataset) |
|
|
|
return dataset |
|
|
|
|
|
def tree_depth(node): |
|
if node is None: |
|
return 0 |
|
else: |
|
left_depth = tree_depth(node.get_left()) |
|
right_depth = tree_depth(node.get_right()) |
|
return max(left_depth, right_depth) + 1 |
|
|
|
def reconstruct_tree(mergings, content): |
|
tree = {} |
|
for i, merge in enumerate(mergings): |
|
|
|
if merge[0] <= len(mergings): |
|
a = content[int(merge[0]) - 1] |
|
else: |
|
|
|
a = tree[int(merge[0])] |
|
|
|
if merge[1] <= len(mergings): |
|
b = content[int(merge[1]) - 1] |
|
else: |
|
|
|
b = tree[int(merge[1])] |
|
tree[1 + i + len(mergings)] = [a,b] |
|
return tree |
|
|
|
|
|
def flatten(dict_list): |
|
flat_list = [] |
|
for item in dict_list: |
|
if isinstance(item, list): |
|
flat_list.extend(flatten(item)) |
|
else: |
|
flat_list.append(item) |
|
return flat_list |
|
|
|
|
|
def get_answer(prompt, system_prompt): |
|
|
|
chat_completion = client.chat.completions.create( |
|
messages=[ |
|
{ |
|
"role":"system", |
|
"content":f"{system_prompt}" |
|
}, |
|
{ |
|
"role": "user", |
|
"content": f"{prompt}", |
|
} |
|
], |
|
model="mixtral-8x7b-32768" |
|
) |
|
return chat_completion.choices[0].message.content |
|
|
|
|
|
|
|
def merge_two_leaves(leaf_0, leaf_1): |
|
|
|
system_prompt =f'You are given two statements from an offensive language dataset that were misclassified by an offensive language detection system. Analyze the two statements thoroughly and provide a bullet list explanation of the similarities between the two statements. Your list should have the following format: * Error_Feature: <Explanation> where Error_Feature: is a two word discription of the feature and Explanation is a one sentence explanation of the feature. Make sure to stick to the format specified. Avoid making explicit references to the examples and use layman terms for the explanations.' |
|
|
|
prompt = f'Statement_0: {leaf_0.content}\npredicted_label_0: {leaf_0.predicted_label}\nactual_label_0: {leaf_0.actual_label}\n\nStatement_1: {leaf_1.content}\npredicted_label_1: {leaf_1.predicted_label}\nactual_label_1: {leaf_1.actual_label}\n\nList: ' |
|
|
|
|
|
base_list = get_answer(prompt, system_prompt) |
|
|
|
return base_list |
|
|
|
|
|
def applies(leaf, list, threshold = 2): |
|
|
|
system_prompt = f"You are given a statement from an offensive language dataset that was misclassified by an offensive language detection system. In addition, you are given a list of features generated by an LLM for other statements that were misclassified. Perform a thorough analysis of the statement and the list. If at least {threshold} points apply to the statement return YES otherwise return NO." |
|
|
|
prompt = f"Statement: {leaf.content}\npredicted_label: {leaf.predicted_label}\nactual_label: {leaf.actual_label}\n\nList: {list}\n\nAnswer: " |
|
|
|
|
|
check = get_answer(prompt, system_prompt).lower() |
|
|
|
|
|
return 'yes' in check |
|
|
|
def merge_leaf(leaf, list): |
|
|
|
system_prompt = f"You are given a statement from an offensive language dataset that was misclassified by an offensive language detection system. In addition, you are given a list of features generated by an LLM for other statements that were misclassified. Make the minimal changes so the list also applies to the given statement. Maintain the same format * Error_Feature: <Explanation> where Error_Feature: is a two word discription of the feature and Explanation is a one sentence explanation of the feature. Make sure to stick to the format specified. Avoid making explicit references to the examples and use layman terms for the explanations. " |
|
|
|
prompt = f"Statement: {leaf.content}\npredicted_label: {leaf.predicted_label}\nactual_label: {leaf.actual_label}\n\nList: {list}\n\nUpdated list: " |
|
|
|
if applies(leaf, list): |
|
return 'edited', get_answer(prompt, system_prompt) |
|
else: |
|
return 'not edited', list |
|
|
|
def get_bullet_points(list): |
|
|
|
return list.split('\n') |
|
|
|
def construct_bipartite_graph(bullet_list_0, bullet_list_1): |
|
bipartite_graph = [] |
|
for first in bullet_list_0: |
|
for second in bullet_list_1: |
|
if ((first, second) and (second, first)) not in bipartite_graph: |
|
bipartite_graph.append((first,second)) |
|
return bipartite_graph |
|
|
|
def sbert_embeddings(bipartite_graph): |
|
sbert_bipartite_graph = [] |
|
for pair in bipartite_graph: |
|
first = sbert_model.encode(pair[0]) |
|
second = sbert_model.encode(pair[1]) |
|
sbert_bipartite_graph.append((first, second)) |
|
return sbert_bipartite_graph |
|
|
|
def compute_cosine_similarity(sbert_bipartite_embeddings): |
|
cosine_similarity = [] |
|
for pair in sbert_bipartite_embeddings: |
|
similarity = 1 - cosine(pair[0], pair[1]) |
|
cosine_similarity.append(similarity) |
|
return cosine_similarity |
|
|
|
def combine(cosine_similarity, bipartite_graph, similarity_threshold): |
|
pairs_to_combine = [] |
|
for index in range(len(cosine_similarity)): |
|
if cosine_similarity[index] > similarity_threshold: |
|
pairs_to_combine.append(bipartite_graph[index]) |
|
return pairs_to_combine |
|
|
|
|
|
def overlap(list_0, list_1, overlap_threshold = 0.5, similarity_threshold = 0.75): |
|
|
|
bullet_list_0 = get_bullet_points(list_0) |
|
bullet_list_1 = get_bullet_points(list_1) |
|
|
|
|
|
bipartite_graph = construct_bipartite_graph(bullet_list_0, bullet_list_1) |
|
|
|
|
|
sbert_bipartite_graph = sbert_embeddings(bipartite_graph) |
|
|
|
|
|
cosine_similarity = compute_cosine_similarity(sbert_bipartite_graph) |
|
|
|
|
|
pairs_to_combine = combine(cosine_similarity, bipartite_graph, similarity_threshold) |
|
|
|
|
|
overlap_score = len(pairs_to_combine) / len(bipartite_graph) |
|
|
|
|
|
return overlap_score > overlap_threshold, bipartite_graph, pairs_to_combine |
|
|
|
def union(bipartite_graph, pairs_to_combine): |
|
|
|
|
|
|
|
bipartite_graph = [pair for pair in bipartite_graph if pair not in pairs_to_combine] |
|
|
|
|
|
|
|
distinct_features = set() |
|
for pair in pairs_to_combine: |
|
distinct_features.add(pair[0]) |
|
distinct_features.add(pair[1]) |
|
|
|
|
|
bipartite_graph = [pair for pair in bipartite_graph if pair[0] or pair[1] not in distinct_features] |
|
|
|
dont_combine = set() |
|
for pair in bipartite_graph: |
|
dont_combine.add(pair[0]) |
|
dont_combine.add(pair[1]) |
|
|
|
return dont_combine |
|
|
|
|
|
def list_union(bipartite_graph, pairs_to_combine): |
|
|
|
dont_combine = union(bipartite_graph, pairs_to_combine) |
|
union_list = '\n'.join(dont_combine) |
|
|
|
system_prompt = f"You are given two bullet points generated to explain similarities between statements. You are tasked to combine these two bullet points into one. Make sure to maintain the same format * Error_Feature: <Explanation> where Error_Feature: is a two word discription of the feature and Explanation is a one sentence explanation of the feature. Make sure to stick to the format specified." |
|
|
|
for pair in pairs_to_combine: |
|
prompt = f"First point: {pair[0]}\n\nSecond point: {pair[1]}\n\nNew point: " |
|
union_list += get_answer(prompt, system_prompt) + '\n' |
|
|
|
return union_list |
|
|
|
|
|
dataset = pd.read_json("improved_english/clusters/baseline_sb.json") |
|
original_dataset = pd.read_csv('clusters/mhs_lhs_errors.csv') |
|
|
|
|
|
dataset.drop(['slice', 'centroid', 'cluster'], inplace=True, axis=1) |
|
dataset = add_labels(dataset, original_dataset) |
|
dataset = dataset.rename(columns={'label_y':'actual_label', 'pred_label':'predicted_label'}) |
|
|
|
|
|
mergings = linkage(np.array(dataset.embedding.to_list()), method='complete', metric='cosine') |
|
|
|
|
|
root, nodelist = hierarchy.to_tree(mergings, rd = True) |
|
|
|
|
|
tree = reconstruct_tree(mergings, dataset.content.to_list()) |
|
|
|
client = Groq( |
|
api_key=os.environ.get("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"), |
|
) |
|
|
|
sbert_model = SentenceTransformer('all-distilroberta-v1') |
|
|
|
|
|
intermediate_steps = [] |
|
end_summaries = [] |
|
|
|
for id, node in tqdm(enumerate(mergings)): |
|
|
|
|
|
if node[0] <= len(mergings) and node[1] <= len(mergings): |
|
|
|
leaf_0 = dataset.iloc[[int(node[0])]] |
|
leaf_1 = dataset.iloc[[int(node[1])]] |
|
leaf_list = merge_two_leaves(leaf_0, leaf_1) |
|
current = {'id': int(id + len(mergings) + 1), |
|
'examples': [[leaf_0.content, |
|
leaf_0.predicted_label, |
|
leaf_0.actual_label, |
|
int(node[0])], |
|
[leaf_1.content, |
|
leaf_1.predicted_label, |
|
leaf_1.actual_label, |
|
int(node[1])]], |
|
'bullet_list': leaf_list, |
|
'edited': 'both are leaves', |
|
'previous_list': 'base list'} |
|
|
|
|
|
elif (node[0] >= len(mergings)) ^ (node[1] >= len(mergings)): |
|
|
|
|
|
if node[0] <= len(mergings): |
|
leaf = dataset.iloc[[int(node[0])]] |
|
previous_list = int(node[1]) |
|
leaf_id = int(node[0]) |
|
else: |
|
leaf = dataset.iloc[[int(node[1])]] |
|
previous_list = int(node[0]) |
|
leaf_id = int(node[1]) |
|
|
|
previous_bullet_list = next(item for item in intermediate_steps if item['id'] == previous_list) |
|
previous_bullet_list = previous_bullet_list['bullet_list'] |
|
|
|
|
|
edited, merged_leaf = merge_leaf(leaf, previous_bullet_list) |
|
|
|
|
|
current = {'id':int(id + len(mergings) + 1), |
|
'examples': [leaf.content, |
|
leaf.predicted_label, |
|
leaf.actual_label, |
|
leaf_id], |
|
'bullet_list': merged_leaf, |
|
'edited': edited, |
|
'previous_list':previous_list |
|
} |
|
|
|
|
|
else: |
|
|
|
list_0 = next(item for item in intermediate_steps if item['id'] == int(node[0])) |
|
list_0_id = list_0['bullet_list'] |
|
|
|
list_1 = next(item for item in intermediate_steps if item['id'] == int(node[1])) |
|
list_1_id = list_1['bullet_list'] |
|
|
|
|
|
enough, bipartite_graph, pairs_to_combine = overlap(list_0_id,list_1_id) |
|
|
|
if enough: |
|
union_list = list_union(bipartite_graph, pairs_to_combine) |
|
current = {'id':int(id + len(mergings) + 1), |
|
'examples': [list_0['id'], |
|
list_1['id']], |
|
'bullet_list': union_list, |
|
'edited': 'merging two clusters', |
|
'previous_list':'enough overlap to merge' |
|
} |
|
|
|
|
|
else: |
|
print('not merging') |
|
end_summaries.append(list_0) |
|
end_summaries.append(list_1) |
|
|
|
intermediate_steps.append(current) |
|
if id == 118: |
|
break |
|
|
|
intermediate_steps = pd.DataFrame(intermediate_steps) |
|
intermediate_steps.to_json('intermediate_steps.json', orient='records', indent=4) |
|
|
|
end_summaries = pd.DataFrame(end_summaries) |
|
end_summaries.to_json('end_summaries.json', orient='records', indent=4) |