# handler.py import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification from typing import Dict, Any, List import logging logger = logging.getLogger(__name__) class EndpointHandler: def __init__(self, path=""): # Load tokenizer and model directly self.tokenizer = AutoTokenizer.from_pretrained(path) try: self.model = AutoModelForSequenceClassification.from_pretrained( path, device_map="auto" # For GPU acceleration ) except: # Fallback to CPU if device_map fails self.model = AutoModelForSequenceClassification.from_pretrained(path) self.model.to(torch.device("cuda" if torch.cuda.is_available() else "cpu")) self.model.eval() def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: try: # Extract segment_id from input segment_id = data.get("segment_id", "unknown_segment") streamer_id = data.get("streamer_id", "unknown_streamer") text_input = data["inputs"] # Tokenize input inputs = self.tokenizer( data["inputs"], truncation=True, max_length=512, padding="max_length", return_tensors="pt" ).to(self.model.device) # Inference with torch.no_grad(): outputs = self.model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) conf, pred = torch.max(probs, dim=1) return [{ "segment_id": segment_id, "streamer_id" : streamer_id, "prediction": self.model.config.id2label[pred.item()], "confidence": round(conf.item(), 4) }] except Exception as e: logger.error(f"Error: {str(e)}") return [{ "segment_id": segment_id, "streamer_id" : streamer_id, "error": str(e) }]