File size: 9,277 Bytes
df23474
fbfd3a3
df23474
 
 
 
 
 
 
fbfd3a3
b69511f
df23474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
00341ad
 
 
df23474
 
 
 
b69511f
 
 
 
df23474
 
6b216c6
b69511f
 
 
df23474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b69511f
df23474
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9b3cd13
df23474
 
 
 
 
9b3cd13
 
 
 
 
 
df23474
 
b69511f
df23474
b69511f
df23474
 
 
 
 
 
 
 
 
 
b69511f
 
df23474
 
 
b69511f
df23474
 
 
 
 
b69511f
df23474
 
b69511f
 
df23474
 
 
b69511f
df23474
b69511f
df23474
 
b69511f
df23474
 
 
b69511f
df23474
 
 
b69511f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df23474
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#importing the libraries
import streamlit as st
from PIL import Image
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
import numpy as np
import pandas as pd
import time
import os


model_repository_id = "Dusduo/Pokemon-classification-1stGen"
# Loading the pokemon classifier model and its processor
image_processor = AutoImageProcessor.from_pretrained(model_repository_id)
model = AutoModelForImageClassification.from_pretrained(model_repository_id)
# Loading the pokemon information table
pokemon_info_df = pd.read_csv('pokemon_info.csv')

pokeball_image = Image.open('pokeball.png').resize((20,20))

#functions to predict image
def preprocess(processor: AutoImageProcessor, image):
    return processor(image.convert("RGB").resize((200,200)), return_tensors="pt")

def predict(model: AutoModelForImageClassification, inputs, k=5):
    
    # Forward the image to the model and retrieve the logits 
    with torch.no_grad():
        logits = model(**inputs).logits
    
    # Convert the retrieved logits into a vector of probabilities for each class
    probabilities = torch.softmax(logits[0], dim=0).tolist()
    
    # Discriminate wether or not the inputted image was an image of a Pokemon
    # Compute the variance of the vector of probabilities 
        # The spread of the probability values is a good represent of the confusion of the model
        # Or in other words, its confidence => the greater the spread, the lower its confidence 
    variance = np.var(probabilities)
    
    # Too great of a spread: it is likely the image provided did not correspond to any known classes
    if variance < 0.001: #not a pokemon
       predicted_label = 'not a pokemon' 
       probability = -1
       (top_k_labels, top_k_probability) = '_', '_'
    else: # it is a pokemon
        # Retrieve the predicted class (pokemon)
        predicted_id = logits.argmax(-1).item()
        predicted_label = model.config.id2label[predicted_id]
        # Retrieve the probability for the predicted class, and format it to 2 decimals
        probability = round(probabilities[predicted_id]*100,2)
        # Retrieve the top 5 classes and their probabilities
        #top_k_labels = [model.config.id2label[key] for key in np.argpartition(logits.numpy(), -k)[-k:]]
        #top_k_probability = [round(prob*100,2) for prob in np.sort(probabilities.numpy())[-k:]]
        
    return predicted_label, probability #, (top_k_labels, top_k_probability)



# Designing the interface ------------------------------------------

# Use the full page instead of a narrow central column
st.set_page_config(layout="wide")

# Define the title
st.title("Gotta Classify 'Em All")
st.subheader("Image classifier for Pokemons from the 1st generation.")

# For newline
st.write('\n')



image = Image.open('base.jpg')

col1, col2 = st.columns([1,2]) # [3,1]

with col1:
    image = Image.open('base.jpg')
    show = st.image(image, use_column_width=True)





# Display Sample images  ----
st.subheader('Sample images')

sample_imgs_dir = "sample_imgs/"
sample_imgs = os.listdir(sample_imgs_dir) # get the list of all sample images
img_idx = 0

n_cols = 4
groups = []
for i in range(0, len(sample_imgs), n_cols):
    groups.append(sample_imgs[i:i+n_cols])

for group in groups: 
    cols = st.columns(n_cols)
    for i,image_file in enumerate(group):
        cols[i].image(sample_imgs_dir+image_file)
    


# Sidebar work and model outputs ---------------

st.sidebar.title("Upload Image")

#Disabling warning
#st.set_option('deprecation.showfileUploaderEncoding', False)
#Choose your own image
uploaded_file = st.sidebar.file_uploader("",type=['png', 'jpg', 'jpeg'], accept_multiple_files=False )

if uploaded_file is not None:
    
    u_img = Image.open(uploaded_file)
    show.image(u_img, 'Uploaded Image',use_column_width=True) #, width=400 )#
    
    # Preprocess the image for the model
    model_inputs = preprocess(image_processor, u_img)
    
# For newline
st.sidebar.write('\n')
    
