--- license: afl-3.0 language: - en metrics: - accuracy base_model: - distilbert/distilbert-base-uncased pipeline_tag: text-classification tags: - tarot - question-detector --- DistilBERT Question Detector Model # DistilBERT 占卜问题检测模型 本项目提供了一个基于 `DistilBERT` 占卜问题检测模型,可用于判断输入文本是否为符合塔罗占卜的问题。 ## 📂 目录结构 model.safetensors: The trained model weights. config.json: The configuration file for the model architecture. tokenizer.json: The tokenizer configuration. special_tokens_map.json: The special tokens configuration. vocab.txt: The vocabulary file for the tokenizer. --- ## 🚀 快速开始 ### **1️⃣ 安装依赖** 请确保你的环境已安装 Python 3.8+,然后运行以下命令安装所需的依赖库: pip install torch transformers fastapi uvicorn safetensors ### **2️⃣ 直接运行推理** 如果你想直接在本地测试模型,可以运行 inference.py: python inference.py 示例代码(inference.py): ```python import torch from transformers import DistilBertTokenizer, DistilBertForSequenceClassification # 1. 加载模型 model_path = "./distilbert-question-detector" tokenizer = DistilBertTokenizer.from_pretrained(model_path) model = DistilBertForSequenceClassification.from_pretrained(model_path) model.eval() # 2. 进行推理 text = "Is this a question?" inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=-1) predicted_class = torch.argmax(probabilities, dim=-1).item() print(f"Probabilities: {probabilities}") print(f"Predicted class: {predicted_class}") # 1 代表是疑问句,0 代表不是 ``` ### **3️⃣ 运行 API** 你也可以使用 FastAPI 部署一个 HTTP 接口,允许其他应用通过 HTTP 请求访问模型。 uvicorn app:app --host 0.0.0.0 --port 8000 示例 API 代码(app.py): ```python from fastapi import FastAPI import torch from transformers import DistilBertTokenizer, DistilBertForSequenceClassification app = FastAPI() # 加载模型 model_path = "./distilbert-question-detector/checkpoint-5150" tokenizer = DistilBertTokenizer.from_pretrained(model_path) model = DistilBertForSequenceClassification.from_pretrained(model_path) model.eval() @app.post("/predict/") async def predict(text: str): inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probabilities = torch.nn.functional.softmax(logits, dim=-1) predicted_class = torch.argmax(probabilities, dim=-1).item() return {"text": text, "probabilities": probabilities.tolist(), "predicted_class": predicted_class} ``` API 运行后,可通过以下方式测试: ```sh curl -X 'POST' \ 'http://127.0.0.1:8000/predict/' \ -H 'Content-Type: application/json' \ -d '{"text": "Is this a valid question?"}' ``` ## 📌 结果说明 predicted_class: 0 代表输入文本是符合条件 predicted_class: 1 代表输入文本不符合条件 示例结果 ```json { "text": "Is this a valid question?", "probabilities": [[0.9266, 0.0734]], "predicted_class": 0 } ```