File size: 6,681 Bytes
9a5d8ec
 
 
 
 
177a8e6
2954446
9a5d8ec
 
 
 
 
 
 
 
b6ecc6e
9a5d8ec
 
177a8e6
9a5d8ec
 
 
 
 
 
d53315e
bb3d05e
 
 
d463b27
bb3d05e
 
d463b27
 
bb3d05e
8be8777
2fe0796
bb3d05e
9cfe75e
 
3d475b8
 
 
 
 
f677a94
3d475b8
 
 
 
 
 
f677a94
3d475b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cfe75e
325ed03
b6ecc6e
6370ba9
7fa1794
b6ecc6e
 
 
 
 
9ea7979
9cfe75e
defc447
c921016
9ea7979
 
 
f1b3d54
f35e5d2
 
b3f79ff
177a8e6
f1b3d54
eded7fe
f1b3d54
 
 
 
7dec0d9
eded7fe
f1b3d54
5c5c1da
f1b3d54
177a8e6
d5e14ad
177a8e6
f1b3d54
 
 
 
 
f35e5d2
f1b3d54
defc447
9799c94
c921016
 
 
defc447
 
 
c921016
 
 
 
4c56373
f677a94
4c56373
 
c921016
 
 
cd07d29
79079d1
 
fa07bfe
2954446
79079d1
32c703c
 
f6cd476
32c703c
3a6ecfa
32c703c
9cfe75e
2954446
06353d7
296279c
06353d7
296279c
06353d7
 
2954446
c50c843
 
 
7dec0d9
c50c843
 
fa07bfe
c50c843
 
ad4809c
c50c843
 
 
5c5c1da
c50c843
79079d1
 
2954446
fa07bfe
ac240ce
2954446
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
# Standardbibliotheken
import os        # Umgebungsvariablen (z.B. HF_TOKEN)
import types     # für Instanz-Monkeypatch (fastText .predict)
import html      # HTML-Escaping für Ausgabe/Gradio
import numpy as np  # Numerik (z.B. für Wahrscheinlichkeiten)
import time
import random

# Machine Learning / NLP
import torch                      # PyTorch (Model, Tensor, Device)
import fasttext                   # Sprach-ID (lid.176)
# Diese beiden werden oft nicht direkt aufgerufen, müssen aber installiert sein,
# damit Hugging Face/Tokenizer korrekt funktionieren (SentencePiece-Backends, Converter).
import sentencepiece              # Required für SentencePiece-basierte Tokenizer (DeBERTa v3)
import tiktoken                   # Optionaler Converter; verhindert Fallback-Fehler/Warnungen
from langid.langid import LanguageIdentifier, model

# Hugging Face / Ökosystem
import spaces
from transformers import AutoTokenizer   # Tokenizer-Lader (mit use_fast=False für SentencePiece)
from huggingface_hub import hf_hub_download  # Dateien/Weights aus dem HF Hub laden
from safetensors.torch import load_file      # Sicheres & schnelles Laden von Weights (.safetensors)

# UI / Serving
import gradio as gr               # Web-UI für Demo/Spaces

# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor
from lib.bert_regressor_utils import (
    #load_model_and_tokenizer,
    predict_flavours,
    #predict_is_review,
    #TARGET_COLUMNS,
    #ICONS
)
from lib.wheel import build_svg_with_values
from lib.examples import EXAMPLES

### Stettings ####################################################################

MODEL_BASE = "microsoft/deberta-v3-base"
REPO_ID    = "ziem-io/deberta_flavour_regressor_multi_head"

# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN")  # in Space-Secrets hinterlegen
MODEL_FILE = os.getenv("MODEL_FILE")  # in Space-Secrets hinterlegen

##################################################################################

# --- Download Weights ---
weights_path = hf_hub_download(
    repo_id=REPO_ID, 
    filename=MODEL_FILE, 
    token=HF_TOKEN
)

# --- Tokenizer (SentencePiece!) ---
tokenizer_flavours = AutoTokenizer.from_pretrained(
    MODEL_BASE, 
    use_fast=False
)

