Eli Orille
added csv file
adfdb45
raw
history blame
4.76 kB
import numpy as np
import pandas as pd
from dotenv import load_dotenv
from langchain_community.document_loaders import TextLoader
from langchain_huggingface import HuggingFaceEmbeddings
# from langchain_text_splitters import CharacterTextSplitter
from langchain_chroma import Chroma
from langchain.schema import Document
import gradio as gr
import os
import requests
load_dotenv()
movies = pd.read_csv("movies_with_emotions.csv")
loader = TextLoader("../tagged_plot.txt", encoding="utf-8")
raw_documents = loader.load()
with open("../tagged_plot.txt", encoding="utf-8") as f:
lines = f.read().splitlines()
documents = [Document(page_content=line) for line in lines if line.strip()]
embeddings = HuggingFaceEmbeddings(
model_name="sentence-transformers/all-mpnet-base-v2"
)
db_movies = Chroma(
persist_directory="chroma_db",
embedding_function=embeddings
)
def retrieve_semantic_recommendations(
query: str,
genre: None,
tone: None,
initial_top_k: int=50,
final_top_k: int=16,
) -> pd.DataFrame:
recs = db_movies.similarity_search(query, k= initial_top_k)
movie_list = [(rec.page_content.split()[0].lstrip('"')) for rec in recs]
movie_recs = movies[movies["Wiki Page"].isin(movie_list)].head(final_top_k)
if genre != "All":
movie_recs = movie_recs[movie_recs["simple_genre"] == genre][:final_top_k]
else:
movie_recs = movie_recs.head(final_top_k)
if tone == "Happy":
movie_recs.sort_values(by="joy", ascending=False, inplace=True)
elif tone == "Surprising":
movie_recs.sort_values(by="surprise", ascending=False, inplace=True)
elif tone == "Angry":
movie_recs.sort_values(by="anger", ascending=False, inplace=True)
elif tone == "Suspenseful":
movie_recs.sort_values(by="fear", ascending=False, inplace=True)
elif tone == "Sad":
movie_recs.sort_values(by="sadness", ascending=False, inplace=True)
return movie_recs
def recommend_movies(
query: str,
genre: None,
tone: None,
):
recommendations = retrieve_semantic_recommendations(query,genre,tone)
results = []
for _,row in recommendations.iterrows():
plot = row["Plot"]
truncated_desc_split = plot.split()
truncated_plot = " ".join(truncated_desc_split[:20]) + "..."
authors_split = row["Director"].split(",")
if len(authors_split) == 2:
authors_str = f"{authors_split[0]} and {authors_split[1]}"
elif len(authors_split) > 2:
authors_str = f"{', '.join(authors_split[:-1])}, and {authors_split[-1]}"
else:
authors_str = row["Director"]
cast_str = "N/A"
if isinstance(row["Cast"],str):
cast_split = row["Cast"].split(",")
if len(cast_split) == 2:
cast_str = f"{cast_split[0]} and {cast_split[1]}"
elif len(authors_split) > 2:
cast_str = f"{', '.join(cast_split[:-1])}, and {cast_split[-1]}"
else:
cast_str = row["Cast"]
caption = (f"{row['Title']} (Directed by {authors_str})\n"
f"Cast: {cast_str}\nPlot: {truncated_plot}")
title = row["Title"]
api_key = os.getenv("OMDB_API_KEY")
year = row["Release Year"]
url = f"http://www.omdbapi.com/?apikey={api_key}&t={title}&y={year}"
data = requests.get(url).json()
if data["Response"] == "True":
poster = data["Poster"] + "&fife=w800"
if data["Poster"] == "N/A":
poster = "cover-not-found.jpg"
else:
poster = "cover-not-found.jpg"
results.append((poster,caption))
return results
categories = ["All"] + sorted(movies["simple_genre"].unique())
tones = ["All"] + ["Happy", "Surprising", "Angry", "Suspenseful", "Sad"]
with gr.Blocks(theme = gr.themes.Glass()) as dashboard:
gr.Markdown("# Movie Recommendation System")
with gr.Row():
user_query = gr.Textbox(label = "Please enter description of a movie:",
placeholder = "e.g. A movie about animals")
category_dropdown = gr.Dropdown(choices = categories, label = "Select a genre: ", value = "All")
tone_dropdown = gr.Dropdown(choices = tones, label = "Select an emotional tone: ", value="All")
submit_button = gr.Button("Find Recommendations")
gr.Markdown("## Recommendations")
output = gr.Gallery(label = "Recommended movies", columns=8, rows=2)
submit_button.click(fn = recommend_movies,
inputs=[user_query, category_dropdown,tone_dropdown],
outputs=output)
if __name__ == "__main__":
dashboard.launch()