Update handler.py
Browse files- handler.py +11 -14
handler.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from transformers import pipeline
|
|
|
2 |
import joblib
|
3 |
import torch
|
4 |
import os
|
@@ -8,7 +9,7 @@ print("Current working directory:", os.getcwd())
|
|
8 |
print("Contents of the directory:", os.listdir())
|
9 |
|
10 |
# Load the label encoder
|
11 |
-
label_encoder = joblib.load('label_encoder.pkl') #
|
12 |
print("Label encoder loaded successfully.")
|
13 |
|
14 |
# Load the model and tokenizer from Hugging Face
|
@@ -39,16 +40,14 @@ def get_average_sentiment(positive_count, negative_count, neutral_count):
|
|
39 |
return "neutral"
|
40 |
|
41 |
class EndpointHandler:
|
42 |
-
def __init__(self):
|
43 |
-
#
|
|
|
44 |
pass
|
45 |
|
46 |
def preprocess(self, data):
|
47 |
# Extract the input text from the request
|
48 |
-
|
49 |
-
text = data.get("inputs", "")
|
50 |
-
else:
|
51 |
-
text = data # Fallback if data is not a dictionary
|
52 |
return text
|
53 |
|
54 |
def inference(self, text):
|
@@ -116,16 +115,14 @@ class EndpointHandler:
|
|
116 |
|
117 |
def postprocess(self, output):
|
118 |
if "error" in output:
|
119 |
-
return {"error": output["error"]}
|
120 |
|
121 |
-
# Return the
|
122 |
-
return output
|
|
|
123 |
|
124 |
def __call__(self, data):
|
125 |
# Main method to handle the request
|
126 |
text = self.preprocess(data)
|
127 |
output = self.inference(text)
|
128 |
-
return self.postprocess(output)
|
129 |
-
|
130 |
-
# Create an instance of the handler
|
131 |
-
handler = EndpointHandler()
|
|
|
1 |
from transformers import pipeline
|
2 |
+
from sklearn.preprocessing import LabelEncoder
|
3 |
import joblib
|
4 |
import torch
|
5 |
import os
|
|
|
9 |
print("Contents of the directory:", os.listdir())
|
10 |
|
11 |
# Load the label encoder
|
12 |
+
label_encoder = joblib.load('/repository/label_encoder.pkl') # Use absolute path
|
13 |
print("Label encoder loaded successfully.")
|
14 |
|
15 |
# Load the model and tokenizer from Hugging Face
|
|
|
40 |
return "neutral"
|
41 |
|
42 |
class EndpointHandler:
|
43 |
+
def __init__(self, model_dir=None):
|
44 |
+
# Model and tokenizer are loaded globally, so no need to reinitialize here
|
45 |
+
# The `model_dir` argument is required by Hugging Face's inference toolkit
|
46 |
pass
|
47 |
|
48 |
def preprocess(self, data):
|
49 |
# Extract the input text from the request
|
50 |
+
text = data.get("inputs", "")
|
|
|
|
|
|
|
51 |
return text
|
52 |
|
53 |
def inference(self, text):
|
|
|
115 |
|
116 |
def postprocess(self, output):
|
117 |
if "error" in output:
|
118 |
+
return [{"error": output["error"]}]
|
119 |
|
120 |
+
# Return only the line-level results as a list
|
121 |
+
return output["line_results"]
|
122 |
+
|
123 |
|
124 |
def __call__(self, data):
|
125 |
# Main method to handle the request
|
126 |
text = self.preprocess(data)
|
127 |
output = self.inference(text)
|
128 |
+
return self.postprocess(output)
|
|
|
|
|
|