AfshinMA commited on
Commit
243dba7
·
verified ·
1 Parent(s): 5c339d4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +124 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import required libraries
2
+ import os
3
+ import keras
4
+ import numpy as np
5
+ import pandas as pd
6
+ import streamlit as st
7
+ from PIL import Image
8
+
9
+ # Function to safely load the models
10
+ def load_model_safely(path: str):
11
+ if not os.path.isfile(path) or not path.endswith('.keras'):
12
+ raise FileNotFoundError(f"The file '{path}' does not exist or is not a .keras file.")
13
+ return keras.saving.load_model(path)
14
+
15
+ # Retrieve the current directory and specify model paths
16
+ current_dir = os.getcwd() # Ensure correct initial directory
17
+ model_paths = {
18
+ 'CNN': os.path.join(current_dir, 'models', 'cnn_model.keras'),
19
+ 'VGG19': os.path.join(current_dir, 'models', 'vgg19_model.keras'),
20
+ 'ResNet50': os.path.join(current_dir, 'models', 'resnet50_model.keras'),
21
+ }
22
+
23
+ # Load models and handle potential exceptions
24
+ models = {}
25
+ for name, path in model_paths.items():
26
+ try:
27
+ models[name] = load_model_safely(path)
28
+ except Exception as e:
29
+ st.error(f"Error loading model {name} from {path}: {str(e)}")
30
+
31
+ # Define the class labels
32
+ classes = { 0:'Speed limit (20km/h)', 1:'Speed limit (30km/h)', 2:'Speed limit (50km/h)',
33
+ 3:'Speed limit (60km/h)', 4:'Speed limit (70km/h)', 5:'Speed limit (80km/h)',
34
+ 6:'End of speed limit (80km/h)', 7:'Speed limit (100km/h)', 8:'Speed limit (120km/h)',
35
+ 9:'No passing', 10:'No passing veh over 3.5 tons', 11:'Right-of-way at intersection',
36
+ 12:'Priority road', 13:'Yield', 14:'Stop', 15:'No vehicles',
37
+ 16:'Veh > 3.5 tons prohibited', 17:'No entry', 18:'General caution',
38
+ 19:'Dangerous curve left', 20:'Dangerous curve right', 21:'Double curve',
39
+ 22:'Bumpy road', 23:'Slippery road', 24:'Road narrows on the right',
40
+ 25:'Road work', 26:'Traffic signals', 27:'Pedestrians', 28:'Children crossing',
41
+ 29:'Bicycles crossing', 30:'Beware of ice/snow', 31:'Wild animals crossing',
42
+ 32:'End speed + passing limits', 33:'Turn right ahead', 34:'Turn left ahead',
43
+ 35:'Ahead only', 36:'Go straight or right', 37:'Go straight or left',
44
+ 38:'Keep right', 39:'Keep left', 40:'Roundabout mandatory',
45
+ 41:'End of no passing', 42:'End no passing veh > 3.5 tons' }
46
+
47
+ # Function to preprocess the image and predict the class
48
+ def preprocess_and_predict(image: Image.Image, size=(50, 50)) -> pd.DataFrame:
49
+ img_resized = image.resize(size)
50
+ img_array = np.array(img_resized).astype(np.float32) / 255.0
51
+ img_array = np.expand_dims(img_array, axis=0) # Shape (1, 50, 50, 3)
52
+
53
+ predictions = []
54
+ for name, model in models.items():
55
+ predicted_class_index = np.argmax(model.predict(img_array), axis=-1)[0]
56
+ predictions.append({'Model': name, 'Predicted Label': classes[predicted_class_index]})
57
+
58
+ return pd.DataFrame(predictions)
59
+
60
+ # Import Example images
61
+ images_dir = os.path.join(current_dir, 'images')
62
+
63
+ if os.path.exists(images_dir):
64
+ # Create a list of images and their corresponding classes
65
+ image_list = [img for img in os.listdir(images_dir) if img.lower().endswith('.png')]
66
+ image_dict = {classes[int(img.split('.')[0])] : os.path.join(images_dir, img) for img in image_list}
67
+ else:
68
+ st.error(f"The images directory does not exist: {images_dir}")
69
+
70
+ # Streamlit UI setup
71
+ st.set_page_config(page_title="Traffic Sign Detection App", page_icon="🚦", layout="wide")
72
+ st.title("🚦 Traffic Sign Recognition using CNN, VGG19, ResNet50")
73
+ st.markdown("Upload a traffic sign image or choose an example from below to get the recognition result.")
74
+ st.markdown("---")
75
+
76
+ # Sidebar for image upload and selection
77
+ st.sidebar.header("Input Options")
78
+ uploaded_file = st.sidebar.file_uploader("Upload an image (JPG, JPEG, PNG)", type=["jpg", "jpeg", "png"])
79
+
80
+ # Select an example image
81
+ selected_example = st.sidebar.selectbox("Or select an example image:", list(image_dict.keys()))
82
+ if selected_example:
83
+ example_image_path = image_dict[selected_example]
84
+
85
+ # Initialize a variable to hold the image for prediction
86
+ image_to_predict = None
87
+
88
+ # Check if user uploaded an image or selected an example image
89
+ if uploaded_file is not None:
90
+ image_to_predict = Image.open(uploaded_file)
91
+ st.image(image_to_predict.resize((256, 256)), caption='Uploaded Image', use_container_width=False, output_format="auto")
92
+ elif selected_example:
93
+ image_to_predict = Image.open(example_image_path)
94
+ st.image(image_to_predict.resize((256, 256)), caption='Example Image', use_container_width=False, output_format="auto")
95
+
96
+ # Add a predict button
97
+ if st.sidebar.button("🚀 Predict", key="predict_button") and image_to_predict is not None:
98
+ # Run prediction
99
+ st.write("Predicting ...")
100
+ results = preprocess_and_predict(image_to_predict)
101
+
102
+ # Display results
103
+ st.write("### Prediction Results")
104
+
105
+ # Style the output dataframe
106
+ st.dataframe(results)
107
+
108
+ # Add some custom CSS for better styling
109
+ st.markdown("""
110
+ <style>
111
+ .stButton > button:hover {
112
+ background-color: #0052cc; /* Darker blue on hover */
113
+ }
114
+ .stDataframe {
115
+ border: 1px solid #ddd; /* Light border for clarity */
116
+ border-radius: 10px; /* Rounded corners for the dataframe */
117
+ }
118
+ .stImage {
119
+ border: 2px solid #0066ff; /* Border for images */
120
+ border-radius: 10px; /* Rounded corners */
121
+ box-shadow: 0 0 8px rgba(0, 0, 0, 0.2); /* Subtle shadow */
122
+ }
123
+ </style>
124
+ """, unsafe_allow_html=True)
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pandas
2
+ numpy
3
+ pillow
4
+ scikit-learn
5
+ keras