Spaces:
Running on CPU Upgrade

whisky-wheel / app.py
ziem-io's picture
Update: Refactor random text
fa07bfe
raw
history blame
6.68 kB
# Standardbibliotheken
import os # Umgebungsvariablen (z.B. HF_TOKEN)
import types # für Instanz-Monkeypatch (fastText .predict)
import html # HTML-Escaping für Ausgabe/Gradio
import numpy as np # Numerik (z.B. für Wahrscheinlichkeiten)
import time
import random
# Machine Learning / NLP
import torch # PyTorch (Model, Tensor, Device)
import fasttext # Sprach-ID (lid.176)
# Diese beiden werden oft nicht direkt aufgerufen, müssen aber installiert sein,
# damit Hugging Face/Tokenizer korrekt funktionieren (SentencePiece-Backends, Converter).
import sentencepiece # Required für SentencePiece-basierte Tokenizer (DeBERTa v3)
import tiktoken # Optionaler Converter; verhindert Fallback-Fehler/Warnungen
from langid.langid import LanguageIdentifier, model
# Hugging Face / Ökosystem
import spaces
from transformers import AutoTokenizer # Tokenizer-Lader (mit use_fast=False für SentencePiece)
from huggingface_hub import hf_hub_download # Dateien/Weights aus dem HF Hub laden
from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors)
# UI / Serving
import gradio as gr # Web-UI für Demo/Spaces
# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor
from lib.bert_regressor_utils import (
#load_model_and_tokenizer,
predict_flavours,
#predict_is_review,
#TARGET_COLUMNS,
#ICONS
)
from lib.wheel import build_svg_with_values
from lib.examples import EXAMPLES
### Stettings ####################################################################
MODEL_BASE = "microsoft/deberta-v3-base"
REPO_ID = "ziem-io/deberta_flavour_regressor_multi_head"
# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen
MODEL_FILE = os.getenv("MODEL_FILE") # in Space-Secrets hinterlegen
##################################################################################
# --- Download Weights ---
weights_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILE,
token=HF_TOKEN
)
# --- Tokenizer (SentencePiece!) ---
tokenizer_flavours = AutoTokenizer.from_pretrained(
MODEL_BASE,
use_fast=False
)
model_flavours = BertMultiHeadRegressor(
pretrained_model_name=MODEL_BASE
)
state = load_file(weights_path) # safetensors -> dict[str, Tensor]
_ = model_flavours.load_state_dict(state, strict=False) # strict=True wenn Keys exakt passen
device = "cuda" if torch.cuda.is_available() else "cpu"
model_flavours.to(device).eval()
### Check if lang is english #####################################################
ID = LanguageIdentifier.from_modelstring(model, norm_probs=True)
def is_eng(text: str, min_chars: int = 6, threshold: float = 0.1):
t = (text or "").strip()
if len(t) < min_chars:
return True, 0.0
lang, prob = ID.classify(t) # prob ∈ [0,1]
return (lang == "en" and prob >= threshold), float(prob)
### Do actual prediction #########################################################
@spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call
def predict(review: str):
review = (review or "").strip()
# Abort if no text if given
if not review:
# immer zwei Outputs zurückgeben
return "Please enter a review.", {}
# Check for lang of text
review_is_eng, review_lang_prob = is_eng(review)
# Abort if text is not english
if not review_is_eng:
# immer zwei Outputs zurückgeben
return "Currently, only English reviews are supported.", {
"is_eng": review_is_eng,
"prob": review_lang_prob
}
prediction_flavours = {}
prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]
# Do actual predictions if is english and whisky note
t_start_flavours = time.time()
prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device)
prediction_flavours_list = list(prediction_flavours.values())
t_end_flavours = time.time()
#html_out = f"<b>{html.escape(review)}</b>"
html_out = build_svg_with_values(prediction_flavours_list)
json_out = {
"results": [{
#'icon': ICONS[name],
"name": name,
"score": score,
#'level': get_level(score)
}
for name, score in prediction_flavours.items()
],
"review": review,
"model": MODEL_FILE,
"device": device,
"duration": round((t_end_flavours - t_start_flavours), 3),
}
return html_out, json_out
##################################################################################
def random_text():
return random.choice(EXAMPLES)
def get_device_info():
if torch.cuda.is_available():
return f"◉ Runs on GPU: {torch.cuda.get_device_name(0)}"
else:
return "◎ Runs on CPU (May be slower)"
### Create Form interface with Gradio Framework ##################################
with gr.Blocks() as demo:
gr.HTML("<h2>Multi-Axis Regression of Whisky Tasting Notes</h2>")
gr.HTML("""
<h3>Automatically turn Whisky Tasting Notes into Flavour Wheels.</h3>
<p>This model is a fine-tuned version of <a href='https://huggingface.co/microsoft/deberta-v3-base'>microsoft/deberta-v3-base</a> designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — <strong>grainy</strong>, <strong>grassy</strong>, <strong>fragrant</strong>, <strong>fruity</strong>, <strong>peated</strong>, <strong>woody</strong>, <strong>winey</strong> and <strong>off-notes</strong> — on a continuous scale from 0 (none) to 4 (extreme).</p>
""")
gr.HTML(f"<span style='color: var(--block-title-text-color)'>{get_device_info()}</span>")
with gr.Row(): # alles nebeneinander
with gr.Column(scale=1): # linke Seite: Input
review_box = gr.Textbox(
label="Whisky Review (English only)",
lines=8,
placeholder="Enter whisky review",
value=random_text(),
)
with gr.Row():
replace_btn = gr.Button("Load Example", variant="secondary", scale=1)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
with gr.Column(scale=1): # rechte Seite: Output
html_out = gr.HTML(label="Flavour Wheel")
json_out = gr.JSON(label="JSON")
# Events
submit_btn.click(predict, inputs=review_box, outputs=[html_out, json_out])
replace_btn.click(random_text, outputs=review_box)
demo.launch()