Spaces:
Running on CPU Upgrade

File size: 7,104 Bytes
9a5d8ec
 
 
 
 
177a8e6
2954446
9a5d8ec
 
 
 
 
 
 
 
 
 
177a8e6
9a5d8ec
 
 
 
 
 
d53315e
bb3d05e
 
 
 
 
 
 
 
 
8be8777
2fe0796
bb3d05e
9cfe75e
 
3d475b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cfe75e
a8b9f74
 
 
 
 
 
 
 
 
9ea7979
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8b9f74
9ea7979
 
a8b9f74
9cfe75e
a8b9f74
9ea7979
a8b9f74
577c02d
9cfe75e
 
577c02d
 
 
a8b9f74
 
 
f1b3d54
9ea7979
9cfe75e
defc447
c921016
9ea7979
 
 
f1b3d54
f35e5d2
 
 
177a8e6
f1b3d54
 
 
 
 
 
 
 
 
 
 
177a8e6
d5e14ad
177a8e6
f1b3d54
 
 
 
 
f35e5d2
f1b3d54
defc447
9799c94
c921016
 
 
defc447
c921016
 
 
defc447
 
 
c921016
 
 
 
 
 
 
cd07d29
79079d1
 
 
2954446
79079d1
32c703c
 
 
 
 
 
9cfe75e
2954446
 
 
c50c843
 
32c703c
c50c843
 
 
 
2fe0796
c50c843
 
 
 
 
 
 
 
79079d1
 
2954446
 
ac240ce
2954446
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
# Standardbibliotheken
import os        # Umgebungsvariablen (z.B. HF_TOKEN)
import types     # für Instanz-Monkeypatch (fastText .predict)
import html      # HTML-Escaping für Ausgabe/Gradio
import numpy as np  # Numerik (z.B. für Wahrscheinlichkeiten)
import time
import random

# Machine Learning / NLP
import torch                      # PyTorch (Model, Tensor, Device)
import fasttext                   # Sprach-ID (lid.176)
# Diese beiden werden oft nicht direkt aufgerufen, müssen aber installiert sein,
# damit Hugging Face/Tokenizer korrekt funktionieren (SentencePiece-Backends, Converter).
import sentencepiece              # Required für SentencePiece-basierte Tokenizer (DeBERTa v3)
import tiktoken                   # Optionaler Converter; verhindert Fallback-Fehler/Warnungen

# Hugging Face / Ökosystem
import spaces
from transformers import AutoTokenizer   # Tokenizer-Lader (mit use_fast=False für SentencePiece)
from huggingface_hub import hf_hub_download  # Dateien/Weights aus dem HF Hub laden
from safetensors.torch import load_file      # Sicheres & schnelles Laden von Weights (.safetensors)

# UI / Serving
import gradio as gr               # Web-UI für Demo/Spaces

# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor
from lib.bert_regressor_utils import (
    load_model_and_tokenizer,
    predict_flavours,
    #predict_is_review,
    TARGET_COLUMNS,
    ICONS
)
from lib.wheel import build_svg_with_values
from lib.examples import EXAMPLES

### Stettings ####################################################################

MODEL_BASE = "microsoft/deberta-v3-base"
REPO_ID    = "ziem-io/deberta_flavour_regressor_multi_head"
FILENAME   = "deberta_flavour_regressor_multi_head_20250914_1020.safetensors"

# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN")  # in Space-Secrets hinterlegen

##################################################################################

# --- Download Weights ---
weights_path = hf_hub_download(
    repo_id=REPO_ID, 
    filename=FILENAME, 
    token=HF_TOKEN
)

# --- Tokenizer (SentencePiece!) ---
tokenizer_flavours = AutoTokenizer.from_pretrained(
    MODEL_BASE, 
    use_fast=False
)

model_flavours = BertMultiHeadRegressor(
    pretrained_model_name=MODEL_BASE
)
state = load_file(weights_path)                    # safetensors -> dict[str, Tensor]
_ = model_flavours.load_state_dict(state, strict=False)  # strict=True wenn Keys exakt passen

device = "cuda" if torch.cuda.is_available() else "cpu"
model_flavours.to(device).eval()

##################################################################################

