Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,6 +5,8 @@ from transformers import pipeline
|
|
| 5 |
from joblib import load
|
| 6 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 7 |
import torch.nn.functional as F
|
|
|
|
|
|
|
| 8 |
|
| 9 |
# global variables to load models
|
| 10 |
lr_model = load("lr_model.joblib")
|
|
@@ -35,6 +37,14 @@ def predict_sentiment(text, model):
|
|
| 35 |
elif model == "custom BERT":
|
| 36 |
pred = F.softmax(bert_model(**bert_tokenizer(text, return_tensors="pt")).logits[0], dim=0).tolist()
|
| 37 |
return {"neg": pred[0], "pos": pred[1]}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
|
| 40 |
demo = gr.Blocks()
|
|
@@ -72,5 +82,12 @@ with demo:
|
|
| 72 |
["Sad frown", "custom BERT"],
|
| 73 |
]
|
| 74 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
demo.launch()
|
|
|
|
| 5 |
from joblib import load
|
| 6 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer
|
| 7 |
import torch.nn.functional as F
|
| 8 |
+
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
|
| 11 |
# global variables to load models
|
| 12 |
lr_model = load("lr_model.joblib")
|
|
|
|
| 37 |
elif model == "custom BERT":
|
| 38 |
pred = F.softmax(bert_model(**bert_tokenizer(text, return_tensors="pt")).logits[0], dim=0).tolist()
|
| 39 |
return {"neg": pred[0], "pos": pred[1]}
|
| 40 |
+
|
| 41 |
+
def plot():
|
| 42 |
+
actual = ["pos", "pos", "neg", "neg", "pos"]
|
| 43 |
+
pred = ["pos", "neg", "pos", "neg", "pos"]
|
| 44 |
+
cm = confusion_matrix(y_test, predictions)
|
| 45 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
|
| 46 |
+
disp.plot()
|
| 47 |
+
return plt.gca()
|
| 48 |
|
| 49 |
|
| 50 |
demo = gr.Blocks()
|
|
|
|
| 82 |
["Sad frown", "custom BERT"],
|
| 83 |
]
|
| 84 |
)
|
| 85 |
+
with gr.TabItem("Multiple Inputs"):
|
| 86 |
+
gr.Markdown("A more complex demo showing a plot and two outputs")
|
| 87 |
+
interface = gr.Interface(
|
| 88 |
+
plot,
|
| 89 |
+
[],
|
| 90 |
+
"image"
|
| 91 |
+
)
|
| 92 |
|
| 93 |
demo.launch()
|