Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
| # Standardbibliotheken | |
| import os # Umgebungsvariablen (z.B. HF_TOKEN) | |
| import types # für Instanz-Monkeypatch (fastText .predict) | |
| import html # HTML-Escaping für Ausgabe/Gradio | |
| import numpy as np # Numerik (z.B. für Wahrscheinlichkeiten) | |
| import time | |
| import random | |
| # Machine Learning / NLP | |
| import torch # PyTorch (Model, Tensor, Device) | |
| import fasttext # Sprach-ID (lid.176) | |
| # Diese beiden werden oft nicht direkt aufgerufen, müssen aber installiert sein, | |
| # damit Hugging Face/Tokenizer korrekt funktionieren (SentencePiece-Backends, Converter). | |
| import sentencepiece # Required für SentencePiece-basierte Tokenizer (DeBERTa v3) | |
| import tiktoken # Optionaler Converter; verhindert Fallback-Fehler/Warnungen | |
| from langid.langid import LanguageIdentifier, model | |
| # Hugging Face / Ökosystem | |
| import spaces | |
| from transformers import AutoTokenizer # Tokenizer-Lader (mit use_fast=False für SentencePiece) | |
| from huggingface_hub import hf_hub_download # Dateien/Weights aus dem HF Hub laden | |
| from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors) | |
| # UI / Serving | |
| import gradio as gr # Web-UI für Demo/Spaces | |
| # Projektspezifische Module | |
| from lib.bert_regressor import BertMultiHeadRegressor | |
| from lib.bert_regressor_utils import ( | |
| #load_model_and_tokenizer, | |
| predict_flavours, | |
| #predict_is_review, | |
| #TARGET_COLUMNS, | |
| #ICONS | |
| ) | |
| from lib.wheel import build_svg_with_values | |
| from lib.examples import EXAMPLES | |
| ### Stettings #################################################################### | |
| MODEL_BASE = "microsoft/deberta-v3-base" | |
| REPO_ID = "ziem-io/deberta_flavour_regressor_multi_head" | |
| # (optional) falls das Model-Repo privat ist: | |
| HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen | |
| MODEL_FILE = os.getenv("MODEL_FILE") # in Space-Secrets hinterlegen | |
| ################################################################################## | |
| # --- Download Weights --- | |
| weights_path = hf_hub_download( | |
| repo_id=REPO_ID, | |
| filename=MODEL_FILE, | |
| token=HF_TOKEN | |
| ) | |
| # --- Tokenizer (SentencePiece!) --- | |
| tokenizer_flavours = AutoTokenizer.from_pretrained( | |
| MODEL_BASE, | |
| use_fast=False | |
| ) | |
| model_flavours = BertMultiHeadRegressor( | |
| pretrained_model_name=MODEL_BASE | |
| ) | |
| state = load_file(weights_path) # safetensors -> dict[str, Tensor] | |
| _ = model_flavours.load_state_dict(state, strict=False) # strict=True wenn Keys exakt passen | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_flavours.to(device).eval() | |
| ### Check if lang is english ##################################################### | |
| ID = LanguageIdentifier.from_modelstring(model, norm_probs=True) | |
| def is_eng(text: str, min_chars: int = 6, threshold: float = 0.1): | |
| t = (text or "").strip() | |
| if len(t) < min_chars: | |
| return True, 0.0 | |
| lang, prob = ID.classify(t) # prob ∈ [0,1] | |
| return (lang == "en" and prob >= threshold), float(prob) | |
| ### Do actual prediction ######################################################### | |
| # Sekunden GPU-Zeit pro Call | |
| def predict(review: str): | |
| review = (review or "").strip() | |
| # Abort if no text if given | |
| if not review: | |
| # immer zwei Outputs zurückgeben | |
| return "Please enter a review.", {} | |
| # Check for lang of text | |
| review_is_eng, review_lang_prob = is_eng(review) | |
| # Abort if text is not english | |
| if not review_is_eng: | |
| # immer zwei Outputs zurückgeben | |
| return "Currently, only English reviews are supported.", { | |
| "is_eng": review_is_eng, | |
| "prob": review_lang_prob | |
| } | |
| prediction_flavours = {} | |
| prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0] | |
| # Do actual predictions if is english and whisky note | |
| t_start_flavours = time.time() | |
| prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device) | |
| prediction_flavours_list = list(prediction_flavours.values()) | |
| t_end_flavours = time.time() | |
| #html_out = f"<b>{html.escape(review)}</b>" | |
| html_out = build_svg_with_values(prediction_flavours_list) | |
| json_out = { | |
| "results": [{ | |
| #'icon': ICONS[name], | |
| "name": name, | |
| "score": score, | |
| #'level': get_level(score) | |
| } | |
| for name, score in prediction_flavours.items() | |
| ], | |
| "review": review, | |
| "model": MODEL_FILE, | |
| "device": device, | |
| "duration": round((t_end_flavours - t_start_flavours), 3), | |
| } | |
| return html_out, json_out | |
| ################################################################################## | |
| def random_text(): | |
| return random.choice(EXAMPLES) | |
| def get_device_info(): | |
| if torch.cuda.is_available(): | |
| return f"◉ Runs on GPU: {torch.cuda.get_device_name(0)}" | |
| else: | |
| return "◎ Runs on CPU (May be slower)" | |
| ### Create Form interface with Gradio Framework ################################## | |
| with gr.Blocks() as demo: | |
| gr.HTML("<h2>Multi-Axis Regression of Whisky Tasting Notes</h2>") | |
| gr.HTML(""" | |
| <h3>Automatically turn Whisky Tasting Notes into Flavour Wheels.</h3> | |
| <p>This model is a fine-tuned version of <a href='https://huggingface.co/microsoft/deberta-v3-base'>microsoft/deberta-v3-base</a> designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — <strong>grainy</strong>, <strong>grassy</strong>, <strong>fragrant</strong>, <strong>fruity</strong>, <strong>peated</strong>, <strong>woody</strong>, <strong>winey</strong> and <strong>off-notes</strong> — on a continuous scale from 0 (none) to 4 (extreme).</p> | |
| """) | |
| gr.HTML(f"<span style='color: var(--block-title-text-color)'>{get_device_info()}</span>") | |
| with gr.Row(): # alles nebeneinander | |
| with gr.Column(scale=1): # linke Seite: Input | |
| review_box = gr.Textbox( | |
| label="Whisky Review (English only)", | |
| lines=8, | |
| placeholder="Enter whisky review", | |
| value=random_text(), | |
| ) | |
| with gr.Row(): | |
| replace_btn = gr.Button("Load Example", variant="secondary", scale=1) | |
| submit_btn = gr.Button("Submit", variant="primary", scale=1) | |
| with gr.Column(scale=1): # rechte Seite: Output | |
| html_out = gr.HTML(label="Flavour Wheel") | |
| json_out = gr.JSON(label="JSON") | |
| # Events | |
| submit_btn.click(predict, inputs=review_box, outputs=[html_out, json_out]) | |
| replace_btn.click(random_text, outputs=review_box) | |
| demo.launch() |