app / app.py
uogoit's picture
Upload app.py
4748976 verified
import gradio as gr
import torch
from transformers import (
AutoConfig,
AutoModelForSequenceClassification,
AutoTokenizer,
)
# ===== 基本配置 =====
MODEL_DIR = "my-bert-model"
MAX_LENGTH = 512
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ===== 加载模型 =====
config = AutoConfig.from_pretrained(
MODEL_DIR,
num_labels=3,
finetuning_task="text-classification",
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR)
model = AutoModelForSequenceClassification.from_pretrained(
MODEL_DIR,
config=config
).to(device)
model.eval()
# 若未定义 id2label,则自动生成
if not hasattr(model.config, "id2label") or not model.config.id2label:
model.config.id2label = {i: f"LABEL_{i}" for i in range(model.config.num_labels)}
# ===== 推理函数 =====
def inference(input_text: str) -> str:
if not input_text or not input_text.strip():
return "Empty input."
inputs = tokenizer(
input_text,
max_length=MAX_LENGTH,
truncation=True,
padding="max_length",
return_tensors="pt",
)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
predicted_class_id = torch.argmax(logits, dim=-1).item()
label = model.config.id2label.get(predicted_class_id, str(predicted_class_id))
return label
# ===== Gradio 界面 =====
demo = gr.Interface(
fn=inference,
inputs=gr.Textbox(
label="Input Text",
placeholder="Enter text to classify...",
lines=5,
),
outputs=gr.Textbox(label="Predicted Label"),
examples=[
["My last two weather pics from the storm on August 2nd. People packed up real fast after the temp dropped and winds picked up."],
["Lying Clinton sinking! Donald Trump singing: Let's Make America Great Again!"],
],
title="BERT-based Text Classification",
description="A text classification demo powered by a fine-tuned BERT model.",
)
# ===== 启动 =====
if __name__ == "__main__":
demo.launch(
debug=False,
server_name="0.0.0.0",
server_port=7860,
)