# Standardbibliotheken import os # Umgebungsvariablen (z.B. HF_TOKEN) import types # für Instanz-Monkeypatch (fastText .predict) import html # HTML-Escaping für Ausgabe/Gradio import numpy as np # Numerik (z.B. für Wahrscheinlichkeiten) import time import random # Machine Learning / NLP import torch # PyTorch (Model, Tensor, Device) import fasttext # Sprach-ID (lid.176) # Diese beiden werden oft nicht direkt aufgerufen, müssen aber installiert sein, # damit Hugging Face/Tokenizer korrekt funktionieren (SentencePiece-Backends, Converter). import sentencepiece # Required für SentencePiece-basierte Tokenizer (DeBERTa v3) import tiktoken # Optionaler Converter; verhindert Fallback-Fehler/Warnungen from langid.langid import LanguageIdentifier, model # Hugging Face / Ökosystem import spaces from transformers import AutoTokenizer # Tokenizer-Lader (mit use_fast=False für SentencePiece) from huggingface_hub import hf_hub_download # Dateien/Weights aus dem HF Hub laden from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors) # UI / Serving import gradio as gr # Web-UI für Demo/Spaces # Projektspezifische Module from lib.bert_regressor import BertMultiHeadRegressor from lib.bert_regressor_utils import ( #load_model_and_tokenizer, predict_flavours, #predict_is_review, #TARGET_COLUMNS, #ICONS ) from lib.wheel import build_svg_with_values from lib.examples import EXAMPLES ### Stettings #################################################################### MODEL_BASE = "microsoft/deberta-v3-base" REPO_ID = "ziem-io/deberta_flavour_regressor_multi_head" # (optional) falls das Model-Repo privat ist: HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen MODEL_FILE = os.getenv("MODEL_FILE") # in Space-Secrets hinterlegen ################################################################################## # --- Download Weights --- weights_path = hf_hub_download( repo_id=REPO_ID, filename=MODEL_FILE, token=HF_TOKEN ) # --- Tokenizer (SentencePiece!) --- tokenizer_flavours = AutoTokenizer.from_pretrained( MODEL_BASE, use_fast=False ) model_flavours = BertMultiHeadRegressor( pretrained_model_name=MODEL_BASE ) state = load_file(weights_path) # safetensors -> dict[str, Tensor] _ = model_flavours.load_state_dict(state, strict=False) # strict=True wenn Keys exakt passen device = "cuda" if torch.cuda.is_available() else "cpu" model_flavours.to(device).eval() ### Check if lang is english ##################################################### ID = LanguageIdentifier.from_modelstring(model, norm_probs=True) def is_eng(text: str, min_chars: int = 6, threshold: float = 0.1): t = (text or "").strip() if len(t) < min_chars: return True, 0.0 lang, prob = ID.classify(t) # prob ∈ [0,1] return (lang == "en" and prob >= threshold), float(prob) ### Do actual prediction ######################################################### @spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call def predict(review: str): review = (review or "").strip() # Abort if no text if given if not review: # immer zwei Outputs zurückgeben return "Please enter a review.", {} # Check for lang of text review_is_eng, review_lang_prob = is_eng(review) # Abort if text is not english if not review_is_eng: # immer zwei Outputs zurückgeben return "Currently, only English reviews are supported.", { "is_eng": review_is_eng, "prob": review_lang_prob } prediction_flavours = {} prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0] # Do actual predictions if is english and whisky note t_start_flavours = time.time() prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device) prediction_flavours_list = list(prediction_flavours.values()) t_end_flavours = time.time() #html_out = f"{html.escape(review)}" html_out = build_svg_with_values(prediction_flavours_list) json_out = { "results": [{ #'icon': ICONS[name], "name": name, "score": score, #'level': get_level(score) } for name, score in prediction_flavours.items() ], "review": review, "model": MODEL_FILE, "device": device, "duration": round((t_end_flavours - t_start_flavours), 3), } return html_out, json_out ################################################################################## def random_text(): return random.choice(EXAMPLES) def get_device_info(): if torch.cuda.is_available(): return f"◉ Runs on GPU: {torch.cuda.get_device_name(0)}" else: return "◎ Runs on CPU (May be slower)" ### Create Form interface with Gradio Framework ################################## with gr.Blocks() as demo: gr.HTML("
This model is a fine-tuned version of microsoft/deberta-v3-base designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — grainy, grassy, fragrant, fruity, peated, woody, winey and off-notes — on a continuous scale from 0 (none) to 4 (extreme).
""") gr.HTML(f"{get_device_info()}") with gr.Row(): # alles nebeneinander with gr.Column(scale=1): # linke Seite: Input review_box = gr.Textbox( label="Whisky Review (English only)", lines=8, placeholder="Enter whisky review", value=random_text(), ) with gr.Row(): replace_btn = gr.Button("Load Example", variant="secondary", scale=1) submit_btn = gr.Button("Submit", variant="primary", scale=1) with gr.Column(scale=1): # rechte Seite: Output html_out = gr.HTML(label="Flavour Wheel") json_out = gr.JSON(label="JSON") # Events submit_btn.click(predict, inputs=review_box, outputs=[html_out, json_out]) replace_btn.click(random_text, outputs=review_box) demo.launch()