Xiaoxi2333 commited on
Commit
1760b74
·
verified ·
1 Parent(s): 096f062

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +70 -69
README.md CHANGED
@@ -1,69 +1,70 @@
1
- ---
2
- language: zh
3
- tags:
4
- - bert
5
- - multilabel-classification
6
- - chinese
7
- - intent-classification
8
- - time-lbs
9
- license: mit
10
- ---
11
-
12
- # 中文多标签意图识别模型(BERT)
13
-
14
- 这是一个基于 `bert-base-chinese` 微调的多标签分类模型,支持以下任务:
15
-
16
- 对中文query进行分类
17
- - 多分类:意图识别(chat / simple question / complex question)
18
- - 二分类:是否时间相关、是否位置(LBS)相关
19
-
20
- ## 模型结构
21
-
22
- - 基础模型:[`bert-base-chinese`](https://huggingface.co/bert-base-chinese)
23
- - 输出层:一个 5 维的 sigmoid 多标签输出向量
24
- - `[意图-chat, 意图-simple, 意图-complex, 是否时间相关, 是否LBS相关]`
25
-
26
- ## 使用方法
27
-
28
- ```python
29
- import torch
30
- from transformers import BertTokenizer
31
- from bert_classifier_3 import BertMultiLabelClassifier
32
-
33
- # 加载 tokenizer 和模型
34
- tokenizer = BertTokenizer.from_pretrained("your-username/bert-multilabel-chinese")
35
- model = BertMultiLabelClassifier(pretrained_model_path="your-username/bert-multilabel-chinese")
36
- model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
37
- model.eval()
38
-
39
- # 定义标签
40
- intent_labels = ["chat", "simple question", "complex question"]
41
- yesno_labels = ["", ""]
42
-
43
- # 定义预测函数
44
- def predict(query):
45
- enc = tokenizer(
46
- query,
47
- truncation=True,
48
- padding="max_length",
49
- max_length=128,
50
- return_tensors="pt"
51
- )
52
- with torch.no_grad():
53
- logits = model(enc["input_ids"], enc["attention_mask"])
54
- probs = torch.sigmoid(logits).squeeze(0)
55
- intent_index = torch.argmax(probs[:3]).item()
56
- is_time = int(probs[3] > 0.5)
57
- is_lbs = int(probs[4] > 0.5)
58
-
59
- return {
60
- "query": query,
61
- "意图": intent_labels[intent_index],
62
- "是否时间相关": yesno_labels[is_time],
63
- "是否lbs相关": yesno_labels[is_lbs],
64
- "原始概率": probs.tolist()
65
- }
66
-
67
- # 示例查询
68
- result = predict("明天北京天气怎么样?")
69
- print(result)
 
 
1
+ ---
2
+ language: zh
3
+ tags:
4
+ - bert
5
+ - multilabel-classification
6
+ - chinese
7
+ - intent-classification
8
+ - time-lbs
9
+ base_model:
10
+ - google-bert/bert-base-chinese
11
+ ---
12
+
13
+ # 中文多标签意图识别模型(BERT)
14
+
15
+ 这是一个基于 `bert-base-chinese` 微调的多标签分类模型,支持以下任务:
16
+
17
+ 对中文query进行分类
18
+ - 多分类:意图识别(chat / simple question / complex question)
19
+ - 二分类:是否时间相关、是否位置(LBS)相关
20
+
21
+ ## 模型结构
22
+
23
+ - 基础模型:[`bert-base-chinese`](https://huggingface.co/bert-base-chinese)
24
+ - 输出层:一个 5 维的 sigmoid 多标签输出向量
25
+ - `[意图-chat, 意图-simple, 意图-complex, 是否时间相关, 是否LBS相关]`
26
+
27
+ ## 使用方法
28
+
29
+ ```python
30
+ import torch
31
+ from transformers import BertTokenizer
32
+ from bert_classifier_3 import BertMultiLabelClassifier
33
+
34
+ # 加载 tokenizer 和模型
35
+ tokenizer = BertTokenizer.from_pretrained("Xiaoxi2333/bert-multilabel-chinese")
36
+ model = BertMultiLabelClassifier(pretrained_model_path="Xiaoxi2333/bert-multilabel-chinese")
37
+ model.load_state_dict(torch.load("pytorch_model.bin", map_location="cpu"))
38
+ model.eval()
39
+
40
+ # 定义标签
41
+ intent_labels = ["chat", "simple question", "complex question"]
42
+ yesno_labels = ["否", "是"]
43
+
44
+ # 定义预测函数
45
+ def predict(query):
46
+ enc = tokenizer(
47
+ query,
48
+ truncation=True,
49
+ padding="max_length",
50
+ max_length=128,
51
+ return_tensors="pt"
52
+ )
53
+ with torch.no_grad():
54
+ logits = model(enc["input_ids"], enc["attention_mask"])
55
+ probs = torch.sigmoid(logits).squeeze(0)
56
+ intent_index = torch.argmax(probs[:3]).item()
57
+ is_time = int(probs[3] > 0.5)
58
+ is_lbs = int(probs[4] > 0.5)
59
+
60
+ return {
61
+ "query": query,
62
+ "意图": intent_labels[intent_index],
63
+ "是否时间相关": yesno_labels[is_time],
64
+ "是否lbs相关": yesno_labels[is_lbs],
65
+ "原始概率": probs.tolist()
66
+ }
67
+
68
+ # 示例查询
69
+ result = predict("明天北京天气怎么样?")
70
+ print(result)