model_flavours = BertMultiHeadRegressor(
    pretrained_model_name=MODEL_BASE
)
state = load_file(weights_path)                    # safetensors -> dict[str, Tensor]
_ = model_flavours.load_state_dict(state, strict=False)  # strict=True wenn Keys exakt passen

device = "cuda" if torch.cuda.is_available() else "cpu"
model_flavours.to(device).eval()

### Check if lang is english #####################################################

ID = LanguageIdentifier.from_modelstring(model, norm_probs=True)

def is_eng(text: str, min_chars: int = 6, threshold: float = 0.1):
    t = (text or "").strip()
    if len(t) < min_chars:
        return True, 0.0
    lang, prob = ID.classify(t)  # prob ∈ [0,1]
    return (lang == "en" and prob >= threshold), float(prob)

### Do actual prediction #########################################################
@spaces.GPU(duration=10)  # Sekunden GPU-Zeit pro Call
def predict(review: str):

    review = (review or "").strip()

    # Abort if no text if given
    if not review:
        # immer zwei Outputs zurückgeben
        return "Please enter a review.", {}
    
    # Check for lang of text
    review_is_eng, review_lang_prob = is_eng(review)

    # Abort if text is not english
    if not review_is_eng:
        # immer zwei Outputs zurückgeben
        return "Currently, only English reviews are supported.", {
            "is_eng": review_is_eng,
            "prob": review_lang_prob
        }
    
    prediction_flavours = {}
    prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]

    # Do actual predictions if is english and whisky note
    t_start_flavours = time.time()
    prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device)
    prediction_flavours_list = list(prediction_flavours.values())
    t_end_flavours = time.time()

    #html_out = f"<b>{html.escape(review)}</b>"

    html_out = build_svg_with_values(prediction_flavours_list)

    json_out = {
        "results": [{
                #'icon': ICONS[name],
                "name": name,
                "score": score,
                #'level': get_level(score)
            }
            for name, score in prediction_flavours.items()
        ],
        "review": review,
        "model": MODEL_FILE,
        "device": device,
        "duration": round((t_end_flavours - t_start_flavours), 3),
    }

    return html_out, json_out

##################################################################################

def random_text():
    return random.choice(EXAMPLES)

def get_device_info():
    if torch.cuda.is_available():
        return f"◉ Runs on GPU: {torch.cuda.get_device_name(0)}"
    else:
        return "◎ Runs on CPU (May be slower)"

### Create Form interface with Gradio Framework ##################################
with gr.Blocks() as demo:
    gr.HTML("<h2>Multi-Axis Regression of Whisky Tasting Notes</h2>")
    gr.HTML("""
    <h3>Automatically turn Whisky Tasting Notes into Flavour Wheels.</h3>
    <p>This model is a fine-tuned version of <a href='https://huggingface.co/microsoft/deberta-v3-base'>microsoft/deberta-v3-base</a> designed to analyze English whisky tasting notes. It predicts the intensity of eight sensory categories — <strong>grainy</strong>, <strong>grassy</strong>, <strong>fragrant</strong>, <strong>fruity</strong>, <strong>peated</strong>, <strong>woody</strong>, <strong>winey</strong> and <strong>off-notes</strong> — on a continuous scale from 0 (none) to 4 (extreme).</p>
""")
    gr.HTML(f"<span style='color: var(--block-title-text-color)'>{get_device_info()}</span>")

    with gr.Row():  # alles nebeneinander
        with gr.Column(scale=1):  # linke Seite: Input
            review_box = gr.Textbox(
                label="Whisky Review (English only)",
                lines=8,
                placeholder="Enter whisky review",
                value=random_text(),
            )
            with gr.Row():
                replace_btn = gr.Button("Load Example", variant="secondary", scale=1)
                submit_btn  = gr.Button("Submit", variant="primary", scale=1)

        with gr.Column(scale=1):  # rechte Seite: Output
            html_out = gr.HTML(label="Flavour Wheel")
            json_out = gr.JSON(label="JSON")

    # Events
    submit_btn.click(predict, inputs=review_box, outputs=[html_out, json_out])
    replace_btn.click(random_text, outputs=review_box)

demo.launch()