if st.sidebar.button("Click Here to Classify"):
    
    if uploaded_file is None:
        
        st.sidebar.write("Please upload an Image to Classify")
    
    else:
        
        with st.spinner('Classifying ...'):
            # Get prediction
            prediction, probability = predict(model, model_inputs,5) #, (top_k_labels, top_k_probability)
            time.sleep(2)
            st.sidebar.success('Done!')
            
        st.sidebar.header("Model response: ")
        
        # Display prediction
        
        if probability==-1:
            
            st.sidebar.write("""I am sorry I am having trouble finding a matching pokemon. <br> 
                             <b>Potential explanations: </b><br>
                                - The image provided is a Pokemon but not from the 1st Generation. <br>
                                - The image provided is not a Pokemon. <br>
                                - There are too many entities on the image. <br>
                                """, unsafe_allow_html=True)
                             
        else:
            st.sidebar.write(f" It's a(n) <b>{prediction}</b> picture.",'\n', unsafe_allow_html=True)
            
            st.sidebar.write(f'<b>Probability:</b>',probability,'%', unsafe_allow_html=True)
            
            # Retrieve predicted pokemon information
            _, pokedex_number, english_name, romaji_name, katakana_name, weight_kg, height_m, type1, type2, color1, color2, classification, evolve_from, evolve_into, is_legendary = pokemon_info_df[pokemon_info_df['name']==prediction].values[0]
            with col2:
                # pokedex box 
                with st.container(border=True ):
                    # first row
                    with st.container():
                        pokeball_image_col,pokedex_number_col, pokemon_name_col = st.columns([1,1,8])
                        pokeball_image_col.image(pokeball_image)
                        pokedex_number_col.markdown(f'<div style="text-align: left; white-space: nowrap; font-size: 1.8em;"><b>Pokedex n°{pokedex_number}</b></div>', unsafe_allow_html=True)
                        pokemon_name_col.markdown(f'<div style="text-align: right; font-size: 1.8em;"><b>{english_name} <br> {katakana_name}</b></div>', unsafe_allow_html=True)
                        
                    # second row
                    with st.container():
                        st.markdown(f'<div style="text-align: center; color: {color1}; font-size: 1.6em;"><b>{classification}</b></div>', unsafe_allow_html=True)
                        
                    # 3rd row
                    with st.container():
                        if pd.isna(type2):
                            st.write('\n')
                            st.markdown(f'<div style="display: flex; justify-content: center; align-items: center; "><div style="display: inline-block; padding: 1%; margin: 0 1%; border-radius: 5px; background-color: {color1}; color: white; font-size: 1.4em;"><b>{type1}</b></div>', unsafe_allow_html=True)
                        else: 
                            type1_col, type2_col = st.columns(2)
                            type1_col.markdown(f'<div style="display: flex; justify-content: center; align-items: center;"><div style="display: inline-block; padding: 1%; margin: 0 1%; border-radius: 5px; background-color: {color1}; color: white; font-size: 1.4em;"><b>{type1}</b></div>', unsafe_allow_html=True)
                            type2_col.markdown(f'<div style="display: flex; justify-content: center; align-items: center;"><div style="display: inline-block; padding: 1%; margin: 0 1%; border-radius: 5px; background-color: {color2}; color: white; font-size: 1.4em;"><b>{type2}</b></div>', unsafe_allow_html=True) 
                        st.write('\n')
                    # 4th row
                    with st.container():
                        st.write(f'<div style="font-size: 1.8em;"><b>Height:</b> {height_m}m', unsafe_allow_html=True)
                        st.write('\n')
                        st.write(f'<div style="font-size: 1.8em;"><b>Weight:</b> {weight_kg}kg', unsafe_allow_html=True)
                        st.write('\n')
                        if not pd.isna(evolve_from):
                            st.markdown(f'<div style="font-size: 1.8em;"><b>Evolves from:</b> {evolve_from}', unsafe_allow_html=True)
                            #st.write(f'Evolves from: {evolve_from}')
                            st.write('\n')
                        if not pd.isna(evolve_into):
                            st.markdown(f'<div style="font-size: 1.8em;"><b>Evolves into:</b> {evolve_into}', unsafe_allow_html=True)
                            #st.write(f'Evolves into: {evolve_into}')
                            st.write('\n')
                    
st.sidebar.write('\n')
st.sidebar.info(
    """
    - Web App URL: [url](https://huggingface.co/spaces/Dusduo/GottaClassifyEmAll)
    - GitHub repository: [repository](https://github.com/A-Duss/GottaClassifyEmAll.git)
    """
)

st.sidebar.title("Contact")
st.sidebar.info(
    """
    Antoine Dussolle: [LinkedIn](https://www.linkedin.com/in/antoine-dussolle/) | [GitHub](https://github.com/A-Duss) 
    """
)