# Standardbibliotheken import os # Umgebungsvariablen (z.B. HF_TOKEN) import types # für Instanz-Monkeypatch (fastText .predict) import html # HTML-Escaping für Ausgabe/Gradio import numpy as np # Numerik (z.B. für Wahrscheinlichkeiten) import time import random # Machine Learning / NLP import torch # PyTorch (Model, Tensor, Device) import fasttext # Sprach-ID (lid.176) # Diese beiden werden oft nicht direkt aufgerufen, müssen aber installiert sein, # damit Hugging Face/Tokenizer korrekt funktionieren (SentencePiece-Backends, Converter). import sentencepiece # Required für SentencePiece-basierte Tokenizer (DeBERTa v3) import tiktoken # Optionaler Converter; verhindert Fallback-Fehler/Warnungen # Hugging Face / Ökosystem import spaces from transformers import AutoTokenizer # Tokenizer-Lader (mit use_fast=False für SentencePiece) from huggingface_hub import hf_hub_download # Dateien/Weights aus dem HF Hub laden from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors) # UI / Serving import gradio as gr # Web-UI für Demo/Spaces # Projektspezifische Module from lib.bert_regressor import BertMultiHeadRegressor from lib.bert_regressor_utils import ( load_model_and_tokenizer, predict_flavours, #predict_is_review, TARGET_COLUMNS, ICONS ) from lib.wheel import build_svg_with_values from lib.examples import EXAMPLES ### Stettings #################################################################### MODEL_BASE = "microsoft/deberta-v3-base" REPO_ID = "ziem-io/deberta_flavour_regressor_multi_head" # (optional) falls das Model-Repo privat ist: HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen MODEL_FILE = os.getenv("MODEL_FILE") # in Space-Secrets hinterlegen ################################################################################## # --- Download Weights --- weights_path = hf_hub_download( repo_id=REPO_ID, filename=MODEL_FILE, token=HF_TOKEN ) # --- Tokenizer (SentencePiece!) --- tokenizer_flavours = AutoTokenizer.from_pretrained( MODEL_BASE, use_fast=False ) model_flavours = BertMultiHeadRegressor( pretrained_model_name=MODEL_BASE ) state = load_file(weights_path) # safetensors -> dict[str, Tensor] _ = model_flavours.load_state_dict(state, strict=False) # strict=True wenn Keys exakt passen device = "cuda" if torch.cuda.is_available() else "cpu" model_flavours.to(device).eval() ################################################################################## # offizielles Mirror-Repo mit lid.176.* lid_path = hf_hub_download( repo_id="julien-c/fasttext-language-id", filename="lid.176.ftz" ) lid_model = fasttext.load_model(lid_path) # robustes predict mit NumPy-2-Fix + Fallback, falls fastText nur Labels liefert def _predict_np2_compat(self, text, k=1, threshold=0.0, on_unicode_error='strict'): out = self.f.predict(text, k, threshold, on_unicode_error) # Fälle: # 1) (labels, probs) # 2) labels-only (einige Builds/SWIG-Versionen) if isinstance(out, tuple) and len(out) == 2: labels, probs = out else: labels = out # sinnvolle Defaults, falls keine Wahrscheinlichkeiten vorliegen if isinstance(labels, (list, tuple)): probs = [1.0] * len(labels) else: labels = [labels] probs = [1.0] return labels, np.asarray(probs) # np.asarray statt np.array(copy=False) # Instanz patchen lid_model.predict = types.MethodType(_predict_np2_compat, lid_model) ### Check if lang is english ##################################################### def is_eng(review: str): lang_labels, lang_probs = lid_model.predict(review) print(lang_labels, lang_probs) if not lang_labels: # kein Label zurückgegeben return False, 0.0 lang_label = lang_labels[0] lang_prob = float(lang_probs[0]) return lang_label[1] == "__label__en", lang_label[1], lang_prob ### Do actual prediction ######################################################### @spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call def predict(review: str): review = (review or "").strip() # Abort if no text if given if not review: # immer zwei Outputs zurückgeben return "Please enter a review.", {} # Check for lang of text review_is_eng, review_lang_label, review_lang_prob = is_eng(review) # Abort if text is not english if not review_is_eng: # immer zwei Outputs zurückgeben return "Only English reviews are supported.", { "lang": review_lang_label.replace("__label__", ""), "prob": review_lang_prob } prediction_flavours = {} prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0] # Do actual predictions if is english and whisky note t_start_flavours = time.time() prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device) prediction_flavours_list = list(prediction_flavours.values()) t_end_flavours = time.time() #html_out = f"{html.escape(review)}" html_out = build_svg_with_values(prediction_flavours_list) json_out = { "results": [{ #'icon': ICONS[name], "name": name, "score": score, #'level': get_level(score) } for name, score in prediction_flavours.items() ], "review": review, "model": MODEL_FILE, "device": device, "duration": round((t_end_flavours - t_start_flavours), 3), } return html_out, json_out ################################################################################## def replace_text(): return random.choice(EXAMPLES) def get_device_info(): if torch.cuda.is_available(): return f"◉ Runs on GPU: {torch.cuda.get_device_name(0)}" else: return "◎ Runs on CPU (May be slow)" ### Create Form interface with Gradio Framework ################################## with gr.Blocks() as demo: gr.HTML("

Multi-Axis Regression of Whisky Tasting Notes

") gr.HTML("This nodel is a fine-tuned version of microsoft/deberta-v3-base designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — grainy, grassy, fragrant, fruity, peated, woody, winey and off-notes — on a continuous scale from 0 (none) to 4 (extreme).") gr.HTML(f"{get_device_info()}") with gr.Row(): # alles nebeneinander with gr.Column(scale=1): # linke Seite: Input review_box = gr.Textbox( label="Whisky Review", lines=8, placeholder="Enter whisky review", value=random.choice(EXAMPLES), ) with gr.Row(): replace_btn = gr.Button("Load Example", variant="secondary", scale=1) submit_btn = gr.Button("Submit", variant="primary", scale=1) with gr.Column(scale=1): # rechte Seite: Output html_out = gr.HTML(label="Flavour Wheel") json_out = gr.JSON(label="JSON") # Events submit_btn.click(predict, inputs=review_box, outputs=[html_out, json_out]) replace_btn.click(replace_text, outputs=review_box) demo.launch()