import json
import torch
from transformers import BertTokenizerFast, BertForTokenClassification
import gradio as gr

# Initialize tokenizer and model
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('maximuspowers/bias-detection-ner')
model.eval()
model.to('cuda' if torch.cuda.is_available() else 'cpu')

# Mapping IDs to labels
id2label = {
    0: 'O',
    1: 'B-STEREO',
    2: 'I-STEREO',
    3: 'B-GEN',
    4: 'I-GEN',
    5: 'B-UNFAIR',
    6: 'I-UNFAIR'
}

# Entity colors for highlights
label_colors = {
    "STEREO": "rgba(255, 0, 0, 0.2)",  # Light Red
    "GEN": "rgba(0, 0, 255, 0.2)",     # Light Blue
    "UNFAIR": "rgba(0, 255, 0, 0.2)"   # Light Green
}

# Post-process entity tags
def post_process_entities(result):
    prev_entity_type = None
    for token_data in result:
        labels = token_data["labels"]

        # Handle sequence rules
        new_labels = []
        for label_data in labels:
            label = label_data['label']
            if label.startswith("B-") and prev_entity_type == label[2:]:
                new_labels.append({"label": f"I-{label[2:]}", "confidence": label_data["confidence"]})
            elif label.startswith("I-") and prev_entity_type != label[2:]:
                new_labels.append({"label": f"B-{label[2:]}", "confidence": label_data["confidence"]})
            else:
                new_labels.append(label_data)
            prev_entity_type = label[2:]
        token_data["labels"] = new_labels
    return result

# Generate HTML matrix and JSON results with probabilities
def predict_ner_tags_with_json(sentence):
    inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
    input_ids = inputs['input_ids'].to(model.device)
    attention_mask = inputs['attention_mask'].to(model.device)

    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        probabilities = torch.sigmoid(logits)

    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    result = []
    for i, token in enumerate(tokens):
        if token not in tokenizer.all_special_tokens:
            label_indices = (probabilities[0][i] > 0.52).nonzero(as_tuple=False).squeeze(-1)
            labels = [
                {
                    "label": id2label[idx.item()],
                    "confidence": round(probabilities[0][i][idx].item() * 100, 2)
                }
                for idx in label_indices
            ]
            result.append({"token": token.replace("##", ""), "labels": labels})

    result = post_process_entities(result)

    # Create table rows
    word_row = []
    stereo_row = []
    gen_row = []
    unfair_row = []

    for token_data in result:
        token = token_data["token"]
        labels = token_data["labels"]

        word_row.append(f"<span style='font-weight:bold;'>{token}</span>")

        # STEREO
        stereo_labels = [
            f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "STEREO" in label_data["label"]
        ]
        stereo_row.append(
            f"<span style='background:{label_colors['STEREO']}; border-radius:6px; padding:2px 5px;'>{', '.join(stereo_labels)}</span>"
            if stereo_labels else "&nbsp;"
        )

        # GEN
        gen_labels = [
            f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "GEN" in label_data["label"]
        ]
        gen_row.append(
            f"<span style='background:{label_colors['GEN']}; border-radius:6px; padding:2px 5px;'>{', '.join(gen_labels)}</span>"
            if gen_labels else "&nbsp;"
        )

        # UNFAIR
        unfair_labels = [
            f"{label_data['label'][2:]} ({label_data['confidence']}%)" for label_data in labels if "UNFAIR" in label_data["label"]
        ]
        unfair_row.append(
            f"<span style='background:{label_colors['UNFAIR']}; border-radius:6px; padding:2px 5px;'>{', '.join(unfair_labels)}</span>"
            if unfair_labels else "&nbsp;"
        )

    matrix_html = f"""
    <table style='border-collapse:collapse; width:100%; font-family:monospace; text-align:left;'>
        <tr>
            <td><strong>Text Sequence</strong></td>
            {''.join(f"<td>{word}</td>" for word in word_row)}
        </tr>
        <tr>
            <td><strong>Generalizations</strong></td>
            {''.join(f"<td>{cell}</td>" for cell in gen_row)}
        </tr>
        <tr>
            <td><strong>Unfairness</strong></td>
            {''.join(f"<td>{cell}</td>" for cell in unfair_row)}
        </tr>
        <tr>
            <td><strong>Stereotypes</strong></td>
            {''.join(f"<td>{cell}</td>" for cell in stereo_row)}
        </tr>
    </table>
    """

    # JSON string
    json_result = json.dumps(result, indent=4)

    return f"{matrix_html}<br><pre>{json_result}</pre>"

# Gradio Interface
iface = gr.Blocks()

with iface:
    with gr.Row():
        gr.Markdown(
            """
            # GUS-Net 🕵
            [GUS-Net](https://huggingface.co/ethical-spectacle/social-bias-ner) is a `BertForTokenClassification` based model, trained on the [GUS dataset](https://huggingface.co/datasets/ethical-spectacle/gus-dataset-v1). It preforms multi-label named-entity recognition of socially biased entities, intended to reveal the underlying structure of bias rather than a one-size fits all definition.

            You can find the full collection of resources introduced in our paper [here](https://huggingface.co/collections/ethical-spectacle/gus-net-66edfe93801ea45d7a26a10f).

            This [blog post](https://huggingface.co/blog/maximuspowers/bias-entity-recognition) walks through the training and architecture of the model.

            Enter a sentence for named-entity recognition of biased entities:
            - **Generalizations (GEN)**
            - **Unfairness (UNFAIR)**
            - **Stereotypes (STEREO)**

            Labels follow the BIO format. Try it out:
            """
        )
    with gr.Row():
        input_box = gr.Textbox(label="Input Sentence")
    with gr.Row():
        output_box = gr.HTML(label="Entity Matrix and JSON Output")

    input_box.change(predict_ner_tags_with_json, inputs=[input_box], outputs=[output_box])

iface.launch(share=True)