|
import json |
|
import torch |
|
from transformers import BertTokenizerFast, BertForTokenClassification |
|
import gradio as gr |
|
|
|
|
|
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') |
|
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner') |
|
model.eval() |
|
model.to('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
id2label = { |
|
0: 'O', |
|
1: 'B-STEREO', |
|
2: 'I-STEREO', |
|
3: 'B-GEN', |
|
4: 'I-GEN', |
|
5: 'B-UNFAIR', |
|
6: 'I-UNFAIR' |
|
} |
|
|
|
|
|
label_colors = { |
|
"STEREO": "rgba(255, 0, 0, 0.2)", |
|
"GEN": "rgba(0, 0, 255, 0.2)", |
|
"UNFAIR": "rgba(0, 255, 0, 0.2)" |
|
} |
|
|
|
|
|
def post_process_entities(result): |
|
prev_entity_type = None |
|
for token_data in result: |
|
labels = token_data["labels"] |
|
|
|
|
|
new_labels = [] |
|
for label_data in labels: |
|
label = label_data['label'] |
|
if label.startswith("B-") and prev_entity_type == label[2:]: |
|
new_labels.append({"label": f"I-{label[2:]}", "confidence": label_data["confidence"]}) |
|
elif label.startswith("I-") and prev_entity_type != label[2:]: |
|
new_labels.append({"label": f"B-{label[2:]}", "confidence": label_data["confidence"]}) |
|
else: |
|
new_labels.append(label_data) |
|
prev_entity_type = label[2:] |
|
token_data["labels"] = new_labels |
|
return result |
|
|
|
|
|
def predict_ner_tags_with_json(sentence): |
|
inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128) |
|
input_ids = inputs['input_ids'].to(model.device) |
|
attention_mask = inputs['attention_mask'].to(model.device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids, attention_mask=attention_mask) |
|
logits = outputs.logits |
|
probabilities = torch.sigmoid(logits) |
|
|
|
tokens = tokenizer.convert_ids_to_tokens(input_ids[0]) |
|
result = [] |
|
for i, token in enumerate(tokens): |
|
if token not in tokenizer.all_special_tokens: |
|
label_indices = (probabilities[0][i] > 0.52).nonzero(as_tuple=False).squeeze(-1) |
|
labels = [ |
|
{ |
|
"label": id2label[idx.item()], |
|
"confidence": round(probabilities[0][i][idx].item() * 100, 2) |
|
} |
|
for idx in label_indices |
|
] |
|
result.append({"token": token.replace("##", ""), "labels": labels}) |
|
|
|
result = post_process_entities(result) |
|
|
|
|
|
word_row = [] |
|
stereo_row = [] |
|
gen_row = [] |
|
unfair_row = [] |
|
|
|
for token_data in result: |
|
token = token_data["token"] |
|
labels = token_data["labels"] |
|
|
|
word_row.append(f"<span style='font-weight:bold;'>{token}</span>") |
|
|
|
|
|
stereo_labels = [ |
|
f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "STEREO" in label_data["label"] |
|
] |
|
stereo_row.append( |
|
f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>" |
|
if stereo_labels else " " |
|
) |
|
|
|
|
|
gen_labels = [ |
|
f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "GEN" in label_data["label"] |
|
] |
|
gen_row.append( |
|
f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>" |
|
if gen_labels else " " |
|
) |
|
|
|
|
|
unfair_labels = [ |
|
f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "UNFAIR" in label_data["label"] |
|
] |
|
unfair_row.append( |
|
f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>" |
|
if unfair_labels else " " |
|
) |
|
|
|
matrix_html = f""" |
|
<table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'> |
|
<tr> |
|
<td><strong>Text Sequence</strong></td> |
|
{''.join(f"<td>{word}</td>" for word in word_row)} |
|
</tr> |
|
<tr> |
|
<td><strong>Generalizations</strong></td> |
|
{''.join(f"<td>{cell}</td>" for cell in gen_row)} |
|
</tr> |
|
<tr> |
|
<td><strong>Unfairness</strong></td> |
|
{''.join(f"<td>{cell}</td>" for cell in unfair_row)} |
|
</tr> |
|
<tr> |
|
<td><strong>Stereotypes</strong></td> |
|
{''.join(f"<td>{cell}</td>" for cell in stereo_row)} |
|
</tr> |
|
</table> |
|
""" |
|
|
|
|
|
json_result = json.dumps(result, indent=4) |
|
|
|
return f"{matrix_html}<br><pre>{json_result}</pre>" |
|
|
|
|
|
iface = gr.Blocks() |
|
|
|
with iface: |
|
with gr.Row(): |
|
gr.Markdown( |
|
""" |
|
# GUS-Net 🕵 |
|
[GUS-Net](https://huggingface.co/ethical-spectacle/social-bias-ner) is a `BertForTokenClassification` based model, trained on the [GUS dataset](https://huggingface.co/datasets/ethical-spectacle/gus-dataset-v1). It preforms multi-label named-entity recognition of socially biased entities, intended to reveal the underlying structure of bias rather than a one-size fits all definition. |
|
|
|
You can find the full collection of resources introduced in our paper [here](https://huggingface.co/collections/ethical-spectacle/gus-net-66edfe93801ea45d7a26a10f). |
|
|
|
This [blog post](https://huggingface.co/blog/maximuspowers/bias-entity-recognition) walks through the training and architecture of the model. |
|
|
|
Enter a sentence for named-entity recognition of biased entities: |
|
- **Generalizations (GEN)** |
|
- **Unfairness (UNFAIR)** |
|
- **Stereotypes (STEREO)** |
|
|
|
Labels follow the BIO format. Try it out: |
|
""" |
|
) |
|
with gr.Row(): |
|
input_box = gr.Textbox(label="Input Sentence") |
|
with gr.Row(): |
|
output_box = gr.HTML(label="Entity Matrix and JSON Output") |
|
|
|
input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box]) |
|
|
|
iface.launch(share=True) |
|
|