File size: 945 Bytes
44ee98d
 
 
 
f4107c5
44ee98d
 
84152d3
44ee98d
 
 
 
 
 
 
 
 
 
5a69625
44ee98d
5a69625
44ee98d
5a69625
 
 
 
 
 
 
 
f4107c5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from transformers import  AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import pipeline, Pipeline
from joblib import load


def load_model(path2chkpt: str, path2mapping: str):
    model = AutoModelForSequenceClassification.from_pretrained(path2chkpt)
    tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-cased")

    pipe = pipeline("text-classification", 
                   model=model, 
                   tokenizer=tokenizer)

    class2name = load(path2mapping)
    return pipe, class2name


def top_95_labels(pipe: Pipeline, class2name: dict[str, str], title: str, abstract: str):
    inputs = ".".join([title, abstract])
    result = pipe(inputs, top_k=20)
    
    proba = 0
    labels = []
    i = 0
    while proba < 0.95:
        proba += result[i]["score"]
        labels.append(result[i]["label"])
        i += 1
    return [class2name[label] for label in labels]