whisky-wheel / app.py
ziem-io's picture
Fix: Abort if text ist not eng
f1b3d54
raw
history blame
7.1 kB
# Standardbibliotheken
import os # Umgebungsvariablen (z.B. HF_TOKEN)
import types # für Instanz-Monkeypatch (fastText .predict)
import html # HTML-Escaping für Ausgabe/Gradio
import numpy as np # Numerik (z.B. für Wahrscheinlichkeiten)
import time
import random
# Machine Learning / NLP
import torch # PyTorch (Model, Tensor, Device)
import fasttext # Sprach-ID (lid.176)
# Diese beiden werden oft nicht direkt aufgerufen, müssen aber installiert sein,
# damit Hugging Face/Tokenizer korrekt funktionieren (SentencePiece-Backends, Converter).
import sentencepiece # Required für SentencePiece-basierte Tokenizer (DeBERTa v3)
import tiktoken # Optionaler Converter; verhindert Fallback-Fehler/Warnungen
# Hugging Face / Ökosystem
import spaces
from transformers import AutoTokenizer # Tokenizer-Lader (mit use_fast=False für SentencePiece)
from huggingface_hub import hf_hub_download # Dateien/Weights aus dem HF Hub laden
from safetensors.torch import load_file # Sicheres & schnelles Laden von Weights (.safetensors)
# UI / Serving
import gradio as gr # Web-UI für Demo/Spaces
# Projektspezifische Module
from lib.bert_regressor import BertMultiHeadRegressor
from lib.bert_regressor_utils import (
load_model_and_tokenizer,
predict_flavours,
#predict_is_review,
TARGET_COLUMNS,
ICONS
)
from lib.wheel import build_svg_with_values
from lib.examples import EXAMPLES
### Stettings ####################################################################
MODEL_BASE = "microsoft/deberta-v3-base"
REPO_ID = "ziem-io/deberta_flavour_regressor_multi_head"
FILENAME = "deberta_flavour_regressor_multi_head_20250914_1020.safetensors"
# (optional) falls das Model-Repo privat ist:
HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen
##################################################################################
# --- Download Weights ---
weights_path = hf_hub_download(
repo_id=REPO_ID,
filename=FILENAME,
token=HF_TOKEN
)
# --- Tokenizer (SentencePiece!) ---
tokenizer_flavours = AutoTokenizer.from_pretrained(
MODEL_BASE,
use_fast=False
)
model_flavours = BertMultiHeadRegressor(
pretrained_model_name=MODEL_BASE
)
state = load_file(weights_path) # safetensors -> dict[str, Tensor]
_ = model_flavours.load_state_dict(state, strict=False) # strict=True wenn Keys exakt passen
device = "cuda" if torch.cuda.is_available() else "cpu"
model_flavours.to(device).eval()
##################################################################################
# offizielles Mirror-Repo mit lid.176.*
lid_path = hf_hub_download(
repo_id="julien-c/fasttext-language-id",
filename="lid.176.ftz"
)
lid_model = fasttext.load_model(lid_path)
# robustes predict mit NumPy-2-Fix + Fallback, falls fastText nur Labels liefert
def _predict_np2_compat(self, text, k=1, threshold=0.0, on_unicode_error='strict'):
out = self.f.predict(text, k, threshold, on_unicode_error)
# Fälle:
# 1) (labels, probs)
# 2) labels-only (einige Builds/SWIG-Versionen)
if isinstance(out, tuple) and len(out) == 2:
labels, probs = out
else:
labels = out
# sinnvolle Defaults, falls keine Wahrscheinlichkeiten vorliegen
if isinstance(labels, (list, tuple)):
probs = [1.0] * len(labels)
else:
labels = [labels]
probs = [1.0]
return labels, np.asarray(probs) # np.asarray statt np.array(copy=False)
# Instanz patchen
lid_model.predict = types.MethodType(_predict_np2_compat, lid_model)
### Check if lang is english #####################################################
def is_eng(review: str):
lang_labels, lang_probs = lid_model.predict(review)
print(lang_labels, lang_probs)
if not lang_labels: # kein Label zurückgegeben
return False, 0.0
lang_label = lang_labels[0]
lang_prob = float(lang_probs[0])
return lang_label[1] == "__label__en", lang_label[1], lang_prob
### Do actual prediction #########################################################
@spaces.GPU(duration=10) # Sekunden GPU-Zeit pro Call
def predict(review: str):
review = (review or "").strip()
# Abort if no text if given
if not review:
# immer zwei Outputs zurückgeben
return "<i>Please enter a review.</i>", {}
# Check for lang of text
review_is_eng, review_lang_label, review_lang_prob = is_eng(review)
# Abort if text is not english
if not review_is_eng:
# immer zwei Outputs zurückgeben
return "<i>English texts only</i>", {"is_en": {
"lang": review_lang_label.replace("__label__", ""),
"prob": review_lang_prob
}}
prediction_flavours = {}
prediction_flavours_list = [0, 0, 0, 0, 0, 0, 0, 0]
# Do actual predictions if is english and whisky note
t_start_flavours = time.time()
prediction_flavours = predict_flavours(review, model_flavours, tokenizer_flavours, device)
prediction_flavours_list = list(prediction_flavours.values())
t_end_flavours = time.time()
#html_out = f"<b>{html.escape(review)}</b>"
html_out = build_svg_with_values(prediction_flavours_list)
json_out = {
"review": review,
"model": FILENAME,
"device": device,
"duration": round((t_end_flavours - t_start_flavours), 3),
"results": [{
#'icon': ICONS[name],
"name": name,
"score": score,
#'level': get_level(score)
}
for name, score in prediction_flavours.items()
],
}
return html_out, json_out
##################################################################################
def replace_text():
return random.choice(EXAMPLES)
def get_device_info():
if torch.cuda.is_available():
return f"Runs on GPU: {torch.cuda.get_device_name(0)}"
else:
return "Runs on CPU"
### Create Form interface with Gradio Framework ##################################
with gr.Blocks() as demo:
gr.Markdown("## Submit Whisky Review for Classification")
with gr.Row(): # alles nebeneinander
with gr.Column(scale=1): # linke Seite: Input
device_label = gr.Markdown(get_device_info())
review_box = gr.Textbox(
label="Whisky Review",
lines=8,
placeholder="Enter whisky review",
value=random.choice(EXAMPLES),
)
with gr.Row():
replace_btn = gr.Button("Replace Example", variant="secondary", scale=1)
submit_btn = gr.Button("Submit", variant="primary", scale=1)
with gr.Column(scale=1): # rechte Seite: Output
html_out = gr.HTML(label="Table")
json_out = gr.JSON(label="JSON")
# Events
submit_btn.click(predict, inputs=review_box, outputs=[html_out, json_out])
replace_btn.click(replace_text, outputs=review_box)
demo.launch()