Update app.py
Browse files
app.py
CHANGED
|
@@ -9,6 +9,7 @@ from sklearn.preprocessing import MultiLabelBinarizer
|
|
| 9 |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 10 |
|
| 11 |
# Load the trained model and tokenizer
|
|
|
|
| 12 |
@st.cache_resource
|
| 13 |
def load_model():
|
| 14 |
model = AutoModelForSequenceClassification.from_pretrained(
|
|
@@ -16,13 +17,13 @@ def load_model():
|
|
| 16 |
num_labels=8, # Adjust based on your label count
|
| 17 |
problem_type="multi_label_classification"
|
| 18 |
)
|
| 19 |
-
model
|
|
|
|
| 20 |
model.eval()
|
| 21 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract")
|
| 22 |
-
model = model.to(device) # Move the model to the correct device
|
| 23 |
-
|
| 24 |
return model, tokenizer
|
| 25 |
|
|
|
|
| 26 |
@st.cache_resource
|
| 27 |
def load_mlb():
|
| 28 |
# Define the classes based on your label set
|
|
@@ -62,7 +63,9 @@ if st.button('Predict'):
|
|
| 62 |
inputs = tokenizer(clinical_note, truncation=True, padding="max_length", max_length=512, return_tensors='pt')
|
| 63 |
|
| 64 |
# Move inputs to the GPU if available
|
| 65 |
-
inputs = {key: val.to(device) for key, val in inputs.items()}
|
|
|
|
|
|
|
| 66 |
|
| 67 |
# Model inference
|
| 68 |
with torch.no_grad():
|
|
@@ -91,30 +94,4 @@ if st.button('Predict'):
|
|
| 91 |
# st.write("Please enter clinical notes for prediction.")
|
| 92 |
|
| 93 |
|
| 94 |
-
|
| 95 |
-
# if st.button('Predict'):
|
| 96 |
-
# if clinical_note:
|
| 97 |
-
# # Tokenize the input clinical note
|
| 98 |
-
# inputs = tokenizer(clinical_note, truncation=True, padding="max_length", max_length=512, return_tensors='pt')
|
| 99 |
-
|
| 100 |
-
# # Move inputs to the GPU if available
|
| 101 |
-
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 102 |
-
# inputs = {key: val.to(device) for key, val in inputs.items()}
|
| 103 |
-
|
| 104 |
-
# # Model inference
|
| 105 |
-
# with torch.no_grad():
|
| 106 |
-
# outputs = model(**inputs)
|
| 107 |
-
# logits = outputs.logits
|
| 108 |
-
|
| 109 |
-
# # Apply sigmoid and threshold the output (0.5 for multi-label classification)
|
| 110 |
-
# pred_labels = (torch.sigmoid(logits) > 0.5).cpu().numpy()
|
| 111 |
-
|
| 112 |
-
# # Get the predicted ICD and CPT codes
|
| 113 |
-
# predicted_codes = mlb.inverse_transform(pred_labels)
|
| 114 |
-
|
| 115 |
-
# # Show the results
|
| 116 |
-
# st.write("Predicted ICD and CPT Codes:")
|
| 117 |
-
# st.write(predicted_codes)
|
| 118 |
-
|
| 119 |
-
# else:
|
| 120 |
-
# st.write("Please enter clinical notes for prediction.")
|
|
|
|
| 9 |
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 10 |
|
| 11 |
# Load the trained model and tokenizer
|
| 12 |
+
|
| 13 |
@st.cache_resource
|
| 14 |
def load_model():
|
| 15 |
model = AutoModelForSequenceClassification.from_pretrained(
|
|
|
|
| 17 |
num_labels=8, # Adjust based on your label count
|
| 18 |
problem_type="multi_label_classification"
|
| 19 |
)
|
| 20 |
+
# Map the model to the appropriate device
|
| 21 |
+
model.load_state_dict(torch.load('best_model_v2.pth', map_location=torch.device('cpu')))
|
| 22 |
model.eval()
|
| 23 |
tokenizer = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract")
|
|
|
|
|
|
|
| 24 |
return model, tokenizer
|
| 25 |
|
| 26 |
+
|
| 27 |
@st.cache_resource
|
| 28 |
def load_mlb():
|
| 29 |
# Define the classes based on your label set
|
|
|
|
| 63 |
inputs = tokenizer(clinical_note, truncation=True, padding="max_length", max_length=512, return_tensors='pt')
|
| 64 |
|
| 65 |
# Move inputs to the GPU if available
|
| 66 |
+
# inputs = {key: val.to(device) for key, val in inputs.items()}
|
| 67 |
+
inputs = {key: val.to(torch.device('cpu')) for key, val in inputs.items()}
|
| 68 |
+
|
| 69 |
|
| 70 |
# Model inference
|
| 71 |
with torch.no_grad():
|
|
|
|
| 94 |
# st.write("Please enter clinical notes for prediction.")
|
| 95 |
|
| 96 |
|
| 97 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|