|
import streamlit as st |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer, AutoModel |
|
import pandas as pd |
|
from time import time |
|
import numpy as np |
|
from src.A_Preprocess import clean_text |
|
from src.E_Summarization import simple_summarize_text |
|
from src.E_Model_utils import get_transformes_embeddings, load_model, get_embeddings |
|
from src.E_Faiss_utils import load_faiss_index, normalize_embeddings |
|
|
|
import warnings |
|
warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
|
|
|
st.header('Watson Assistant VDF TOBi improvement') |
|
st.write('The model is trained on the TOBi 🤖 intents in Romanian language.') |
|
'---' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model_name = st.sidebar.radio("Selectează modelul 👇", ["MiniLM-L12-v2","llama3.2-1b","all-MiniLM-L6-v2","bert-base-romanian-cased-v1","multilingual-e5-small","e5_small_fine_tuned_model","all-distilroberta-v1"]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if model_name: |
|
if model_name == "bert-base-romanian-cased-v1": |
|
transformer_model_name = "dumitrescustefan/bert-base-romanian-cased-v1" |
|
if model_name == "llama3.2-1b": |
|
infloat_model_name = "AlexHung29629/sgpt-llama3.2-1b-stage1" |
|
if model_name == "MiniLM-L12-v2": |
|
infloat_model_name = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" |
|
model_name = "paraphrase-multilingual-MiniLM-L12-v2" |
|
if model_name == "multilingual-e5-small": |
|
infloat_model_name = "intfloat/multilingual-e5-small" |
|
elif model_name == "e5_small_fine_tuned_model": |
|
infloat_model_name = r"output\fine-tuned-model" |
|
local_only = "local_files_only = True" |
|
elif model_name == "all-MiniLM-L6-v2": |
|
infloat_model_name = "sentence-transformers/all-MiniLM-L6-v2" |
|
elif model_name == "all-distilroberta-v1": |
|
infloat_model_name = "sentence-transformers/all-distilroberta-v1" |
|
else: |
|
st.write("Choose a model") |
|
|
|
st.write(f"Model **{model_name}** loaded successfully!") |
|
|
|
|
|
if 'index_loaded' not in st.session_state: |
|
st.session_state.index_loaded = False |
|
if 'index' not in st.session_state: |
|
st.session_state.index = None |
|
if 'pdf_button_enabled' not in st.session_state: |
|
st.session_state.pdf_button_enabled = False |
|
if 'data' not in st.session_state: |
|
st.session_state.data = None |
|
if 'intent_button_clicked' not in st.session_state: |
|
st.session_state.intent_button_clicked = False |
|
if 'intent' not in st.session_state: |
|
st.session_state.intent = None |
|
if 'similarity' not in st.session_state: |
|
st.session_state.similarity = None |
|
if 'model' not in st.session_state: |
|
st.session_state.model = None |
|
if 'summar_model' not in st.session_state: |
|
st.session_state.summar_model = None |
|
if 'summarized_text' not in st.session_state: |
|
st.session_state.summarized_text = None |
|
if 'csv_copied' not in st.session_state: |
|
st.session_state.csv_copied = False |
|
if 'csv_file_path' not in st.session_state: |
|
st.session_state.csv_file_path = r'C:\Users\ZZ029K826\Documents\GitHub\LLM_Intent_Recognition\data\Pager_Intents_cleaned.csv' |
|
if 'copied_csv_file_path' not in st.session_state: |
|
st.session_state.copied_csv_file_path = r'C:\Users\ZZ029K826\Documents\GitHub\LLM_Intent_Recognition\data\Pager_Intents_cleaned_Copy.csv' |
|
if 'user_text' not in st.session_state: |
|
st.session_state.user_text = "" |
|
if 'user_utterance_updated' not in st.session_state: |
|
st.session_state.user_utterance_updated = r'C:\Users\ZZ029K826\Documents\GitHub\LLM_Intent_Recognition\data\User_utterances_updated.csv' |
|
|
|
|
|
def create_csv_copy(): |
|
df = pd.read_csv(st.session_state.csv_file_path) |
|
df.to_csv(st.session_state.copied_csv_file_path, index=False) |
|
st.session_state.csv_copied = True |
|
st.success("CSV file copied successfully.") |
|
|
|
|
|
def add_user_text_and_intent(): |
|
if st.session_state.csv_copied: |
|
df = pd.read_csv(st.session_state.copied_csv_file_path) |
|
new_row = {'utterance': st.session_state.user_text, 'intent': st.session_state.intent, 'similarity': st.session_state.similarity} |
|
st.write(new_row) |
|
df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True) |
|
|
|
csv_file_path = f'{st.session_state.copied_csv_file_path}' |
|
|
|
df.to_csv(csv_file_path, index=False) |
|
st.success("User text and intent added to the copied CSV file successfully.") |
|
|
|
|
|
if st.button("Load Embeddings and Index"): |
|
if model_name == "e5_small_fine_tuned_model": |
|
model = SentenceTransformer(r'C:\Users\ZZ029K826\Documents\GitHub\LLM_Intent_Recognition\src\output\fine-tuned-model\e5_small_fine_tuned_model', local_files_only = True) |
|
|
|
|
|
vocab_size = model.tokenizer.vocab_size |
|
st.write(f"**Vocab Size:** {vocab_size}") |
|
|
|
|
|
max_len = model.max_seq_length |
|
st.write(f"**Max Sequence Length:** {max_len}") |
|
|
|
st.session_state.model = model |
|
elif model_name == "bert-base-romanian-cased-v1": |
|
tokenizer = AutoTokenizer.from_pretrained("dumitrescustefan/bert-base-romanian-cased-v1") |
|
model = AutoModel.from_pretrained("dumitrescustefan/bert-base-romanian-cased-v1") |
|
st.session_state.model = model |
|
else: |
|
model = SentenceTransformer(infloat_model_name) |
|
st.session_state.model = model |
|
|
|
index = load_faiss_index(f"embeddings/{model_name}_vector_db.index") |
|
st.session_state.index = index |
|
st.session_state.index_loaded = True |
|
st.write("Embeddings and index loaded successfully!") |
|
|
|
|
|
|
|
|
|
if st.session_state.index_loaded == True: |
|
'-------------------' |
|
st.write(f'✨ Load the csv file?') |
|
uploaded_file = st.file_uploader("Search the csv file", type="csv") |
|
|
|
if uploaded_file is not None: |
|
st.session_state.data = pd.read_csv(uploaded_file) |
|
st.write("CSV file successfully uploaded!") |
|
st.write(st.session_state.data) |
|
|
|
|
|
elif st.session_state.data is not None: |
|
st.write("Previously uploaded data:") |
|
st.write(st.session_state.data[:5]) |
|
|
|
|
|
|
|
data = st.session_state.data |
|
... |
|
if st.session_state.data is not None: |
|
|
|
'-------------------' |
|
|
|
user_text = st.text_area("👇 Enter user utterance text:", placeholder= 'User text') |
|
st.write(f'Text length: {len(user_text)}') |
|
|
|
if user_text: |
|
if len(user_text) > 150: |
|
st.write("The text is too long. Please summarize it.") |
|
summarize_button = st.button("Summarize") |
|
if summarize_button: |
|
st.session_state.summarized_text = simple_summarize_text(user_text) |
|
user_text = st.session_state.summarized_text |
|
st.write(f"The summarized text: {user_text}") |
|
|
|
|
|
|
|
st.session_state.user_text = user_text |
|
|
|
|
|
start = time() |
|
|
|
|
|
cleaned_text = clean_text(user_text) |
|
|
|
|
|
model = st.session_state.model |
|
|
|
if model_name == "bert-base-romanian-cased-v1": |
|
tokenizer = AutoTokenizer.from_pretrained("dumitrescustefan/bert-base-romanian-cased-v1") |
|
model = AutoModel.from_pretrained("dumitrescustefan/bert-base-romanian-cased-v1") |
|
input_embedding = get_transformes_embeddings([cleaned_text], model, tokenizer) |
|
else: |
|
input_embedding = get_embeddings(model, [cleaned_text]) |
|
|
|
|
|
normalized_embedding = normalize_embeddings(input_embedding) |
|
|
|
|
|
st.session_state.input_embedding = normalized_embedding |
|
|
|
st.session_state.cleaned_text = cleaned_text |
|
|
|
|
|
intent_button = st.button("Calculate Intent and Similarity") |
|
|
|
|
|
if intent_button: |
|
st.session_state.intent_button_clicked = True |
|
|
|
|
|
if st.session_state.intent_button_clicked and st.session_state.input_embedding is not None: |
|
start = time() |
|
|
|
index = st.session_state.index |
|
D, I = index.search(st.session_state.input_embedding, 1) |
|
|
|
intents = st.session_state.data['intent'].tolist() |
|
intent = intents[I[0][0]] |
|
distance = D[0][0] |
|
similarity = 1 / (1 + distance) |
|
|
|
|
|
st.session_state.intent = intent |
|
st.session_state.similarity = similarity |
|
|
|
|
|
st.write(f"Intent: {intent}") |
|
st.write(f"Confidence: {similarity:.4f}") |
|
st.write(f"Timp de răspuns: {time() - start:.4f} secunde") |
|
|
|
|
|
|
|
|
|
'-------------------' |
|
st.write(f'✨ Correct Intent: **{intent}**?') |
|
if st.button("Append User Text and Intent"): |
|
create_csv_copy() |
|
add_user_text_and_intent() |
|
|
|
|
|
'-------------------' |
|
|
|
if 'utt_csv_file' not in st.session_state: |
|
st.session_state.utt_csv_file = None |
|
if 'utt_intent_results_df' not in st.session_state: |
|
st.session_state.utt_intent_results_df = None |
|
if 'utt_csv_file_df' not in st.session_state: |
|
st.session_state.utt_csv_file_df = None |
|
|
|
|
|
def apply_similarity_search(df): |
|
|
|
|
|
|
|
|
|
|
|
|
|
if 'utterance' not in df.columns: |
|
raise KeyError("The column 'utterance' does not exist in the DataFrame.") |
|
|
|
|
|
utterances = df['utterance'].tolist() |
|
embeddings = st.session_state.model.encode(utterances) |
|
embeddings = np.array(embeddings).astype('float32') |
|
|
|
|
|
intents = st.session_state.data['intent'].tolist() |
|
for i, embedding in enumerate(embeddings): |
|
D, I = st.session_state.index.search(np.expand_dims(embedding, axis=0), 1) |
|
intent = intents[I[0][0]] |
|
df.at[i, 'intent'] = intent |
|
|
|
|
|
csv_file_name = st.session_state.utt_csv_file.name |
|
df.to_csv(f'Updated_{csv_file_name}', index=False) |
|
|
|
return df |
|
|
|
|
|
if st.session_state.similarity and st.session_state.utt_csv_file is None: |
|
st.header('✨ Auto-update the utterances list without intent') |
|
csv_file = st.file_uploader("Load User utterances file", type="csv") |
|
if csv_file is not None: |
|
st.session_state.utt_csv_file = csv_file |
|
|
|
df = pd.read_csv(csv_file, encoding='windows-1252') |
|
st.session_state.utt_csv_file_df = df |
|
|
|
display_df = df[['utterance','intent']] |
|
st.write(display_df) |
|
st.success("Utterance file loaded successfully.") |
|
elif st.session_state.similarity and st.session_state.utt_csv_file_df is not None: |
|
st.write("Utterance file already loaded.") |
|
df = st.session_state.utt_csv_file_df |
|
|
|
display_df = df[['utterance','intent']] |
|
st.write(display_df) |
|
|
|
|
|
if st.session_state.utt_csv_file is not None and st.button("Apply Similarity Search to CSV"): |
|
st.write("Performing similarity search on the uploaded CSV file...") |
|
df = st.session_state.utt_csv_file_df |
|
results_df = apply_similarity_search(df) |
|
st.session_state.utt_intent_results_df = results_df |
|
|
|
|
|
|
|
|
|
if st.session_state.utt_intent_results_df is not None: |
|
st.write("Results:") |
|
|
|
df = st.session_state.utt_intent_results_df |
|
|
|
display_results_df = df[['utterance','intent']] |
|
st.write(display_results_df) |
|
st.write(f"Timp de răspuns: {time() - start:.4f} secunde") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
st.stop() |
|
|