arichar14 commited on
Commit
ad97e4d
·
verified ·
1 Parent(s): cde329f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -7,7 +7,7 @@ from PIL import Image
7
  model = tf.keras.models.load_model("my_model.keras")
8
 
9
  # Define class labels (update if your order is reversed)
10
- class_names = ["Dog", "Cat"]
11
 
12
  def predict(image):
13
  # Preprocess image (resize to your model's input size)
@@ -16,11 +16,12 @@ def predict(image):
16
  img = np.expand_dims(img, axis=0)
17
 
18
  # Run prediction
19
- predictions = model.predict(img)
20
- predicted_class = class_names[np.argmax(predictions)]
21
- confidence = float(np.max(predictions))
22
-
23
- return {predicted_class: confidence}
 
24
 
25
  # Build Gradio interface
26
  demo = gr.Interface(
 
7
  model = tf.keras.models.load_model("my_model.keras")
8
 
9
  # Define class labels (update if your order is reversed)
10
+ class_names = ["Cat", "Dog"]
11
 
12
  def predict(image):
13
  # Preprocess image (resize to your model's input size)
 
16
  img = np.expand_dims(img, axis=0)
17
 
18
  # Run prediction
19
+ pred = model.predict(img)[0][0] # get scalar value
20
+
21
+ if pred < 0.5:
22
+ return {"Cat": 1 - float(pred), "Dog": float(pred)}
23
+ else:
24
+ return {"Cat": 1 - float(pred), "Dog": float(pred)}
25
 
26
  # Build Gradio interface
27
  demo = gr.Interface(