whisky-wheel / lib /bert_regressor_utils.py
ziem-io's picture
Scaffold: Include custom libs
bb3d05e
raw
history blame
7.36 kB
import torch
from transformers import AutoTokenizer
from torch.utils.data import Dataset
import numpy as np
from .bert_regressor import BertMultiHeadRegressor, BertBinaryClassifier
###################################################################################
# Konstante Liste der acht Aromen-Kategorien für Whisky-Tasting-Notes.
# Diese wird von Modellen und Evaluierungsfunktionen verwendet.
TARGET_COLUMNS = [
"grainy",
"grassy",
"fragrant",
"fruity",
"peated",
"woody",
"winey",
"off-notes"
]
###################################################################################
COLORS = {
"grainy": "#FFF3B0",
"grassy": "#C4F0C5",
"fragrant": "#F3C4FB",
"fruity": "#FFD6B0",
"peated": "#CFCFCF",
"woody": "#EAD6C7",
"winey": "#F7B7A3",
"off-notes": "#D6E4F0",
"quantifiers": "#ff8083"
}
ICONS = {
"grainy": "🌾",
"grassy": "🌿",
"fragrant": "🌸",
"fruity": "🍋",
"peated": "🔥",
"woody": "🌲",
"winey": "🍷",
"off-notes": "☠️"
}
###################################################################################
class WhiskyDataset(Dataset):
def __init__(self, texts, targets, tokenizer, max_len):
self.texts = texts
self.targets = targets
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.texts)
def __getitem__(self, item):
text = str(self.texts[item])
target = self.targets[item]
# Einheitliche Tokenisierung über Hilfsfunktion
encoding = tokenize_input(text, self.tokenizer)
return {
'input_ids': encoding['input_ids'].squeeze(),
'attention_mask': encoding['attention_mask'].squeeze(),
'targets': torch.tensor(target, dtype=torch.float)
}
###################################################################################
def get_device(prefer_mps=True, verbose=True):
"""
Gibt das beste verfügbare Torch-Device zurück (MPS, CUDA oder CPU).
Args:
prefer_mps (bool): Ob bei Apple-Geräten 'mps' (Metal Performance Shaders) bevorzugt werden soll.
verbose (bool): Ob das erkannte Device ausgegeben werden soll.
Returns:
torch.device: Das beste verfügbare Gerät für das Training.
"""
if prefer_mps and torch.backends.mps.is_available():
device = torch.device("mps")
name = "Apple GPU (MPS)"
elif torch.cuda.is_available():
device = torch.device("cuda")
name = torch.cuda.get_device_name(device)
else:
device = torch.device("cpu")
name = "CPU"
if verbose:
print(f"✅ Verwendetes Gerät: {name} ({device})")
return device
###################################################################################
def tokenize_input(texts, tokenizer, max_len=256):
"""
Einheitliche Tokenisierung für Training und Inferenz.
Args:
texts (str or List[str]): Eingabetext(e).
tokenizer (PreTrainedTokenizer): z. B. BertTokenizer.
Returns:
dict: Dictionary mit PyTorch-Tensoren (input_ids, attention_mask).
"""
return tokenizer(
texts,
truncation=True,
padding='max_length',
max_length=max_len,
return_tensors='pt'
)
###################################################################################
def load_model_and_tokenizer(model_name, model_path, model_type="multihead"):
"""
Universelle Ladefunktion für BertMultiHeadRegressor oder BertBinaryClassifier.
Args:
model_name (str): Name des vortrainierten BERT-Modells (z. B. 'bert-base-uncased').
model_path (str): Pfad zur gespeicherten Modellzustandsdatei (.pt).
model_type (str): 'multihead' oder 'binary'. Default: 'multihead'.
Returns:
model (nn.Module): Geladenes Modell im Eval-Modus.
tokenizer (BertTokenizer): Passender Tokenizer.
device (torch.device): Verwendetes Rechengerät (CPU oder GPU).
"""
# Gerät automatisch ermitteln (GPU/CPU)
device = get_device()
# Modellzustand und Konfiguration laden
checkpoint = torch.load(model_path, map_location=device)
config = checkpoint["model_config"]
# Modell je nach Typ initialisieren
if model_type == "multihead":
model = BertMultiHeadRegressor(
pretrained_model_name=config["pretrained_model_name"],
n_heads=config["n_heads"],
unfreeze_from=config["unfreeze_from"],
dropout=config["dropout"]
)
elif model_type == "binary":
model = BertBinaryClassifier(
pretrained_model_name=config["pretrained_model_name"],
unfreeze_from=config["unfreeze_from"],
dropout=config["dropout"]
)
else:
raise ValueError(f"Unbekannter model_type: {model_type}")
# Gewichtungen laden und Modell auf Gerät verschieben
model.to(device)
model.load_state_dict(checkpoint["model_state_dict"])
model.eval() # Wechselt in den Inferenzmodus
# Lädt den passenden Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
return model, tokenizer, device
###################################################################################
def predict_flavours(review_text, model, tokenizer, device, max_len=256):
# Modell in den Evaluierungsmodus setzen (kein Dropout etc.)
model.eval()
# Eingabetext tokenisieren und als Tensoren zurückgeben
encoding = tokenize_input(
review_text,
tokenizer
)
# Tokens auf das richtige Device verschieben
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
# Inferenz ohne Gradientenberechnung (Effizienz)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask) # shape: [1, 8]
prediction = outputs.cpu().numpy().flatten() # [8] – flach machen
prediction = np.clip(prediction, 0.0, 4.0)
# In ein Dictionary umwandeln (z. B. {"fruity": 2.1, "peated": 3.8, ...})
result = {
flavour: round(float(score), 2)
for flavour, score in zip(TARGET_COLUMNS, prediction)
}
return result
###################################################################################
def predict_is_review(review_text, model, tokenizer, device, max_len=256, threshold=0.5):
# Modell in den Evaluierungsmodus setzen (kein Dropout etc.)
model.eval()
# Eingabetext tokenisieren und als Tensoren zurückgeben
encoding = tokenize_input(
review_text,
tokenizer
)
# Tokens auf das richtige Device verschieben
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
with torch.no_grad():
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
print(outputs.cpu().numpy()) # <--- Zeigt die rohen Logits
probs = torch.sigmoid(outputs) # [1, 1]
prob = float(probs.squeeze().cpu().numpy()) # Skalar
return {
"is_review": prob >= threshold,
"probability": round(prob, 4)
}
###################################################################################