eric707 commited on
Commit
0ebba2e
·
verified ·
1 Parent(s): bd21fd3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -0
app.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ import torch
5
+
6
+ # Initialize FastAPI app
7
+ app = FastAPI()
8
+
9
+ # Load Hugging Face model and tokenizer
10
+ MODEL_NAME = "ealvaradob/bert-finetuned-phishing"
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
13
+
14
+ # Define input structure
15
+ class TextInput(BaseModel):
16
+ text: str
17
+
18
+ @app.post("/predict")
19
+ def predict_spam(input_data: TextInput):
20
+ # Tokenize input text
21
+ inputs = tokenizer(input_data.text, return_tensors="pt", truncation=True, padding=True, max_length=512)
22
+
23
+ # Perform prediction
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+
27
+ # Get classification result
28
+ prediction = torch.argmax(outputs.logits, dim=1).item()
29
+
30
+ # Return response
31
+ return {
32
+ "text": input_data.text,
33
+ "prediction": "Phishing Email" if prediction == 1 else "Not Phishing Email"
34
+ }
35
+
36
+ # Root Endpoint
37
+ @app.get("/")
38
+ def home():
39
+ return {"message": "Welcome to the Spam Classifier API!"}