# Standardbibliotheken import os # Umgebungsvariablen (z.B. HF_TOKEN) import time # Timing / Performance-Messung import random # Zufallswerte (z.B. Beispiel-Reviews) import html # HTML-Escaping für sichere Ausgabe in Gradio import types # Monkeypatching von Instanzen (fastText .predict) import numpy as np # Numerische Arrays und Wahrscheinlichkeiten # Machine Learning / NLP import torch # PyTorch: Modelle, Tensoren, Devices import fasttext # Sprach-ID-Modell (lid.176) # Folgende sind notwendig, auch wenn sie nicht explizit genutzt werden: import sentencepiece # Pflicht für SentencePiece-basierte Tokenizer (z.B. DeBERTa v3) import tiktoken # Optionaler Converter (verhindert Fallback-Fehler bei Tokenizer) from langid.langid import LanguageIdentifier, model # Alternative Sprach-ID # Hugging Face Ökosystem import spaces # HF Spaces-Dekoratoren (@spaces.GPU) from transformers import AutoTokenizer # Tokenizer laden (use_fast=False für DeBERTa v3) from huggingface_hub import hf_hub_download # Download von Dateien/Weights aus dem HF Hub from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors) # Übersetzung import deepl # DeepL API für automatische Übersetzung # UI / Serving import gradio as gr # Web-UI für Demo/Spaces # Projektspezifische Module from lib.bert_regressor import BertMultiHeadRegressor from lib.bert_regressor_utils import ( predict_flavours, # Hauptfunktion: Vorhersage der 8 Aromenachsen ) from lib.wheel import build_svg_with_values # SVG-Rendering für Flavour Wheel from lib.examples import EXAMPLES # Beispiel-Reviews (vordefiniert) ### Stettings #################################################################### MODEL_BASE = "microsoft/deberta-v3-base" REPO_ID = "ziem-io/deberta_flavour_regressor_multi_head" # (optional) falls das Model-Repo privat ist: HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen MODEL_FILE = os.getenv("MODEL_FILE") # in Space-Secrets hinterlegen DEEPL_API_KEY = os.getenv("DEEPL_API_KEY") # in Space-Secrets hinterlegen ################################################################################## # --- Download Weights --- weights_path = hf_hub_download( repo_id=REPO_ID, filename=MODEL_FILE, token=HF_TOKEN ) # --- Tokenizer (SentencePiece!) --- tokenizer_flavours = AutoTokenizer.from_pretrained( MODEL_BASE, use_fast=False ) model_flavours = BertMultiHeadRegressor( pretrained_model_name=MODEL_BASE ) state = load_file(weights_path) # safetensors -> dict[str, Tensor] _ = model_flavours.load_state_dict(state, strict=False) # strict=True wenn Keys exakt passen device = "cuda" if torch.cuda.is_available() else "cpu" model_flavours.to(device).eval() ### Check if lang is english ##################################################### ID = LanguageIdentifier.from_modelstring(model, norm_probs=True) def _is_eng(text: str, min_chars: int = 6, threshold: float = 0.1): t = (text or "").strip() if len(t) < min_chars: return True, 0.0 lang, prob = ID.classify(t) # prob ∈ [0,1] return (lang == "en" and prob >= threshold), float(prob) def _translate_en(text: str, target_lang: str = "EN-GB"): deepl_client = deepl.Translator(DEEPL_API_KEY) result = deepl_client.translate_text(text, target_lang=target_lang) return result.text ### Do actual prediction ######################################################### @spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call def predict(review: str): review = (review or "").strip() is_translated = False html_info_out = "" # Abort if no text if given if not review: # immer drei Outputs zurückgeben return "Please enter a review.", "", {} # Check for lang of text review_is_eng, review_lang_prob = _is_eng(review) # Abort if text is not english if not review_is_eng: review = _translate_en(review) html_info_out = f"""Your text has been automatically translated:

{review}

""" is_translated = True prediction_flavours = {} prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0] # Do actual predictions if is english and whisky note t_start_flavours = time.time() prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device) prediction_flavours_list = list(prediction_flavours.values()) t_end_flavours = time.time() html_wheel_out = build_svg_with_values(prediction_flavours_list) json_out = { "results": [{ #'icon': ICONS[name], "name": name, "score": score, #'level': get_level(score) } for name, score in prediction_flavours.items() ], "review": review, "model": MODEL_FILE, "device": device, "translated": is_translated, "duration": round((t_end_flavours - t_start_flavours), 3), } return html_info_out, html_wheel_out, json_out ################################################################################## def random_text(): return random.choice(EXAMPLES) def _get_device_info(): if torch.cuda.is_available(): return f"◉ Runs on GPU: {torch.cuda.get_device_name(0)}" else: return "◎ Runs on CPU (May be slower)" ### Create Form interface with Gradio Framework ################################## custom_css = """ @media (prefers-color-scheme: dark) { svg#wheel > text { fill: rgb(200, 200, 200); } } """ with gr.Blocks(css=custom_css) as demo: gr.HTML("

Multi-Axis Regression of Whisky Tasting Notes

") gr.HTML("""

Automatically turns Whisky Tasting Notes into Flavour Wheels.

This model is a fine-tuned version of microsoft/deberta-v3-base designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — grainy, grassy, fragrant, fruity, peated, woody, winey and off-notes — on a continuous scale from 0 (none) to 4 (extreme).

""") gr.HTML(f"{_get_device_info()}") with gr.Row(): # alles nebeneinander with gr.Column(scale=1): # linke Seite: Input review_box = gr.Textbox( label="Whisky Review", lines=8, placeholder="Enter whisky review", value=random_text(), ) gr.HTML("
Note: Non-English texts will be automatically translated.
") with gr.Row(): replace_btn = gr.Button("Load Example", variant="secondary", scale=1) submit_btn = gr.Button("Submit", variant="primary", scale=1) with gr.Column(scale=1): # rechte Seite: Output html_info_out = gr.HTML(label="Info") html_wheel_out = gr.HTML(label="Flavour Wheel") json_out = gr.JSON(label="JSON") # Events submit_btn.click(predict, inputs=review_box, outputs=[html_info_out, html_wheel_out, json_out]) replace_btn.click(random_text, outputs=review_box) demo.launch(show_api=False)