Update app.py
Browse files
app.py
CHANGED
|
@@ -22,10 +22,13 @@ def func(text):
|
|
| 22 |
with torch.no_grad():
|
| 23 |
output = bert_sc(**encoding)
|
| 24 |
scores = output.logits.argmax(-1)
|
|
|
|
|
|
|
| 25 |
|
| 26 |
label = "ネガティブ" if scores.item()==0 else "ポジティブ"
|
|
|
|
| 27 |
|
| 28 |
-
return label
|
| 29 |
|
| 30 |
-
app = gr.Interface(fn=func, inputs=gr.Textbox(lines=3, placeholder="文章を入力してください"), outputs="label", title="ビジネス文書のネガポジ分析", description=descriptions)
|
| 31 |
app.launch()
|
|
|
|
| 22 |
with torch.no_grad():
|
| 23 |
output = bert_sc(**encoding)
|
| 24 |
scores = output.logits.argmax(-1)
|
| 25 |
+
neg = torch.softmax(output.logits, dim=1).tolist()[0][0]
|
| 26 |
+
pos = torch.softmax(output.logits, dim=1).tolist()[0][1]
|
| 27 |
|
| 28 |
label = "ネガティブ" if scores.item()==0 else "ポジティブ"
|
| 29 |
+
cos = f"信頼度:{neg*100:.1f}%" if scores.item()==0 else f"信頼度:{pos*100:.1f}%"
|
| 30 |
|
| 31 |
+
return label,cos
|
| 32 |
|
| 33 |
+
app = gr.Interface(fn=func, inputs=gr.Textbox(lines=3, placeholder="文章を入力してください"), outputs=["label","label"], title="ビジネス文書のネガポジ分析", description=descriptions)
|
| 34 |
app.launch()
|