Create app.py
Browse files
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,51 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import streamlit as st
         | 
| 2 | 
            +
            from transformers import pipeline
         | 
| 3 | 
            +
            import streamlit as st
         | 
| 4 | 
            +
            import json
         | 
| 5 | 
            +
            import torch
         | 
| 6 | 
            +
            from torch.nn import functional as F
         | 
| 7 | 
            +
            from transformers import AutoTokenizer, AutoModelForSequenceClassification
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            @st.cache_resource
         | 
| 10 | 
            +
            def load_dicts():
         | 
| 11 | 
            +
                with open("label2ind.json", "r") as file:
         | 
| 12 | 
            +
                    label2ind = json.load(file)
         | 
| 13 | 
            +
                with open("ind2label.json", "r") as file:
         | 
| 14 | 
            +
                    ind2label = json.load(file)
         | 
| 15 | 
            +
                return label2ind, ind2label
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            @st.cache_resource
         | 
| 18 | 
            +
            def load_model():
         | 
| 19 | 
            +
                tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
         | 
| 20 | 
            +
                model = AutoModelForSequenceClassification.from_pretrained(
         | 
| 21 | 
            +
                    "my_model/checkpoint-23000",
         | 
| 22 | 
            +
                    num_labels=len(label2ind),
         | 
| 23 | 
            +
                    problem_type="single_label_classification",
         | 
| 24 | 
            +
                )
         | 
| 25 | 
            +
                return tokenizer, model
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            label2ind, ind2label = load_dicts()
         | 
| 28 | 
            +
            tokenizer, model = load_model()
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            title = st.text_input("Title", value="Math")
         | 
| 31 | 
            +
            abstract = st.text_input("Abstract", value="Random variable")
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            def get_logits(title, abstract):
         | 
| 34 | 
            +
                text = title + "###" + abstract
         | 
| 35 | 
            +
                logits = model(tokenizer(text, return_tensors="pt")['input_ids'])['logits']
         | 
| 36 | 
            +
                return logits
         | 
| 37 | 
            +
             | 
| 38 | 
            +
            def get_ans(logits):
         | 
| 39 | 
            +
                ind = torch.argsort(logits, dim=1, descending=True)
         | 
| 40 | 
            +
                logits = F.softmax(logits)
         | 
| 41 | 
            +
                cum_sum = 0
         | 
| 42 | 
            +
                i = 0
         | 
| 43 | 
            +
                while cum_sum < 0.95:
         | 
| 44 | 
            +
                    cum_sum += logits[0][ind[i]]
         | 
| 45 | 
            +
                    st.write(f"label {ind2label[ind[i]]} with probability {logits[0][ind[i]] * 100}%")
         | 
| 46 | 
            +
                    i +=1
         | 
| 47 | 
            +
             | 
| 48 | 
            +
            if title or abstract:
         | 
| 49 | 
            +
                st.write(query)
         | 
| 50 | 
            +
                st.write(result)
         | 
| 51 | 
            +
                logits = get_logits(text, abstract)
         |