| import torch | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| import re | |
| import emoji | |
| from transformers import BertForSequenceClassification, BertTokenizer | |
| def preprocess_text(text): | |
| """Preprocess the input text to match training conditions.""" | |
| text = re.sub(r'u/\w+', '[USER]', text) | |
| text = re.sub(r'r/\w+', '[SUBREDDIT]', text) | |
| text = re.sub(r'http[s]?://\S+', '[URL]', text) | |
| text = emoji.demojize(text, delimiters=(" ", " ")) | |
| text = text.lower() | |
| return text | |
| def load_model_and_resources(): | |
| """Load the model, tokenizer, emotion labels, and thresholds from Hugging Face.""" | |
| repo_id = "logasanjeev/goemotions-bert" | |
| try: | |
| model = BertForSequenceClassification.from_pretrained(repo_id) | |
| tokenizer = BertTokenizer.from_pretrained(repo_id) | |
| except Exception as e: | |
| raise RuntimeError(f"Error loading model/tokenizer: {str(e)}") | |
| try: | |
| thresholds_file = hf_hub_download(repo_id=repo_id, filename="optimized_thresholds.json") | |
| with open(thresholds_file, "r") as f: | |
| thresholds_data = json.load(f) | |
| if not (isinstance(thresholds_data, dict) and "emotion_labels" in thresholds_data and "thresholds" in thresholds_data): | |
| raise ValueError("Unexpected format in optimized_thresholds.json. Expected a dictionary with keys 'emotion_labels' and 'thresholds'.") | |
| emotion_labels = thresholds_data["emotion_labels"] | |
| thresholds = thresholds_data["thresholds"] | |
| except Exception as e: | |
| raise RuntimeError(f"Error loading thresholds: {str(e)}") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model.to(device) | |
| model.eval() | |
| return model, tokenizer, emotion_labels, thresholds, device | |
| MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = None, None, None, None, None | |
| def predict_emotions(text): | |
| """Predict emotions for the given text using the GoEmotions BERT model. | |
| Args: | |
| text (str): The input text to analyze. | |
| Returns: | |
| tuple: (predictions, processed_text) | |
| - predictions (str): Formatted string of predicted emotions and their confidence scores. | |
| - processed_text (str): The preprocessed input text. | |
| """ | |
| global MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE | |
| if MODEL is None: | |
| MODEL, TOKENIZER, EMOTION_LABELS, THRESHOLDS, DEVICE = load_model_and_resources() | |
| processed_text = preprocess_text(text) | |
| encodings = TOKENIZER( | |
| processed_text, | |
| padding='max_length', | |
| truncation=True, | |
| max_length=128, | |
| return_tensors='pt' | |
| ) | |
| input_ids = encodings['input_ids'].to(DEVICE) | |
| attention_mask = encodings['attention_mask'].to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = MODEL(input_ids, attention_mask=attention_mask) | |
| logits = torch.sigmoid(outputs.logits).cpu().numpy()[0] | |
| predictions = [] | |
| for i, (logit, thresh) in enumerate(zip(logits, THRESHOLDS)): | |
| if logit >= thresh: | |
| predictions.append((EMOTION_LABELS[i], round(logit, 4))) | |
| predictions.sort(key=lambda x: x[1], reverse=True) | |
| result = "\n".join([f"{emotion}: {confidence:.4f}" for emotion, confidence in predictions]) or "No emotions predicted." | |
| return result, processed_text | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Predict emotions using the GoEmotions BERT model.") | |
| parser.add_argument("text", type=str, help="The input text to analyze for emotions.") | |
| args = parser.parse_args() | |
| result, processed = predict_emotions(args.text) | |
| print(f"Input: {args.text}") | |
| print(f"Processed: {processed}") | |
| print("Predicted Emotions:") | |
| print(result) |