# offizielles Mirror-Repo mit lid.176.*
lid_path = hf_hub_download(
    repo_id="julien-c/fasttext-language-id",
    filename="lid.176.ftz"
)

lid_model = fasttext.load_model(lid_path)

# robustes predict mit NumPy-2-Fix + Fallback, falls fastText nur Labels liefert
def _predict_np2_compat(self, text, k=1, threshold=0.0, on_unicode_error='strict'):
    out = self.f.predict(text, k, threshold, on_unicode_error)
    # Fälle:
    # 1) (labels, probs)
    # 2) labels-only (einige Builds/SWIG-Versionen)
    if isinstance(out, tuple) and len(out) == 2:
        labels, probs = out
    else:
        labels = out
        # sinnvolle Defaults, falls keine Wahrscheinlichkeiten vorliegen
        if isinstance(labels, (list, tuple)):
            probs = [1.0] * len(labels)
        else:
            labels = [labels]
            probs = [1.0]
    return labels, np.asarray(probs)  # np.asarray statt np.array(copy=False)

# Instanz patchen
lid_model.predict = types.MethodType(_predict_np2_compat, lid_model)

### Check if lang is english #####################################################

def is_eng(review: str):
    lang_labels, lang_probs = lid_model.predict(review)

    print(lang_labels, lang_probs)

    if not lang_labels:  # kein Label zurückgegeben
        return False, 0.0
    
    lang_label = lang_labels[0]
    lang_prob = float(lang_probs[0])

    return lang_label[1] == "__label__en", lang_label[1], lang_prob

### Do actual prediction #########################################################
@spaces.GPU(duration=10)  # Sekunden GPU-Zeit pro Call
def predict(review: str):

    review = (review or "").strip()

    # Abort if no text if given
    if not review:
        # immer zwei Outputs zurückgeben
        return "<i>Please enter a review.</i>", {}
    
    # Check for lang of text
    review_is_eng, review_lang_label, review_lang_prob = is_eng(review)

    # Abort if text is not english
    if not review_is_eng:
        # immer zwei Outputs zurückgeben
        return "<i>English texts only</i>", {"is_en": {
            "lang": review_lang_label.replace("__label__", ""),
            "prob": review_lang_prob
        }}
    
    prediction_flavours = {}
    prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]

    # Do actual predictions if is english and whisky note
    t_start_flavours = time.time()
    prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device)
    prediction_flavours_list = list(prediction_flavours.values())
    t_end_flavours = time.time()

    #html_out = f"<b>{html.escape(review)}</b>"

    html_out = build_svg_with_values(prediction_flavours_list)

    json_out = {
        "review": review,
        "model": FILENAME,
        "device": device,
        "duration": round((t_end_flavours - t_start_flavours), 3),
        "results": [{
                #'icon': ICONS[name],
                "name": name,
                "score": score,
                #'level': get_level(score)
            }
            for name, score in prediction_flavours.items()
        ],
    }

    return html_out, json_out

##################################################################################

def replace_text():
    return random.choice(EXAMPLES)

def get_device_info():
    if torch.cuda.is_available():
        return f"Runs on GPU: {torch.cuda.get_device_name(0)}"
    else:
        return "Runs on CPU"

### Create Form interface with Gradio Framework ##################################
with gr.Blocks() as demo:
    gr.Markdown("## Submit Whisky Review for Classification")

    with gr.Row():  # alles nebeneinander
        with gr.Column(scale=1):  # linke Seite: Input
            device_label = gr.Markdown(get_device_info())
            review_box = gr.Textbox(
                label="Whisky Review",
                lines=8,
                placeholder="Enter whisky review",
                value=random.choice(EXAMPLES),
            )
            with gr.Row():
                replace_btn = gr.Button("Replace Example", variant="secondary", scale=1)
                submit_btn  = gr.Button("Submit", variant="primary", scale=1)

        with gr.Column(scale=1):  # rechte Seite: Output
            html_out = gr.HTML(label="Table")
            json_out = gr.JSON(label="JSON")

    # Events
    submit_btn.click(predict, inputs=review_box, outputs=[html_out, json_out])
    replace_btn.click(replace_text, outputs=review_box)

demo.launch()