|
|
import gradio as gr |
|
|
import torch |
|
|
from transformers import ( |
|
|
AutoConfig, |
|
|
AutoModelForSequenceClassification, |
|
|
AutoTokenizer, |
|
|
) |
|
|
|
|
|
|
|
|
MODEL_DIR = "my-bert-model" |
|
|
MAX_LENGTH = 512 |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
config = AutoConfig.from_pretrained( |
|
|
MODEL_DIR, |
|
|
num_labels=3, |
|
|
finetuning_task="text-classification", |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
|
MODEL_DIR, |
|
|
config=config |
|
|
).to(device) |
|
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
if not hasattr(model.config, "id2label") or not model.config.id2label: |
|
|
model.config.id2label = {i: f"LABEL_{i}" for i in range(model.config.num_labels)} |
|
|
|
|
|
|
|
|
def inference(input_text: str) -> str: |
|
|
if not input_text or not input_text.strip(): |
|
|
return "Empty input." |
|
|
|
|
|
inputs = tokenizer( |
|
|
input_text, |
|
|
max_length=MAX_LENGTH, |
|
|
truncation=True, |
|
|
padding="max_length", |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
inputs = {k: v.to(device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
logits = outputs.logits |
|
|
|
|
|
predicted_class_id = torch.argmax(logits, dim=-1).item() |
|
|
label = model.config.id2label.get(predicted_class_id, str(predicted_class_id)) |
|
|
|
|
|
return label |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=inference, |
|
|
inputs=gr.Textbox( |
|
|
label="Input Text", |
|
|
placeholder="Enter text to classify...", |
|
|
lines=5, |
|
|
), |
|
|
outputs=gr.Textbox(label="Predicted Label"), |
|
|
examples=[ |
|
|
["My last two weather pics from the storm on August 2nd. People packed up real fast after the temp dropped and winds picked up."], |
|
|
["Lying Clinton sinking! Donald Trump singing: Let's Make America Great Again!"], |
|
|
], |
|
|
title="BERT-based Text Classification", |
|
|
description="A text classification demo powered by a fine-tuned BERT model.", |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch( |
|
|
debug=False, |
|
|
server_name="0.0.0.0", |
|
|
server_port=7860, |
|
|
) |
|
|
|