xiaogantongxue commited on
Commit
129d2da
·
verified ·
1 Parent(s): bdc25d3

Upload 4 files

Browse files
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ from transformers import BertTokenizer, BertModel
5
+ import gradio as gr
6
+
7
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
+ id2label = {0: "负向", 1: "中性", 2: "正向"}
9
+
10
+ class BertCNNClassifier(nn.Module):
11
+ def __init__(self, bert_model, num_classes=3, dropout=0.3):
12
+ super().__init__()
13
+ self.bert = bert_model
14
+ self.conv1 = nn.Conv2d(1, 100, (3, bert_model.config.hidden_size))
15
+ self.conv2 = nn.Conv2d(1, 100, (4, bert_model.config.hidden_size))
16
+ self.conv3 = nn.Conv2d(1, 100, (5, bert_model.config.hidden_size))
17
+ self.dropout = nn.Dropout(dropout)
18
+ self.fc = nn.Linear(300, num_classes)
19
+
20
+ def forward(self, input_ids, attention_mask, token_type_ids=None):
21
+ with torch.no_grad():
22
+ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
23
+ x = outputs.last_hidden_state.unsqueeze(1)
24
+ x1 = torch.relu(self.conv1(x)).squeeze(3)
25
+ x2 = torch.relu(self.conv2(x)).squeeze(3)
26
+ x3 = torch.relu(self.conv3(x)).squeeze(3)
27
+ x1 = torch.max_pool1d(x1, x1.size(2)).squeeze(2)
28
+ x2 = torch.max_pool1d(x2, x2.size(2)).squeeze(2)
29
+ x3 = torch.max_pool1d(x3, x3.size(2)).squeeze(2)
30
+ x = torch.cat((x1, x2, x3), dim=1)
31
+ x = self.dropout(x)
32
+ return self.fc(x)
33
+
34
+ model_name = "hfl/chinese-macbert-base"
35
+ tokenizer = BertTokenizer.from_pretrained("bert_cnn_tokenizer")
36
+ bert_model = BertModel.from_pretrained(model_name)
37
+ model = BertCNNClassifier(bert_model).to(device)
38
+ model.load_state_dict(torch.load("bert_cnn_sentiment.pth", map_location=device))
39
+ model.eval()
40
+
41
+ def predict(text):
42
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128).to(device)
43
+ with torch.no_grad():
44
+ outputs = model(input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"])
45
+ pred = torch.argmax(outputs, dim=1).item()
46
+ return id2label[pred]
47
+
48
+ interface = gr.Interface(
49
+ fn=predict,
50
+ inputs=gr.Textbox(lines=3, placeholder="请输入朋友圈文案..."),
51
+ outputs="text",
52
+ title="朋友圈情绪识别",
53
+ description="输入一段朋友圈内容,判断情绪:负向 / 中性 / 正向"
54
+ )
55
+
56
+ interface.launch()
bert_cnn_sentiment.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a38a58db10a6d6b331a124ca39ca2ce2e6e7550afe3fd0b039fe8bfad0c505dc
3
+ size 412847589
bert_cnn_sentiment_project.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:853b4c78d966db5ed56344ae06e825aa7c53d46c314f2646eca187416af4a358
3
+ size 384000747
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ gradio
4
+ scikit-learn
5
+ pandas
6
+ openpyxl
7
+ tqdm