Spaces:
Running
Running
| import torch | |
| from transformers import AutoTokenizer | |
| from torch.utils.data import Dataset | |
| import numpy as np | |
| from .bert_regressor import BertMultiHeadRegressor, BertBinaryClassifier | |
| ################################################################################### | |
| # Konstante Liste der acht Aromen-Kategorien für Whisky-Tasting-Notes. | |
| # Diese wird von Modellen und Evaluierungsfunktionen verwendet. | |
| TARGET_COLUMNS = [ | |
| "grainy", | |
| "grassy", | |
| "fragrant", | |
| "fruity", | |
| "peated", | |
| "woody", | |
| "winey", | |
| "off-notes" | |
| ] | |
| ################################################################################### | |
| COLORS = { | |
| "grainy": "#FFF3B0", | |
| "grassy": "#C4F0C5", | |
| "fragrant": "#F3C4FB", | |
| "fruity": "#FFD6B0", | |
| "peated": "#CFCFCF", | |
| "woody": "#EAD6C7", | |
| "winey": "#F7B7A3", | |
| "off-notes": "#D6E4F0", | |
| "quantifiers": "#ff8083" | |
| } | |
| ICONS = { | |
| "grainy": "🌾", | |
| "grassy": "🌿", | |
| "fragrant": "🌸", | |
| "fruity": "🍋", | |
| "peated": "🔥", | |
| "woody": "🌲", | |
| "winey": "🍷", | |
| "off-notes": "☠️" | |
| } | |
| ################################################################################### | |
| class WhiskyDataset(Dataset): | |
| def __init__(self, texts, targets, tokenizer, max_len): | |
| self.texts = texts | |
| self.targets = targets | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.texts) | |
| def __getitem__(self, item): | |
| text = str(self.texts[item]) | |
| target = self.targets[item] | |
| # Einheitliche Tokenisierung über Hilfsfunktion | |
| encoding = tokenize_input(text, self.tokenizer) | |
| return { | |
| 'input_ids': encoding['input_ids'].squeeze(), | |
| 'attention_mask': encoding['attention_mask'].squeeze(), | |
| 'targets': torch.tensor(target, dtype=torch.float) | |
| } | |
| ################################################################################### | |
| def get_device(prefer_mps=True, verbose=True): | |
| """ | |
| Gibt das beste verfügbare Torch-Device zurück (MPS, CUDA oder CPU). | |
| Args: | |
| prefer_mps (bool): Ob bei Apple-Geräten 'mps' (Metal Performance Shaders) bevorzugt werden soll. | |
| verbose (bool): Ob das erkannte Device ausgegeben werden soll. | |
| Returns: | |
| torch.device: Das beste verfügbare Gerät für das Training. | |
| """ | |
| if prefer_mps and torch.backends.mps.is_available(): | |
| device = torch.device("mps") | |
| name = "Apple GPU (MPS)" | |
| elif torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| name = torch.cuda.get_device_name(device) | |
| else: | |
| device = torch.device("cpu") | |
| name = "CPU" | |
| if verbose: | |
| print(f"✅ Verwendetes Gerät: {name} ({device})") | |
| return device | |
| ################################################################################### | |
| def tokenize_input(texts, tokenizer, max_len=256): | |
| """ | |
| Einheitliche Tokenisierung für Training und Inferenz. | |
| Args: | |
| texts (str or List[str]): Eingabetext(e). | |
| tokenizer (PreTrainedTokenizer): z. B. BertTokenizer. | |
| Returns: | |
| dict: Dictionary mit PyTorch-Tensoren (input_ids, attention_mask). | |
| """ | |
| return tokenizer( | |
| texts, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=max_len, | |
| return_tensors='pt' | |
| ) | |
| ################################################################################### | |
| def load_model_and_tokenizer(model_name, model_path, model_type="multihead"): | |
| """ | |
| Universelle Ladefunktion für BertMultiHeadRegressor oder BertBinaryClassifier. | |
| Args: | |
| model_name (str): Name des vortrainierten BERT-Modells (z. B. 'bert-base-uncased'). | |
| model_path (str): Pfad zur gespeicherten Modellzustandsdatei (.pt). | |
| model_type (str): 'multihead' oder 'binary'. Default: 'multihead'. | |
| Returns: | |
| model (nn.Module): Geladenes Modell im Eval-Modus. | |
| tokenizer (BertTokenizer): Passender Tokenizer. | |
| device (torch.device): Verwendetes Rechengerät (CPU oder GPU). | |
| """ | |
| # Gerät automatisch ermitteln (GPU/CPU) | |
| device = get_device() | |
| # Modellzustand und Konfiguration laden | |
| checkpoint = torch.load(model_path, map_location=device) | |
| config = checkpoint["model_config"] | |
| # Modell je nach Typ initialisieren | |
| if model_type == "multihead": | |
| model = BertMultiHeadRegressor( | |
| pretrained_model_name=config["pretrained_model_name"], | |
| n_heads=config["n_heads"], | |
| unfreeze_from=config["unfreeze_from"], | |
| dropout=config["dropout"] | |
| ) | |
| elif model_type == "binary": | |
| model = BertBinaryClassifier( | |
| pretrained_model_name=config["pretrained_model_name"], | |
| unfreeze_from=config["unfreeze_from"], | |
| dropout=config["dropout"] | |
| ) | |
| else: | |
| raise ValueError(f"Unbekannter model_type: {model_type}") | |
| # Gewichtungen laden und Modell auf Gerät verschieben | |
| model.to(device) | |
| model.load_state_dict(checkpoint["model_state_dict"]) | |
| model.eval() # Wechselt in den Inferenzmodus | |
| # Lädt den passenden Tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| return model, tokenizer, device | |
| ################################################################################### | |
| def predict_flavours(review_text, model, tokenizer, device, max_len=256): | |
| # Modell in den Evaluierungsmodus setzen (kein Dropout etc.) | |
| model.eval() | |
| # Eingabetext tokenisieren und als Tensoren zurückgeben | |
| encoding = tokenize_input( | |
| review_text, | |
| tokenizer | |
| ) | |
| # Tokens auf das richtige Device verschieben | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| # Inferenz ohne Gradientenberechnung (Effizienz) | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) # shape: [1, 8] | |
| prediction = outputs.cpu().numpy().flatten() # [8] – flach machen | |
| prediction = np.clip(prediction, 0.0, 4.0) | |
| # In ein Dictionary umwandeln (z. B. {"fruity": 2.1, "peated": 3.8, ...}) | |
| result = { | |
| flavour: round(float(score), 2) | |
| for flavour, score in zip(TARGET_COLUMNS, prediction) | |
| } | |
| return result | |
| ################################################################################### | |
| def predict_is_review(review_text, model, tokenizer, device, max_len=256, threshold=0.5): | |
| # Modell in den Evaluierungsmodus setzen (kein Dropout etc.) | |
| model.eval() | |
| # Eingabetext tokenisieren und als Tensoren zurückgeben | |
| encoding = tokenize_input( | |
| review_text, | |
| tokenizer | |
| ) | |
| # Tokens auf das richtige Device verschieben | |
| input_ids = encoding['input_ids'].to(device) | |
| attention_mask = encoding['attention_mask'].to(device) | |
| with torch.no_grad(): | |
| outputs = model(input_ids=input_ids, attention_mask=attention_mask) | |
| print(outputs.cpu().numpy()) # <--- Zeigt die rohen Logits | |
| probs = torch.sigmoid(outputs) # [1, 1] | |
| prob = float(probs.squeeze().cpu().numpy()) # Skalar | |
| return { | |
| "is_review": prob >= threshold, | |
| "probability": round(prob, 4) | |
| } | |
| ################################################################################### | |