Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 7,104 Bytes
9a5d8ec 177a8e6 2954446 9a5d8ec 177a8e6 9a5d8ec d53315e bb3d05e 8be8777 2fe0796 bb3d05e 9cfe75e 3d475b8 9cfe75e a8b9f74 9ea7979 a8b9f74 9ea7979 a8b9f74 9cfe75e a8b9f74 9ea7979 a8b9f74 577c02d 9cfe75e 577c02d a8b9f74 f1b3d54 9ea7979 9cfe75e defc447 c921016 9ea7979 f1b3d54 f35e5d2 177a8e6 f1b3d54 177a8e6 d5e14ad 177a8e6 f1b3d54 f35e5d2 f1b3d54 defc447 9799c94 c921016 defc447 c921016 defc447 c921016 cd07d29 79079d1 2954446 79079d1 32c703c 9cfe75e 2954446 c50c843 32c703c c50c843 2fe0796 c50c843 79079d1 2954446 ac240ce 2954446 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
# Standardbibliotheken
import os # Umgebungsvariablen (z.B. HF_TOKEN)
import types # für Instanz-Monkeypatch (fastText .predict)
import html # HTML-Escaping für Ausgabe/Gradio
import numpy as np # Numerik (z.B. für Wahrscheinlichkeiten)
import time
import random
# Machine Learning / NLP
import torch # PyTorch (Model, Tensor, Device)
import fasttext # Sprach-ID (lid.176)
# Diese beiden werden oft nicht direkt aufgerufen, müssen aber installiert sein,
# damit Hugging Face/Tokenizer korrekt funktionieren (SentencePiece-Backends, Converter).
import sentencepiece # Required für SentencePiece-basierte Tokenizer (DeBERTa v3)
import tiktoken # Optionaler Converter; verhindert Fallback-Fehler/Warnungen
# Hugging Face / Ökosystem
import spaces
from transformers import AutoTokenizer # Tokenizer-Lader (mit use_fast=False für SentencePiece)
from huggingface_hub import hf_hub_download # Dateien/Weights aus dem HF Hub laden
from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors)
# UI / Serving
import gradio as gr # Web-UI für Demo/Spaces
# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor
from lib.bert_regressor_utils import (
load_model_and_tokenizer,
predict_flavours,
#predict_is_review,
TARGET_COLUMNS,
ICONS
)
from lib.wheel import build_svg_with_values
from lib.examples import EXAMPLES
### Stettings ####################################################################
MODEL_BASE = "microsoft/deberta-v3-base"
REPO_ID = "ziem-io/deberta_flavour_regressor_multi_head"
FILENAME = "deberta_flavour_regressor_multi_head_20250914_1020.safetensors"
# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen
##################################################################################
# --- Download Weights ---
weights_path = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
token=HF_TOKEN
)
# --- Tokenizer (SentencePiece!) ---
tokenizer_flavours = AutoTokenizer.from_pretrained(
MODEL_BASE,
use_fast=False
)
model_flavours = BertMultiHeadRegressor(
pretrained_model_name=MODEL_BASE
)
state = load_file(weights_path) # safetensors -> dict[str, Tensor]
_ = model_flavours.load_state_dict(state, strict=False) # strict=True wenn Keys exakt passen
device = "cuda" if torch.cuda.is_available() else "cpu"
model_flavours.to(device).eval()
##################################################################################
# offizielles Mirror-Repo mit lid.176.*
lid_path = hf_hub_download(
repo_id="julien-c/fasttext-language-id",
filename="lid.176.ftz"
)
lid_model = fasttext.load_model(lid_path)
# robustes predict mit NumPy-2-Fix + Fallback, falls fastText nur Labels liefert
def _predict_np2_compat(self, text, k=1, threshold=0.0, on_unicode_error='strict'):
out = self.f.predict(text, k, threshold, on_unicode_error)
# Fälle:
# 1) (labels, probs)
# 2) labels-only (einige Builds/SWIG-Versionen)
if isinstance(out, tuple) and len(out) == 2:
labels, probs = out
else:
labels = out
# sinnvolle Defaults, falls keine Wahrscheinlichkeiten vorliegen
if isinstance(labels, (list, tuple)):
probs = [1.0] * len(labels)
else:
labels = [labels]
probs = [1.0]
return labels, np.asarray(probs) # np.asarray statt np.array(copy=False)
# Instanz patchen
lid_model.predict = types.MethodType(_predict_np2_compat, lid_model)
### Check if lang is english #####################################################
def is_eng(review: str):
lang_labels, lang_probs = lid_model.predict(review)
print(lang_labels, lang_probs)
if not lang_labels: # kein Label zurückgegeben
return False, 0.0
lang_label = lang_labels[0]
lang_prob = float(lang_probs[0])
return lang_label[1] == "__label__en", lang_label[1], lang_prob
### Do actual prediction #########################################################
@spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call
def predict(review: str):
review = (review or "").strip()
# Abort if no text if given
if not review:
# immer zwei Outputs zurückgeben
return "<i>Please enter a review.</i>", {}
# Check for lang of text
review_is_eng, review_lang_label, review_lang_prob = is_eng(review)
# Abort if text is not english
if not review_is_eng:
# immer zwei Outputs zurückgeben
return "<i>English texts only</i>", {"is_en": {
"lang": review_lang_label.replace("__label__", ""),
"prob": review_lang_prob
}}
prediction_flavours = {}
prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]
# Do actual predictions if is english and whisky note
t_start_flavours = time.time()
prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device)
prediction_flavours_list = list(prediction_flavours.values())
t_end_flavours = time.time()
#html_out = f"<b>{html.escape(review)}</b>"
html_out = build_svg_with_values(prediction_flavours_list)
json_out = {
"review": review,
"model": FILENAME,
"device": device,
"duration": round((t_end_flavours - t_start_flavours), 3),
"results": [{
#'icon': ICONS[name],
"name": name,
"score": score,
#'level': get_level(score)
}
for name, score in prediction_flavours.items()
],
}
return html_out, json_out
##################################################################################
def replace_text():
return random.choice(EXAMPLES)
def get_device_info():
if torch.cuda.is_available():
return f"Runs on GPU: {torch.cuda.get_device_name(0)}"
else:
return "Runs on CPU"
### Create Form interface with Gradio Framework ##################################
with gr.Blocks() as demo:
gr.Markdown("## Submit Whisky Review for Classification")
with gr.Row(): # alles nebeneinander
with gr.Column(scale=1): # linke Seite: Input
device_label = gr.Markdown(get_device_info())
review_box = gr.Textbox(
label="Whisky Review",
lines=8,
placeholder="Enter whisky review",
value=random.choice(EXAMPLES),
)
with gr.Row():
replace_btn = gr.Button("Replace Example", variant="secondary", scale=1)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
with gr.Column(scale=1): # rechte Seite: Output
html_out = gr.HTML(label="Table")
json_out = gr.JSON(label="JSON")
# Events
submit_btn.click(predict, inputs=review_box, outputs=[html_out, json_out])
replace_btn.click(replace_text, outputs=review_box)
demo.launch() |