Spaces:
Running
Running
File size: 6,681 Bytes
9a5d8ec 177a8e6 2954446 9a5d8ec b6ecc6e 9a5d8ec 177a8e6 9a5d8ec d53315e bb3d05e d463b27 bb3d05e d463b27 bb3d05e 8be8777 2fe0796 bb3d05e 9cfe75e 3d475b8 f677a94 3d475b8 f677a94 3d475b8 9cfe75e 325ed03 b6ecc6e 6370ba9 7fa1794 b6ecc6e 9ea7979 9cfe75e defc447 c921016 9ea7979 f1b3d54 f35e5d2 b3f79ff 177a8e6 f1b3d54 eded7fe f1b3d54 7dec0d9 eded7fe f1b3d54 5c5c1da f1b3d54 177a8e6 d5e14ad 177a8e6 f1b3d54 f35e5d2 f1b3d54 defc447 9799c94 c921016 defc447 c921016 4c56373 f677a94 4c56373 c921016 cd07d29 79079d1 fa07bfe 2954446 79079d1 32c703c f6cd476 32c703c 3a6ecfa 32c703c 9cfe75e 2954446 06353d7 296279c 06353d7 296279c 06353d7 2954446 c50c843 7dec0d9 c50c843 fa07bfe c50c843 ad4809c c50c843 5c5c1da c50c843 79079d1 2954446 fa07bfe ac240ce 2954446 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
# Standardbibliotheken
import os # Umgebungsvariablen (z.B. HF_TOKEN)
import types # für Instanz-Monkeypatch (fastText .predict)
import html # HTML-Escaping für Ausgabe/Gradio
import numpy as np # Numerik (z.B. für Wahrscheinlichkeiten)
import time
import random
# Machine Learning / NLP
import torch # PyTorch (Model, Tensor, Device)
import fasttext # Sprach-ID (lid.176)
# Diese beiden werden oft nicht direkt aufgerufen, müssen aber installiert sein,
# damit Hugging Face/Tokenizer korrekt funktionieren (SentencePiece-Backends, Converter).
import sentencepiece # Required für SentencePiece-basierte Tokenizer (DeBERTa v3)
import tiktoken # Optionaler Converter; verhindert Fallback-Fehler/Warnungen
from langid.langid import LanguageIdentifier, model
# Hugging Face / Ökosystem
import spaces
from transformers import AutoTokenizer # Tokenizer-Lader (mit use_fast=False für SentencePiece)
from huggingface_hub import hf_hub_download # Dateien/Weights aus dem HF Hub laden
from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors)
# UI / Serving
import gradio as gr # Web-UI für Demo/Spaces
# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor
from lib.bert_regressor_utils import (
#load_model_and_tokenizer,
predict_flavours,
#predict_is_review,
#TARGET_COLUMNS,
#ICONS
)
from lib.wheel import build_svg_with_values
from lib.examples import EXAMPLES
### Stettings ####################################################################
MODEL_BASE = "microsoft/deberta-v3-base"
REPO_ID = "ziem-io/deberta_flavour_regressor_multi_head"
# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen
MODEL_FILE = os.getenv("MODEL_FILE") # in Space-Secrets hinterlegen
##################################################################################
# --- Download Weights ---
weights_path = hf_hub_download(
repo_id=REPO_ID,
filename=MODEL_FILE,
token=HF_TOKEN
)
# --- Tokenizer (SentencePiece!) ---
tokenizer_flavours = AutoTokenizer.from_pretrained(
MODEL_BASE,
use_fast=False
)
model_flavours = BertMultiHeadRegressor(
pretrained_model_name=MODEL_BASE
)
state = load_file(weights_path) # safetensors -> dict[str, Tensor]
_ = model_flavours.load_state_dict(state, strict=False) # strict=True wenn Keys exakt passen
device = "cuda" if torch.cuda.is_available() else "cpu"
model_flavours.to(device).eval()
### Check if lang is english #####################################################
ID = LanguageIdentifier.from_modelstring(model, norm_probs=True)
def is_eng(text: str, min_chars: int = 6, threshold: float = 0.1):
t = (text or "").strip()
if len(t) < min_chars:
return True, 0.0
lang, prob = ID.classify(t) # prob ∈ [0,1]
return (lang == "en" and prob >= threshold), float(prob)
### Do actual prediction #########################################################
@spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call
def predict(review: str):
review = (review or "").strip()
# Abort if no text if given
if not review:
# immer zwei Outputs zurückgeben
return "Please enter a review.", {}
# Check for lang of text
review_is_eng, review_lang_prob = is_eng(review)
# Abort if text is not english
if not review_is_eng:
# immer zwei Outputs zurückgeben
return "Currently, only English reviews are supported.", {
"is_eng": review_is_eng,
"prob": review_lang_prob
}
prediction_flavours = {}
prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]
# Do actual predictions if is english and whisky note
t_start_flavours = time.time()
prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device)
prediction_flavours_list = list(prediction_flavours.values())
t_end_flavours = time.time()
#html_out = f"<b>{html.escape(review)}</b>"
html_out = build_svg_with_values(prediction_flavours_list)
json_out = {
"results": [{
#'icon': ICONS[name],
"name": name,
"score": score,
#'level': get_level(score)
}
for name, score in prediction_flavours.items()
],
"review": review,
"model": MODEL_FILE,
"device": device,
"duration": round((t_end_flavours - t_start_flavours), 3),
}
return html_out, json_out
##################################################################################
def random_text():
return random.choice(EXAMPLES)
def get_device_info():
if torch.cuda.is_available():
return f"◉ Runs on GPU: {torch.cuda.get_device_name(0)}"
else:
return "◎ Runs on CPU (May be slower)"
### Create Form interface with Gradio Framework ##################################
with gr.Blocks() as demo:
gr.HTML("<h2>Multi-Axis Regression of Whisky Tasting Notes</h2>")
gr.HTML("""
<h3>Automatically turn Whisky Tasting Notes into Flavour Wheels.</h3>
<p>This model is a fine-tuned version of <a href='https://huggingface.co/microsoft/deberta-v3-base'>microsoft/deberta-v3-base</a> designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — <strong>grainy</strong>, <strong>grassy</strong>, <strong>fragrant</strong>, <strong>fruity</strong>, <strong>peated</strong>, <strong>woody</strong>, <strong>winey</strong> and <strong>off-notes</strong> — on a continuous scale from 0 (none) to 4 (extreme).</p>
""")
gr.HTML(f"<span style='color: var(--block-title-text-color)'>{get_device_info()}</span>")
with gr.Row(): # alles nebeneinander
with gr.Column(scale=1): # linke Seite: Input
review_box = gr.Textbox(
label="Whisky Review (English only)",
lines=8,
placeholder="Enter whisky review",
value=random_text(),
)
with gr.Row():
replace_btn = gr.Button("Load Example", variant="secondary", scale=1)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
with gr.Column(scale=1): # rechte Seite: Output
html_out = gr.HTML(label="Flavour Wheel")
json_out = gr.JSON(label="JSON")
# Events
submit_btn.click(predict, inputs=review_box, outputs=[html_out, json_out])
replace_btn.click(random_text, outputs=review_box)
demo.launch() |