fpopov1993 commited on
Commit
14f74d2
·
verified ·
1 Parent(s): 3e071a5

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +49 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
4
+
5
+
6
+ # Cache the model loading to speed up app restarts.
7
+ @st.cache_resource
8
+ def load_model_and_tokenizer():
9
+ model = DistilBertForSequenceClassification.from_pretrained(
10
+ "./results/checkpoint-1980"
11
+ )
12
+ tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-cased")
13
+ model.eval()
14
+ return model, tokenizer
15
+
16
+
17
+ model, tokenizer = load_model_and_tokenizer()
18
+
19
+
20
+ def classify_text(text: str) -> str:
21
+ """Tokenize the text and run inference."""
22
+ encoding = tokenizer(
23
+ text, return_tensors="pt", padding=True, truncation=True, max_length=128
24
+ )
25
+ with torch.no_grad():
26
+ outputs = model(**encoding)
27
+ logits = outputs.logits
28
+ predicted_class_id = torch.argmax(logits, dim=1).item()
29
+ id2label = model.config.id2label # Assumes id2label was set during training.
30
+ predicted_label = (
31
+ id2label[predicted_class_id] if id2label else str(predicted_class_id)
32
+ )
33
+ return predicted_label
34
+
35
+
36
+ # Build the Streamlit interface.
37
+ st.title("Text Classification with DistilBERT")
38
+
39
+ st.write("Enter text in the box below and click 'Classify' to see the predicted label.")
40
+
41
+ # Text input area.
42
+ user_text = st.text_area("Input Text", "")
43
+
44
+ if st.button("Classify"):
45
+ if user_text.strip() == "":
46
+ st.error("Please enter some text to classify.")
47
+ else:
48
+ predicted_label = classify_text(user_text)
49
+ st.success(f"Predicted label: **{predicted_label}**")
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ streamlit
2
+ transformers
3
+ torch