import functools
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.applications import efficientnet
#import efficientnet
from tensorflow.keras.layers import TextVectorization
import matplotlib.pyplot as plt
import cv2


from models import EMBED_DIM, FF_DIM, SEQ_LENGTH, ImageCaptioningModel, TransformerDecoderBlock, TransformerEncoderBlock, get_cnn_model, image_augmentation, vectorization, valid_data, decode_and_resize

def display_UI():
    import streamlit as st
    from streamlit_option_menu import option_menu
    import streamlit.components.v1 as html
    import pandas as pd
    import numpy as np
    from pathlib import Path

    # from  PIL import Image

    st.markdown(""" <style> .appview-container .main .block-container {
        max-width: 100%;
        padding-top: 1rem;
        padding-right: {1}rem;
        padding-left: {1}rem;
        padding-bottom: {1}rem;
    }</style> """, unsafe_allow_html=True)
    #Add a logo (optional) in the sidebar
    # logo = Image.open(r'C:\Users\13525\Desktop\Insights_Bees_logo.png')
    # with st.sidebar:
    #     choose = option_menu("Forensic Examiner", ["Inspect Media","Comparative Analysis","About", "Contact"],
    #                         icons=['camera fill', 'kanban', 'book','person lines fill'],
    #                         menu_icon="app-indicator", default_index=0,
    #                         styles={
    #         "container": {"padding": "0 5 5 5 !important", "background-color": "#fafafa"},
    #         "icon": {"color": "orange", "font-size": "25px"}, 
    #         "nav-link": {"font-size": "16px", "text-align": "left", "margin":"0px", "--hover-color": "#eee"},
    #         "nav-link-selected": {"background-color": "#02ab21"},
    #     }
    #     )


    #Add the cover image for the cover page. Used a little trick to center the image
    st.markdown(""" <style> .font {
        font-size:25px ; font-family: 'Cooper Black'; color: #FF9633;} 
        </style> """, unsafe_allow_html=True)
    col1, col2 = st.columns( [0.8, 0.2])
    with col1:               # To display the header text using css style
        st.markdown('<p class="font">Generate Caption of image</p>', unsafe_allow_html=True)
        
    with col2:               # To display brand logo                
        st.image('./logo.png', width=50 )
    # model_name = st.selectbox("Select the model...", list (all_models.keys ()))

    uploaded_file = st.file_uploader("Choose an Image File", type=[".jpg", ".jpeg", ".png", ".PNG"],
                            accept_multiple_files=False)

    opencv_image= None
    if uploaded_file is not None:
        with st.spinner('Wait for it...'):
            # read image file and store for prediction
            # img_file=uploaded_file.read()
            file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
            opencv_image = cv2.imdecode(file_bytes, 1)

            # Now do something with the image! For example, let's display it:
            st.image(opencv_image, channels="BGR")
            # bytes_data = uploaded_file.getvalue()

            # audio_bytes = uploaded_file.read()
            # save_folder = './data'
            # save_path = Path(save_folder, uploaded_file.name)
            # with open(save_path, mode='wb') as w:
            #     w.write(uploaded_file.getvalue())

        st.image(opencv_image, width=400 )
        with st.spinner('Loading the model..'):
            cnn_model = get_cnn_model()
            encoder = TransformerEncoderBlock(embed_dim=EMBED_DIM, dense_dim=FF_DIM, num_heads=1)
            decoder = TransformerDecoderBlock(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2)
            new_model = ImageCaptioningModel(
                cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation,
            )
            
            def generate_caption():
                # Select a random image from the validation dataset
                sample_img = opencv_image #np.random.choice(valid_images)

                # Read the image from the disk
                cv2.imwrite('./uploaded_image.jpg', sample_img)
                sample_img = decode_and_resize('./uploaded_image.jpg')
                img = sample_img.numpy().clip(0, 255).astype(np.uint8)
                #plt.imshow(img)
                #plt.show()

                # Pass the image to the CNN
                img = tf.expand_dims(sample_img, 0)
                img = new_model.cnn_model(img)

                # Pass the image features to the Transformer encoder
                encoded_img = new_model.encoder(img, training=False)

                # Generate the caption using the Transformer decoder
                decoded_caption = "<start> "
                for i in range(max_decoded_sentence_length):
                    tokenized_caption = vectorization([decoded_caption])[:, :-1]
                    mask = tf.math.not_equal(tokenized_caption, 0)
                    predictions = new_model.decoder(
                        tokenized_caption, encoded_img, training=False, mask=mask
                    )
                    sampled_token_index = np.argmax(predictions[0, i, :])
                    sampled_token = index_lookup[sampled_token_index]
                    if sampled_token == " <end>":
                        break
                    decoded_caption += " " + sampled_token

                decoded_caption = decoded_caption.replace("<start> ", "")
                decoded_caption = decoded_caption.replace(" <end>", "").strip()
                return decoded_caption
            vocab = vectorization.get_vocabulary()
            index_lookup = dict(zip(range(len(vocab)), vocab))
            max_decoded_sentence_length = SEQ_LENGTH - 1
            valid_images = list(valid_data.keys())
            # caption=generate_caption()
            new_model = ImageCaptioningModel(cnn_model=cnn_model, encoder=encoder, decoder=decoder, image_aug=image_augmentation)
            # Call the model with some dummy input data to create its variables
            dummy_img = tf.zeros((1, 229, 229, 3))
            dummy_seq = tf.zeros((1, 5, 512))
            new_model.cnn_model(dummy_img)
            new_model.encoder(dummy_seq)
            new_model.decoder(dummy_seq)

            # Load the weights
            new_model.load_weights('model_weights.h5')
            st.success(f'Model  Loaded!', icon="✅")
            # st.success(f'Reported EER for the selected model {reported_eer}%')                
        with st.spinner("Getting prediction..."):
                
            
            


                


            # Check predictions for a few samples
                caption=generate_caption()
                # print(audio.shape)
                if caption:
                    st.success(caption, icon="✅")
                else:
                    # st.error(f"The Sample is spoof: \n Confidence {(prediction_value) }%",  icon="🚨")
                    st.error(f"Error occured in caption generation",  icon="🚨")

    
    # if choose == "Comparative Analysis":
    #     st.markdown(""" <style> .font {
    #         font-size:25px ; font-family: 'Cooper Black'; color: #FF9633;} 
    #         </style> """, unsafe_allow_html=True)
    #     st.markdown('<p class="font">Comparison of Models</p>', unsafe_allow_html=True)
    #     data_frame = get_data()
    #     tab1, tab2 = st.tabs(["EER", "min-TDCF"])
    #     with tab1:
    #         data_frame["EER ASVS 2019"] = data_frame["EER ASVS 2019"].astype('float64') 
    #         data_frame["EER ASVS 2021"] = data_frame["EER ASVS 2021"].astype('float64') 
    #         data_frame["Cross-dataset 19-21"] = data_frame["Cross-dataset 19-21"].astype('float64') 

    #         data = data_frame[["Model Name","EER ASVS 2019","EER ASVS 2021","Cross-dataset 19-21"]].reset_index(drop=True).melt('Model Name')
    #         chart=alt.Chart(data).mark_line().encode(
    #             x='Model Name',
    #             y='value',
    #             color='variable'
    #         )
    #         st.altair_chart(chart, theme=None, use_container_width=True)
    #     with tab2:
    #         data_frame["min-TDCF ASVS 2019"] = data_frame["EER ASVS 2019"].astype('float64') 
    #         data_frame["min-TDCF ASVS 2021"] = data_frame["EER ASVS 2021"].astype('float64') 
    #         data_frame["min-TDCF Cross-dataset"] = data_frame["Cross-dataset 19-21"].astype('float64')

    #         data = data_frame[["Model Name","min-TDCF ASVS 2019","min-TDCF ASVS 2021","min-TDCF Cross-dataset"]].reset_index(drop=True).melt('Model Name')
    #         chart=alt.Chart(data).mark_line().encode(
    #             x='Model Name',
    #             y='value',
    #             color='variable'
    #         )
    #         st.altair_chart(chart, theme=None, use_container_width=True)
    #     # Data table
    #     st.markdown(""" <style> .appview-container .main .block-container {
    #         max-width: 100%;
    #         padding-top: {1}rem;
    #         padding-right: {1}rem;
    #         padding-left: {1}rem;
    #         padding-bottom: {1}rem;
    #         }</style> """, unsafe_allow_html=True)
    #     st.dataframe(data_frame, use_container_width=True)



    # if choose == "About":
    #     st.markdown(""" <style> .font {
    #         font-size:35px ; font-family: 'Cooper Black'; color: #FF9633;} 
    #         </style> """, unsafe_allow_html=True)
    #     st.markdown('<p class="font">About</p>', unsafe_allow_html=True)
    # if choose == "Contact":
    #     st.markdown(""" <style> .font {
    #         font-size:35px ; font-family: 'Cooper Black'; color: #FF9633;} 
    #         </style> """, unsafe_allow_html=True)
    #     st.markdown('<p class="font">Contact Us</p>', unsafe_allow_html=True)
    #     with st.form(key='columns_in_form2',clear_on_submit=True): #set clear_on_submit=True so that the form will be reset/cleared once it's submitted
    #         #st.write('Please help us improve!')
    #         Name=st.text_input(label='Please Enter Your Name') #Collect user feedback
    #         Email=st.text_input(label='Please Enter Your Email') #Collect user feedback
    #         Message=st.text_input(label='Please Enter Your Message') #Collect user feedback
    #         submitted = st.form_submit_button('Submit')
    #         if submitted:
    #             st.write('Thanks for your contacting us. We will respond to your questions or inquiries as soon as possible!')

display_UI()