|
--- |
|
language: zh |
|
tags: |
|
- bert |
|
- multilabel-classification |
|
- chinese |
|
- intent-classification |
|
- time-lbs |
|
base_model: |
|
- google-bert/bert-base-chinese |
|
--- |
|
|
|
# 中文多标签意图识别模型(BERT) |
|
|
|
这是一个基于 `bert-base-chinese` 微调的多标签分类模型,支持以下任务: |
|
|
|
对中文query进行分类 |
|
- 多分类:意图识别(chat / simple question / complex question) |
|
- 二分类:是否时间相关、是否位置(LBS)相关 |
|
|
|
## 模型结构 |
|
|
|
- 基础模型:[`bert-base-chinese`](https://huggingface.co/bert-base-chinese) |
|
- 输出层:一个 5 维的 sigmoid 多标签输出向量 |
|
- `[意图-chat, 意图-simple, 意图-complex, 是否时间相关, 是否LBS相关]` |
|
|
|
## 使用方法 |
|
|
|
```python |
|
import torch |
|
from transformers import BertTokenizer |
|
from bert_classifier_3 import BertMultiLabelClassifier |
|
|
|
# 加载 tokenizer 和模型 |
|
bert_base = "bert-base-chinese" |
|
model_id = "Xiaoxi2333/bert_multilabel_chinese" |
|
tokenizer = BertTokenizer.from_pretrained(model_id) |
|
model = BertMultiLabelClassifier(pretrained_model_path=bert_base, num_labels=5) |
|
state_dict = torch.hub.load_state_dict_from_url( |
|
f"https://huggingface.co/{model_id}/resolve/main/pytorch_model.bin", |
|
map_location="cpu" |
|
) |
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
|
|
# 定义标签 |
|
intent_labels = ["chat", "simple question", "complex question"] |
|
yesno_labels = ["否", "是"] |
|
|
|
# 定义预测函数 |
|
def predict(query): |
|
enc = tokenizer( |
|
query, |
|
truncation=True, |
|
padding="max_length", |
|
max_length=128, |
|
return_tensors="pt" |
|
) |
|
with torch.no_grad(): |
|
logits = model(enc["input_ids"], enc["attention_mask"]) |
|
probs = torch.sigmoid(logits).squeeze(0) |
|
intent_index = torch.argmax(probs[:3]).item() |
|
is_time = int(probs[3] > 0.5) |
|
is_lbs = int(probs[4] > 0.5) |
|
|
|
return { |
|
"query": query, |
|
"意图": intent_labels[intent_index], |
|
"是否时间相关": yesno_labels[is_time], |
|
"是否lbs相关": yesno_labels[is_lbs], |
|
"原始概率": probs.tolist() |
|
} |
|
|
|
# 示例查询 |
|
result = predict("明天北京天气怎么样?") |
|
print(result) |