Spaces:
Sleeping
Sleeping
File size: 9,277 Bytes
df23474 fbfd3a3 df23474 fbfd3a3 b69511f df23474 00341ad df23474 b69511f df23474 6b216c6 b69511f df23474 b69511f df23474 9b3cd13 df23474 9b3cd13 df23474 b69511f df23474 b69511f df23474 b69511f df23474 b69511f df23474 b69511f df23474 b69511f df23474 b69511f df23474 b69511f df23474 b69511f df23474 b69511f df23474 b69511f df23474 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
#importing the libraries
import streamlit as st
from PIL import Image
import torch
from transformers import AutoModelForImageClassification, AutoImageProcessor
import numpy as np
import pandas as pd
import time
import os
model_repository_id = "Dusduo/Pokemon-classification-1stGen"
# Loading the pokemon classifier model and its processor
image_processor = AutoImageProcessor.from_pretrained(model_repository_id)
model = AutoModelForImageClassification.from_pretrained(model_repository_id)
# Loading the pokemon information table
pokemon_info_df = pd.read_csv('pokemon_info.csv')
pokeball_image = Image.open('pokeball.png').resize((20,20))
#functions to predict image
def preprocess(processor: AutoImageProcessor, image):
return processor(image.convert("RGB").resize((200,200)), return_tensors="pt")
def predict(model: AutoModelForImageClassification, inputs, k=5):
# Forward the image to the model and retrieve the logits
with torch.no_grad():
logits = model(**inputs).logits
# Convert the retrieved logits into a vector of probabilities for each class
probabilities = torch.softmax(logits[0], dim=0).tolist()
# Discriminate wether or not the inputted image was an image of a Pokemon
# Compute the variance of the vector of probabilities
# The spread of the probability values is a good represent of the confusion of the model
# Or in other words, its confidence => the greater the spread, the lower its confidence
variance = np.var(probabilities)
# Too great of a spread: it is likely the image provided did not correspond to any known classes
if variance < 0.001: #not a pokemon
predicted_label = 'not a pokemon'
probability = -1
(top_k_labels, top_k_probability) = '_', '_'
else: # it is a pokemon
# Retrieve the predicted class (pokemon)
predicted_id = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_id]
# Retrieve the probability for the predicted class, and format it to 2 decimals
probability = round(probabilities[predicted_id]*100,2)
# Retrieve the top 5 classes and their probabilities
#top_k_labels = [model.config.id2label[key] for key in np.argpartition(logits.numpy(), -k)[-k:]]
#top_k_probability = [round(prob*100,2) for prob in np.sort(probabilities.numpy())[-k:]]
return predicted_label, probability #, (top_k_labels, top_k_probability)
# Designing the interface ------------------------------------------
# Use the full page instead of a narrow central column
st.set_page_config(layout="wide")
# Define the title
st.title("Gotta Classify 'Em All")
st.subheader("Image classifier for Pokemons from the 1st generation.")
# For newline
st.write('\n')
image = Image.open('base.jpg')
col1, col2 = st.columns([1,2]) # [3,1]
with col1:
image = Image.open('base.jpg')
show = st.image(image, use_column_width=True)
# Display Sample images ----
st.subheader('Sample images')
sample_imgs_dir = "sample_imgs/"
sample_imgs = os.listdir(sample_imgs_dir) # get the list of all sample images
img_idx = 0
n_cols = 4
groups = []
for i in range(0, len(sample_imgs), n_cols):
groups.append(sample_imgs[i:i+n_cols])
for group in groups:
cols = st.columns(n_cols)
for i,image_file in enumerate(group):
cols[i].image(sample_imgs_dir+image_file)
# Sidebar work and model outputs ---------------
st.sidebar.title("Upload Image")
#Disabling warning
#st.set_option('deprecation.showfileUploaderEncoding', False)
#Choose your own image
uploaded_file = st.sidebar.file_uploader("",type=['png', 'jpg', 'jpeg'], accept_multiple_files=False )
if uploaded_file is not None:
u_img = Image.open(uploaded_file)
show.image(u_img, 'Uploaded Image',use_column_width=True) #, width=400 )#
# Preprocess the image for the model
model_inputs = preprocess(image_processor, u_img)
# For newline
st.sidebar.write('\n')
if st.sidebar.button("Click Here to Classify"):
if uploaded_file is None:
st.sidebar.write("Please upload an Image to Classify")
else:
with st.spinner('Classifying ...'):
# Get prediction
prediction, probability = predict(model, model_inputs,5) #, (top_k_labels, top_k_probability)
time.sleep(2)
st.sidebar.success('Done!')
st.sidebar.header("Model response: ")
# Display prediction
if probability==-1:
st.sidebar.write("""I am sorry I am having trouble finding a matching pokemon. <br>
<b>Potential explanations: </b><br>
- The image provided is a Pokemon but not from the 1st Generation. <br>
- The image provided is not a Pokemon. <br>
- There are too many entities on the image. <br>
""", unsafe_allow_html=True)
else:
st.sidebar.write(f" It's a(n) <b>{prediction}</b> picture.",'\n', unsafe_allow_html=True)
st.sidebar.write(f'<b>Probability:</b>',probability,'%', unsafe_allow_html=True)
# Retrieve predicted pokemon information
_, pokedex_number, english_name, romaji_name, katakana_name, weight_kg, height_m, type1, type2, color1, color2, classification, evolve_from, evolve_into, is_legendary = pokemon_info_df[pokemon_info_df['name']==prediction].values[0]
with col2:
# pokedex box
with st.container(border=True ):
# first row
with st.container():
pokeball_image_col,pokedex_number_col, pokemon_name_col = st.columns([1,1,8])
pokeball_image_col.image(pokeball_image)
pokedex_number_col.markdown(f'<div style="text-align: left; white-space: nowrap; font-size: 1.8em;"><b>Pokedex n°{pokedex_number}</b></div>', unsafe_allow_html=True)
pokemon_name_col.markdown(f'<div style="text-align: right; font-size: 1.8em;"><b>{english_name} <br> {katakana_name}</b></div>', unsafe_allow_html=True)
# second row
with st.container():
st.markdown(f'<div style="text-align: center; color: {color1}; font-size: 1.6em;"><b>{classification}</b></div>', unsafe_allow_html=True)
# 3rd row
with st.container():
if pd.isna(type2):
st.write('\n')
st.markdown(f'<div style="display: flex; justify-content: center; align-items: center; "><div style="display: inline-block; padding: 1%; margin: 0 1%; border-radius: 5px; background-color: {color1}; color: white; font-size: 1.4em;"><b>{type1}</b></div>', unsafe_allow_html=True)
else:
type1_col, type2_col = st.columns(2)
type1_col.markdown(f'<div style="display: flex; justify-content: center; align-items: center;"><div style="display: inline-block; padding: 1%; margin: 0 1%; border-radius: 5px; background-color: {color1}; color: white; font-size: 1.4em;"><b>{type1}</b></div>', unsafe_allow_html=True)
type2_col.markdown(f'<div style="display: flex; justify-content: center; align-items: center;"><div style="display: inline-block; padding: 1%; margin: 0 1%; border-radius: 5px; background-color: {color2}; color: white; font-size: 1.4em;"><b>{type2}</b></div>', unsafe_allow_html=True)
st.write('\n')
# 4th row
with st.container():
st.write(f'<div style="font-size: 1.8em;"><b>Height:</b> {height_m}m', unsafe_allow_html=True)
st.write('\n')
st.write(f'<div style="font-size: 1.8em;"><b>Weight:</b> {weight_kg}kg', unsafe_allow_html=True)
st.write('\n')
if not pd.isna(evolve_from):
st.markdown(f'<div style="font-size: 1.8em;"><b>Evolves from:</b> {evolve_from}', unsafe_allow_html=True)
#st.write(f'Evolves from: {evolve_from}')
st.write('\n')
if not pd.isna(evolve_into):
st.markdown(f'<div style="font-size: 1.8em;"><b>Evolves into:</b> {evolve_into}', unsafe_allow_html=True)
#st.write(f'Evolves into: {evolve_into}')
st.write('\n')
st.sidebar.write('\n')
st.sidebar.info(
"""
- Web App URL: [url](https://huggingface.co/spaces/Dusduo/GottaClassifyEmAll)
- GitHub repository: [repository](https://github.com/A-Duss/GottaClassifyEmAll.git)
"""
)
st.sidebar.title("Contact")
st.sidebar.info(
"""
Antoine Dussolle: [LinkedIn](https://www.linkedin.com/in/antoine-dussolle/) | [GitHub](https://github.com/A-Duss)
"""
)
|