kdevoe commited on
Commit
654e5aa
·
verified ·
1 Parent(s): 0b3395a

Adding bar chart for predicted probabilites

Browse files
Files changed (1) hide show
  1. app.py +18 -8
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import gradio as gr
 
2
  import numpy as np
3
  import cv2
4
  import tensorflow as tf
@@ -45,14 +46,22 @@ def preprocess_frame(frame):
45
 
46
 
47
  def predict_asl(frame):
48
- """Predict the ASL sign from the webcam frame."""
49
- # Preprocess the frame
50
  processed_frame = preprocess_frame(frame)
51
- # Make a prediction
52
- predictions = model.predict(processed_frame)
53
- # Get the class with the highest probability
54
- predicted_label = labels[np.argmax(predictions)]
55
- return predicted_label
 
 
 
 
 
 
 
 
 
56
 
57
  css = """.my-group {max-width: 500px !important; max-height: 500px !important;}
58
  .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
@@ -62,7 +71,8 @@ with gr.Blocks(css=css) as demo:
62
  with gr.Group(elem_classes=["my-group"]):
63
  input_img = gr.Image(sources=["webcam"], type="numpy", streaming=True, label="Webcam Input")
64
  output_label = gr.Label(label="Predicted ASL Sign")
 
65
 
66
- input_img.stream(predict_asl, [input_img], [output_label], time_limit=300, stream_every=0.1)
67
 
68
  demo.launch()
 
1
  import gradio as gr
2
+ import matplotlib.pyplot as plt
3
  import numpy as np
4
  import cv2
5
  import tensorflow as tf
 
46
 
47
 
48
  def predict_asl(frame):
49
+ """Predict the ASL sign and return the label and probabilities."""
 
50
  processed_frame = preprocess_frame(frame)
51
+ predictions = model.predict(processed_frame) # Predict probabilities
52
+ predicted_label = labels[np.argmax(predictions)] # Get the class with the highest probability
53
+
54
+ # Generate a bar chart for probabilities
55
+ fig, ax = plt.subplots(figsize=(6, 4))
56
+ ax.bar(labels, predictions[0])
57
+ ax.set_title("Class Probabilities")
58
+ ax.set_ylabel("Probability")
59
+ ax.set_xlabel("ASL Classes")
60
+ ax.set_xticks(range(len(labels)))
61
+ ax.set_xticklabels(labels, rotation=45)
62
+ plt.tight_layout()
63
+
64
+ return predicted_label, fig
65
 
66
  css = """.my-group {max-width: 500px !important; max-height: 500px !important;}
67
  .my-column {display: flex !important; justify-content: center !important; align-items: center !important};"""
 
71
  with gr.Group(elem_classes=["my-group"]):
72
  input_img = gr.Image(sources=["webcam"], type="numpy", streaming=True, label="Webcam Input")
73
  output_label = gr.Label(label="Predicted ASL Sign")
74
+ output_plot = gr.Plot(label="Class Probabilities")
75
 
76
+ input_img.stream(predict_asl, [input_img], [output_label], time_limit=300, stream_every=0.5)
77
 
78
  demo.launch()