Upload app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AlbertTokenizer, AlbertForSequenceClassification, AlbertModel
|
| 4 |
+
import numpy as np
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import os
|
| 7 |
+
from torch.nn.functional import softmax
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
|
| 10 |
+
# Paths
|
| 11 |
+
LEVEL_DIRS = {
|
| 12 |
+
1: 'level1',
|
| 13 |
+
2: 'level2',
|
| 14 |
+
3: 'level3',
|
| 15 |
+
4: 'level4',
|
| 16 |
+
5: 'level5',
|
| 17 |
+
6: 'level6',
|
| 18 |
+
7: 'level7'
|
| 19 |
+
}
|
| 20 |
+
MAPPING_FILE = 'mapping.csv'
|
| 21 |
+
MODEL_NAME = 'albert/albert-base-v2' # Define the base model name
|
| 22 |
+
|
| 23 |
+
# Load mapping
|
| 24 |
+
mapping_df = pd.read_csv(MAPPING_FILE)
|
| 25 |
+
|
| 26 |
+
def get_label_text(level, predicted_id):
|
| 27 |
+
level_map = {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 5, 6: 6}
|
| 28 |
+
level_num = level_map.get(level)
|
| 29 |
+
if level_num is not None:
|
| 30 |
+
row = mapping_df[(mapping_df['level'] == level_num) & (mapping_df['id'] == predicted_id)]
|
| 31 |
+
return row['text'].iloc[0] if not row.empty else "Description not found"
|
| 32 |
+
return "Invalid Level"
|
| 33 |
+
|
| 34 |
+
def predict_level(level, text, parent_prediction_id=None, checkpoint_path=None):
|
| 35 |
+
level_dir = LEVEL_DIRS[level]
|
| 36 |
+
tokenizer = AlbertTokenizer.from_pretrained(checkpoint_path)
|
| 37 |
+
label_map = np.load(os.path.join(level_dir, 'label_map.npy'), allow_pickle=True).item()
|
| 38 |
+
num_labels = len(label_map)
|
| 39 |
+
|
| 40 |
+
if level == 1:
|
| 41 |
+
model = AlbertForSequenceClassification.from_pretrained(checkpoint_path)
|
| 42 |
+
else:
|
| 43 |
+
parent_level_dir = LEVEL_DIRS[level - 1]
|
| 44 |
+
parent_label_map = np.load(os.path.join(parent_level_dir, 'label_map.npy'), allow_pickle=True).item()
|
| 45 |
+
num_parent_labels = len(parent_label_map)
|
| 46 |
+
|
| 47 |
+
class TaxonomyClassifier(nn.Module):
|
| 48 |
+
def __init__(self, base_model_name, num_parent_labels, num_labels):
|
| 49 |
+
super().__init__()
|
| 50 |
+
self.albert = AlbertModel.from_pretrained(base_model_name)
|
| 51 |
+
self.dropout = nn.Dropout(0.1)
|
| 52 |
+
self.classifier = nn.Linear(self.albert.config.hidden_size + num_parent_labels, num_labels)
|
| 53 |
+
|
| 54 |
+
def forward(self, input_ids, attention_mask, parent_ids):
|
| 55 |
+
outputs = self.albert(input_ids, attention_mask=attention_mask)
|
| 56 |
+
pooled_output = outputs.pooler_output
|
| 57 |
+
pooled_output = self.dropout(pooled_output)
|
| 58 |
+
combined_features = torch.cat((pooled_output, parent_ids), dim=1)
|
| 59 |
+
logits = self.classifier(combined_features)
|
| 60 |
+
return logits
|
| 61 |
+
|
| 62 |
+
model = TaxonomyClassifier(MODEL_NAME, num_parent_labels, num_labels)
|
| 63 |
+
model.load_state_dict(torch.load(os.path.join(checkpoint_path, 'model.safetensors'), map_location=torch.device('cpu')))
|
| 64 |
+
|
| 65 |
+
model.eval()
|
| 66 |
+
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
|
| 67 |
+
|
| 68 |
+
if level > 1:
|
| 69 |
+
parent_label_map_current = np.load(os.path.join(LEVEL_DIRS[level - 1], 'label_map.npy'), allow_pickle=True).item()
|
| 70 |
+
num_parent_labels_current = len(parent_label_map_current)
|
| 71 |
+
parent_one_hot = torch.zeros(num_parent_labels_current)
|
| 72 |
+
if parent_prediction_id != 0:
|
| 73 |
+
parent_index = parent_label_map_current.get(parent_prediction_id)
|
| 74 |
+
if parent_index is not None:
|
| 75 |
+
parent_one_hot[parent_index] = 1.0
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
outputs = model(inputs.input_ids, attention_mask=inputs.attention_mask, parent_ids=parent_one_hot.unsqueeze(0))
|
| 78 |
+
else:
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
outputs = model(**inputs)
|
| 81 |
+
|
| 82 |
+
probabilities = softmax(outputs.logits if level == 1 else outputs, dim=-1)[0]
|
| 83 |
+
top3_prob, top3_indices = torch.topk(probabilities, 3)
|
| 84 |
+
index_to_label = {v: k for k, v in label_map.items()}
|
| 85 |
+
results = []
|
| 86 |
+
for prob, index in zip(top3_prob, top3_indices):
|
| 87 |
+
predicted_label_id = index_to_label[index.item()]
|
| 88 |
+
results.append((predicted_label_id, prob.item()))
|
| 89 |
+
return results
|
| 90 |
+
|
| 91 |
+
st.title("Taxonomy Model Inference")
|
| 92 |
+
|
| 93 |
+
input_text = st.text_area("Enter text to classify", "Experience the magic of music with the Clavinova CLP-800 series. This versatile range of digital pianos is designed to delight everyone, from budding musicians to seasoned pianists. Each model combines state-of-the-art technology with the realistic touch and tone of world-renowned grand pianos, enhanced by GrandTouch keyboard action and Virtual Resonance Modeling. With seamless Bluetooth® connectivity, built-in lessons, and elegant design, the CLP-800 series offers the perfect blend of tradition and innovation. Elevate your musical journey with the warmth and sophistication of the Yamaha Clavinova, our finest series of digital pianos.")
|
| 94 |
+
|
| 95 |
+
softmax_threshold = st.slider("Softmax Threshold", min_value=0.0, max_value=1.0, value=0.5, step=0.05)
|
| 96 |
+
|
| 97 |
+
# Checkpoint Selection
|
| 98 |
+
available_levels = []
|
| 99 |
+
level_checkpoints = {}
|
| 100 |
+
for level in LEVEL_DIRS:
|
| 101 |
+
level_dir = LEVEL_DIRS[level]
|
| 102 |
+
if os.path.exists(level_dir):
|
| 103 |
+
options = [d for d in os.listdir(level_dir) if os.path.isdir(os.path.join(level_dir, d))]
|
| 104 |
+
options = [d for d in options if 'step' in d or d == 'model']
|
| 105 |
+
options.sort(key=lambda x: (('step' not in x), int(x.split('step')[-1]) if 'step' in x else -1))
|
| 106 |
+
level_checkpoints[level] = [os.path.join(level_dir, opt) for opt in options]
|
| 107 |
+
if level_checkpoints[level]:
|
| 108 |
+
available_levels.append(level)
|
| 109 |
+
else:
|
| 110 |
+
level_checkpoints[level] = []
|
| 111 |
+
|
| 112 |
+
selected_checkpoints = {}
|
| 113 |
+
for level in available_levels:
|
| 114 |
+
selected_checkpoints[level] = st.selectbox(f"Select Level {level} Checkpoint", options=level_checkpoints[level])
|
| 115 |
+
|
| 116 |
+
if st.button("Run Inference"):
|
| 117 |
+
if input_text:
|
| 118 |
+
all_level_results = {}
|
| 119 |
+
current_prediction_id = None
|
| 120 |
+
last_level = 0
|
| 121 |
+
|
| 122 |
+
for level in sorted(available_levels):
|
| 123 |
+
if selected_checkpoints[level]:
|
| 124 |
+
checkpoint_path = selected_checkpoints[level]
|
| 125 |
+
if level == 1:
|
| 126 |
+
level_results = predict_level(level, input_text, checkpoint_path=checkpoint_path)
|
| 127 |
+
else:
|
| 128 |
+
if current_prediction_id == 0:
|
| 129 |
+
st.info(f"Taxonomy terminated at Level {last_level} with ID 0.")
|
| 130 |
+
break
|
| 131 |
+
level_results = predict_level(level, input_text, parent_prediction_id=current_prediction_id, checkpoint_path=checkpoint_path)
|
| 132 |
+
|
| 133 |
+
if level_results[0][1] < softmax_threshold:
|
| 134 |
+
st.info(f"Inference stopped at Level {level} due to softmax probability ({level_results[0][1]:.3f}) being below the threshold.")
|
| 135 |
+
break
|
| 136 |
+
|
| 137 |
+
all_level_results[level] = level_results
|
| 138 |
+
current_prediction_id = level_results[0][0]
|
| 139 |
+
last_level = level
|
| 140 |
+
else:
|
| 141 |
+
st.warning(f"Skipping Level {level} as no checkpoint is selected.")
|
| 142 |
+
break
|
| 143 |
+
|
| 144 |
+
data = []
|
| 145 |
+
for level in sorted(all_level_results.keys()):
|
| 146 |
+
results = all_level_results[level]
|
| 147 |
+
data.append({
|
| 148 |
+
'level': level,
|
| 149 |
+
'text': get_label_text(level - 1, results[0][0]),
|
| 150 |
+
'softmax': f"{results[0][1]:.3f}",
|
| 151 |
+
'runner_up_1_id': results[1][0],
|
| 152 |
+
'runner_up_1_text': get_label_text(level - 1, results[1][0]),
|
| 153 |
+
'runner_up_1_softmax': f"{results[1][1]:.3f}",
|
| 154 |
+
'runner_up_2_id': results[2][0],
|
| 155 |
+
'runner_up_2_text': get_label_text(level - 1, results[2][0]),
|
| 156 |
+
'runner_up_2_softmax': f"{results[2][1]:.3f}",
|
| 157 |
+
})
|
| 158 |
+
|
| 159 |
+
if data:
|
| 160 |
+
df = pd.DataFrame(data)
|
| 161 |
+
st.dataframe(df)
|
| 162 |
+
else:
|
| 163 |
+
st.info("No predictions made or inference stopped.")
|
| 164 |
+
|
| 165 |
+
else:
|
| 166 |
+
st.warning("Please enter text for classification.")
|