#importing the libraries
import streamlit as st
from PIL import Image
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
import numpy as np
import pandas as pd
import time
import os

model_repository_id = "Dusduo/Pokemon-classification-1stGen"
# Loading the pokemon classifier model and its processor
image_processor = AutoImageProcessor.from_pretrained(model_repository_id)
model = AutoModelForImageClassification.from_pretrained(model_repository_id)
# Loading the pokemon information table
pokemon_info_df = pd.read_csv('pokemon_info.csv')

pokeball_image = Image.open('pokeball.png').resize((20,20))

#functions to predict image
def preprocess(processor: AutoImageProcessor, image):
    return processor(image.convert("RGB").resize((200,200)), return_tensors="pt")

def predict(model: AutoModelForImageClassification, inputs, k=5):
    
    # Forward the image to the model and retrieve the logits 
    with torch.no_grad():
        logits = model(**inputs).logits
    
    # Convert the retrieved logits into a vector of probabilities for each class
    probabilities = torch.softmax(logits[0], dim=0).tolist()
    
    # Discriminate wether or not the inputted image was an image of a Pokemon
    # Compute the variance of the vector of probabilities 
        # The spread of the probability values is a good represent of the confusion of the model
        # Or in other words, its confidence => the greater the spread, the lower its confidence 
    variance = np.var(probabilities)
    
    # Too great of a spread: it is likely the image provided did not correspond to any known classes
    if variance < 0.001: #not a pokemon
       predicted_label = 'not a pokemon' 
       probability = -1
       (top_k_labels, top_k_probability) = '_', '_'
    else: # it is a pokemon
        # Retrieve the predicted class (pokemon)
        predicted_id = logits.argmax(-1).item()
        predicted_label = model.config.id2label[predicted_id]
        # Retrieve the probability for the predicted class, and format it to 2 decimals
        probability = round(probabilities[predicted_id]*100,2)
        # Retrieve the top 5 classes and their probabilities
        #top_k_labels = [model.config.id2label[key] for key in np.argpartition(logits.numpy(), -k)[-k:]]
        #top_k_probability = [round(prob*100,2) for prob in np.sort(probabilities.numpy())[-k:]]
        
    return predicted_label, probability #, (top_k_labels, top_k_probability)



# Designing the interface ------------------------------------------

# Use the full page instead of a narrow central column
st.set_page_config(layout="wide")

# Define the title
st.title("Gotta Classify 'Em All - 1st Generation Pokedex -")
# For newline
st.write('\n')

image = Image.open('anime1.jpeg')

col1, col2 = st.columns([3,1]) # [3,1]

with col1:
    image = Image.open('anime1.jpeg')
    show = st.image(image, use_column_width=True)



# Display Sample images  ----
st.subheader('Sample images')

sample_imgs_dir = "sample_imgs/"
sample_imgs = os.listdir(sample_imgs_dir) # get the list of all sample images
img_idx = 0

n_cols = 4
groups = []
for i in range(0, len(sample_imgs), n_cols):
    groups.append(sample_imgs[i:i+n_cols])

for group in groups: 
    cols = st.columns(n_cols)
    for i,image_file in enumerate(group):
        cols[i].image(sample_imgs_dir+image_file)
    


# Sidebar work and model outputs ---------------

st.sidebar.title("Upload Image")

#Disabling warning
#st.set_option('deprecation.showfileUploaderEncoding', False)
#Choose your own image
uploaded_file = st.sidebar.file_uploader("",type=['png', 'jpg', 'jpeg'], accept_multiple_files=False )

if uploaded_file is not None:
    
    u_img = Image.open(uploaded_file)
    show.image(u_img, 'Uploaded Image', width=400 )#use_column_width=True)
    
    # Preprocess the image for the model
    model_inputs = preprocess(image_processor, u_img)
    
# For newline
st.sidebar.write('\n')
    
if st.sidebar.button("Click Here to Classify"):
    
    if uploaded_file is None:
        
        st.sidebar.write("Please upload an Image to Classify")
    
    else:
        
        with st.spinner('Classifying ...'):
            # Get prediction
            prediction, probability = predict(model, model_inputs,5) #, (top_k_labels, top_k_probability)
            time.sleep(2)
            st.sidebar.success('Done!')
            
        st.sidebar.header("Model predicts: ")
        
        # Display prediction
        
        if probability==-1:
            
            st.sidebar.write("It seems like it is not a picture of a 1st Generation Pokemon alone.", '\n', 
                             "There might be too many entities on the image." )
                             
        else:
            st.sidebar.write(f" It's a(n) {prediction} picture.",'\n')
            
            st.sidebar.write('Probability:',probability,'%')
            
            # Retrieve predicted pokemon information
            _, pokedex_number, english_name, romaji_name, katakana_name, weight_kg, height_m, type1, type2, color1, color2, classification, evolve_from, evolve_into, is_legendary = pokemon_info_df[pokemon_info_df['name']==prediction].values[0]
            with col2:
                # pokedex box 
                with st.container(border=True ):
                    # first row
                    with st.container():
                        pokeball_image_col,pokedex_number_col, pokemon_name_col = st.columns([1,1,8])
                        pokeball_image_col.image(pokeball_image)
                        pokedex_number_col.markdown(f'<div style="text-align: left; font-size: 1.4rem;"><b>{pokedex_number}</b></div>', unsafe_allow_html=True)
                        pokemon_name_col.markdown(f'<div style="text-align: right; font-size: 1.4rem;"><b>{english_name}</b></div>', unsafe_allow_html=True)
                        
                    # second row
                    with st.container():
                        st.markdown(f'<div style="text-align: center; color: {color1}; font-size: 1.2rem;"><b>{classification}</b></div>', unsafe_allow_html=True)
                        
                    # 3rd row
                    with st.container():
                        if pd.isna(type2):
                            st.write('\n')
                            st.markdown(f'<div style="display: flex; justify-content: center; align-items: center; "><div style="display: inline-block; padding: 5px; margin: 0 5px; border-radius: 5px; background-color: {color1}; color: white;">{type1}</div>', unsafe_allow_html=True)
                        else: 
                            type1_col, type2_col = st.columns(2)
                            type1_col.markdown(f'<div style="display: flex; justify-content: center; align-items: center;"><div style="display: inline-block; padding: 5px; margin: 0 5px; border-radius: 5px; background-color: {color1}; color: white;">{type1}</div>', unsafe_allow_html=True)
                            type2_col.markdown(f'<div style="display: flex; justify-content: center; align-items: center;"><div style="display: inline-block; padding: 5px; margin: 0 5px; border-radius: 5px; background-color: {color2}; color: white;">{type2}</div>', unsafe_allow_html=True) 
                        st.write('\n')
                    # 4th row
                    with st.container():
                        st.write(f'<div style=font-size: 1.4rem;><b>Height:</b> {height_m}m', unsafe_allow_html=True)
                        st.write('\n')
                        st.write(f'<div style=font-size: 1.4rem;><b>Weight:</b> {weight_kg}kg', unsafe_allow_html=True)
                        st.write('\n')
                        if not pd.isna(evolve_from):
                            st.markdown(f'<div style=font-size: 1.4rem;><b>Evolves from:</b> {evolve_from}', unsafe_allow_html=True)
                            #st.write(f'Evolves from: {evolve_from}')
                            st.write('\n')
                        if not pd.isna(evolve_into):
                            st.markdown(f'<div style=font-size: 1.4rem;><b>Evolves into:</b> {evolve_into}', unsafe_allow_html=True)
                            #st.write(f'Evolves into: {evolve_into}')
                            st.write('\n')