Spaces:
Sleeping
Sleeping
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import ast | |
import random | |
import torch | |
import time | |
from joblib import load | |
from transformers import BertTokenizer, BertModel | |
from sklearn.metrics.pairwise import cosine_similarity | |
# import faiss | |
""" | |
## Сервис умного поиска сериалов 📽️ | |
""" | |
# Читаем вектора сериалов | |
embeddings = np.loadtxt('data/embs.txt') | |
# Указываем пути к сохраненным модели и токенизатору | |
model_path = "model" | |
tokenizer_path = "tokenizer" | |
# Загружаем модель | |
loaded_model = BertModel.from_pretrained(model_path) | |
# Загружаем токенизатор | |
loaded_tokenizer = BertTokenizer.from_pretrained(tokenizer_path) | |
df = pd.read_csv('data/data.csv') | |
df['ganres'] = df['ganres'].apply(lambda x: ast.literal_eval(x)) | |
df['description'] = df['description'].astype(str) | |
st.write(f'<p style="font-family: Arial, sans-serif; font-size: 24px; ">Наш сервис насчитывает \ | |
{len(df)} лучших сериалов</p>', unsafe_allow_html=True) | |
st.image('images/ser2.png') | |
ganres_lst = sorted(['драма', 'документальный', 'биография', 'комедия', 'фэнтези', 'приключения', 'для детей', 'мультсериалы', | |
'мелодрама', 'боевик', 'детектив', 'фантастика', 'триллер', 'семейный', 'криминал', 'исторический', 'музыкальные', | |
'мистика', 'аниме', 'ужасы', 'спорт', 'скетч-шоу', 'военный', 'для взрослых', 'вестерн']) | |
st.sidebar.header('Панель инструментов :gear:') | |
choice_g = st.sidebar.multiselect("Выберите жанры", options=ganres_lst) | |
n = st.sidebar.selectbox("Количество отображаемых элементов на странице", options=[5, 10, 15, 20, 30]) | |
st.sidebar.info("📚 Для наилучшего соответствия, запрос должен быть максимально развернутым") | |
text = st.text_input('Введите описание для рекомендации') | |
# Векторизуем запрос | |
loaded_model.eval() | |
tokens = loaded_tokenizer(text, return_tensors="pt", padding=True, truncation=True) | |
start_time = time.time() | |
tokens = {key: value.to(loaded_model.device) for key, value in tokens.items()} | |
# Передача токенов в модель для получения эмбеддингов | |
with torch.no_grad(): | |
output = loaded_model(**tokens) | |
# Эмбеддинги получаются из последнего скрытого состояния | |
user_embedding = output.last_hidden_state.mean(dim=1).squeeze().cpu().detach().numpy() | |
cosine_similarities = cosine_similarity(embeddings, user_embedding.reshape(1, -1)) | |
button = st.button('Отправить запрос', type="primary") | |
if text and button: | |
if len(choice_g) == 0: | |
choice_g = ganres_lst | |
# random = random.sample(range(len(df)), 50) | |
top_ind = np.unravel_index(np.argsort(cosine_similarities, axis=None)[-30:][::-1], cosine_similarities.shape) | |
confidence = cosine_similarities[top_ind] | |
top_ind = list(top_ind[0]) | |
conf_dict = {} | |
for value, conf in zip(top_ind, confidence): | |
conf_dict[int(value)] = conf | |
# st.write(conf_dict) | |
output_dict = {} | |
for i in top_ind: | |
for ganre in df['ganres'][i]: | |
if ganre in choice_g: | |
output_dict[i] = df['ganres'][i] | |
# st.write('output_dict') | |
sorted_lst = sorted(output_dict.items(), key=lambda x: len(set(x[1]) & set(choice_g)), reverse=True) | |
n_lst = [i[0] for i in sorted_lst[:n]] | |
st.write(f'<p style="font-family: Arial, sans-serif; font-size: 18px; text-align: center;"><strong>Всего подобранных \ | |
рекомендаций {len(sorted_lst)}</strong></p>', unsafe_allow_html=True) | |
st.write('\n') | |
# Отображение изображений и названий | |
for i in n_lst: | |
col1, col2 = st.columns([2, 5]) | |
with col1: | |
st.image(df['poster'][i], width=200) | |
with col2: | |
st.write(f"***Название:*** {df['title'][i]}") | |
st.write(f"***Жанр:*** {', '.join(df['ganres'][i])}") | |
st.write(f"***Описание:*** {df['description'][i]}") | |
# similarity = float(confidence) | |
# st.write(f"***Cosine Similarity : {round(similarity, 3)}***") | |
st.markdown(f"[***ссылка на сериал***]({df['url'][i]})") | |
st.write(f"") | |
end_time = time.time() | |
st.write(f"<small>*Степень соответствия по косинусному сходству: {conf_dict[i]:.4f}*</small>", unsafe_allow_html=True) | |
st.markdown( | |
"<hr style='border: 2px solid #000; margin-top: 10px; margin-bottom: 10px;'>", | |
unsafe_allow_html=True | |
) |