AeroArtz's picture
Update handler.py
2f6d218 verified
raw
history blame
2.14 kB
# handler.py
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from typing import Dict, Any, List
import logging
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path=""):
# Load tokenizer and model directly
self.tokenizer = AutoTokenizer.from_pretrained(path)
try:
self.model = AutoModelForSequenceClassification.from_pretrained(
path,
device_map="auto" # For GPU acceleration
)
except:
# Fallback to CPU if device_map fails
self.model = AutoModelForSequenceClassification.from_pretrained(path)
self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu"))
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
try:
# Extract segment_id from input
segment_id = data.get("segment_id", "unknown_segment")
streamer_id = data.get("streamer_id", "unknown_streamer")
text_input = data["inputs"]
# Tokenize input
inputs = self.tokenizer(
data["inputs"],
truncation=True,
max_length=512,
padding="max_length",
return_tensors="pt"
).to(self.model.device)
# Inference
with torch.no_grad():
outputs = self.model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
conf, pred = torch.max(probs, dim=1)
return [{
"segment_id": segment_id,
"streamer_id" : streamer_id,
"prediction": self.model.config.id2label[pred.item()],
"confidence": round(conf.item(), 4)
}]
except Exception as e:
logger.error(f"Error: {str(e)}")
return [{
"segment_id": segment_id,
"streamer_id" : streamer_id,
"error": str(e)
}]