import gradio as gr import torch from transformers import ( AutoConfig, AutoModelForSequenceClassification, AutoTokenizer, ) # ===== 基本配置 ===== MODEL_DIR = "my-bert-model" MAX_LENGTH = 512 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # ===== 加载模型 ===== config = AutoConfig.from_pretrained( MODEL_DIR, num_labels=3, finetuning_task="text-classification", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) model = AutoModelForSequenceClassification.from_pretrained( MODEL_DIR, config=config ).to(device) model.eval() # 若未定义 id2label,则自动生成 if not hasattr(model.config, "id2label") or not model.config.id2label: model.config.id2label = {i: f"LABEL_{i}" for i in range(model.config.num_labels)} # ===== 推理函数 ===== def inference(input_text: str) -> str: if not input_text or not input_text.strip(): return "Empty input." inputs = tokenizer( input_text, max_length=MAX_LENGTH, truncation=True, padding="max_length", return_tensors="pt", ) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits predicted_class_id = torch.argmax(logits, dim=-1).item() label = model.config.id2label.get(predicted_class_id, str(predicted_class_id)) return label # ===== Gradio 界面 ===== demo = gr.Interface( fn=inference, inputs=gr.Textbox( label="Input Text", placeholder="Enter text to classify...", lines=5, ), outputs=gr.Textbox(label="Predicted Label"), examples=[ ["My last two weather pics from the storm on August 2nd. People packed up real fast after the temp dropped and winds picked up."], ["Lying Clinton sinking! Donald Trump singing: Let's Make America Great Again!"], ], title="BERT-based Text Classification", description="A text classification demo powered by a fine-tuned BERT model.", ) # ===== 启动 ===== if __name__ == "__main__": demo.launch( debug=False, server_name="0.0.0.0", server_port=7860, )