ziem-io commited on
Commit
3d475b8
·
1 Parent(s): 9cfe75e

New: Load prediction model

Browse files
Files changed (1) hide show
  1. app.py +33 -0
app.py CHANGED
@@ -1,8 +1,10 @@
 
1
  import gradio as gr
2
  import fasttext
3
  import html
4
  import numpy as np
5
  import types
 
6
  from huggingface_hub import hf_hub_download
7
  from safetensors.torch import load_file
8
  from transformers import AutoTokenizer
@@ -19,6 +21,37 @@ from lib.bert_regressor_utils import (
19
 
20
  ### Stettings ####################################################################
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  ##################################################################################
23
 
24
  # offizielles Mirror-Repo mit lid.176.*
 
1
+ import os
2
  import gradio as gr
3
  import fasttext
4
  import html
5
  import numpy as np
6
  import types
7
+ import torch
8
  from huggingface_hub import hf_hub_download
9
  from safetensors.torch import load_file
10
  from transformers import AutoTokenizer
 
21
 
22
  ### Stettings ####################################################################
23
 
24
+ MODEL_BASE = "microsoft/deberta-v3-base"
25
+ REPO_ID = "ziem-io/deberta_flavour_regressor_multi_head"
26
+ FILENAME = "deberta_flavour_regressor_multi_head_20250914_1020.safetensors"
27
+
28
+ # (optional) falls das Model-Repo privat ist:
29
+ HF_TOKEN = os.getenv("HF_TOKEN") # in Space-Secrets hinterlegen
30
+
31
+ ##################################################################################
32
+
33
+ # --- Download Weights ---
34
+ weights_path = hf_hub_download(
35
+ repo_id=REPO_ID,
36
+ filename=FILENAME,
37
+ token=HF_TOKEN
38
+ )
39
+
40
+ # --- Tokenizer (SentencePiece!) ---
41
+ tokenizer_flavours = AutoTokenizer.from_pretrained(
42
+ MODEL_BASE,
43
+ use_fast=False
44
+ )
45
+
46
+ model_flavours = BertMultiHeadRegressor(
47
+ pretrained_model_name=MODEL_BASE
48
+ )
49
+ state = load_file(weights_path) # safetensors -> dict[str, Tensor]
50
+ _ = model_flavours.load_state_dict(state, strict=False) # strict=True wenn Keys exakt passen
51
+
52
+ device = "cuda" if torch.cuda.is_available() else "cpu"
53
+ model_flavours.to(device).eval()
54
+
55
  ##################################################################################
56
 
57
  # offizielles Mirror-Repo mit lid.176.*