SCANSKY commited on
Commit
10789df
·
verified ·
1 Parent(s): 4179781

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +11 -14
handler.py CHANGED
@@ -1,4 +1,5 @@
1
  from transformers import pipeline
 
2
  import joblib
3
  import torch
4
  import os
@@ -8,7 +9,7 @@ print("Current working directory:", os.getcwd())
8
  print("Contents of the directory:", os.listdir())
9
 
10
  # Load the label encoder
11
- label_encoder = joblib.load('label_encoder.pkl') # Ensure the file is in the correct path
12
  print("Label encoder loaded successfully.")
13
 
14
  # Load the model and tokenizer from Hugging Face
@@ -39,16 +40,14 @@ def get_average_sentiment(positive_count, negative_count, neutral_count):
39
  return "neutral"
40
 
41
  class EndpointHandler:
42
- def __init__(self):
43
- # No need to load the model here since it's loaded globally
 
44
  pass
45
 
46
  def preprocess(self, data):
47
  # Extract the input text from the request
48
- if isinstance(data, dict):
49
- text = data.get("inputs", "")
50
- else:
51
- text = data # Fallback if data is not a dictionary
52
  return text
53
 
54
  def inference(self, text):
@@ -116,16 +115,14 @@ class EndpointHandler:
116
 
117
  def postprocess(self, output):
118
  if "error" in output:
119
- return {"error": output["error"]}
120
 
121
- # Return the full output
122
- return output
 
123
 
124
  def __call__(self, data):
125
  # Main method to handle the request
126
  text = self.preprocess(data)
127
  output = self.inference(text)
128
- return self.postprocess(output)
129
-
130
- # Create an instance of the handler
131
- handler = EndpointHandler()
 
1
  from transformers import pipeline
2
+ from sklearn.preprocessing import LabelEncoder
3
  import joblib
4
  import torch
5
  import os
 
9
  print("Contents of the directory:", os.listdir())
10
 
11
  # Load the label encoder
12
+ label_encoder = joblib.load('/repository/label_encoder.pkl') # Use absolute path
13
  print("Label encoder loaded successfully.")
14
 
15
  # Load the model and tokenizer from Hugging Face
 
40
  return "neutral"
41
 
42
  class EndpointHandler:
43
+ def __init__(self, model_dir=None):
44
+ # Model and tokenizer are loaded globally, so no need to reinitialize here
45
+ # The `model_dir` argument is required by Hugging Face's inference toolkit
46
  pass
47
 
48
  def preprocess(self, data):
49
  # Extract the input text from the request
50
+ text = data.get("inputs", "")
 
 
 
51
  return text
52
 
53
  def inference(self, text):
 
115
 
116
  def postprocess(self, output):
117
  if "error" in output:
118
+ return [{"error": output["error"]}]
119
 
120
+ # Return only the line-level results as a list
121
+ return output["line_results"]
122
+
123
 
124
  def __call__(self, data):
125
  # Main method to handle the request
126
  text = self.preprocess(data)
127
  output = self.inference(text)
128
+ return self.postprocess(output)