|
import tensorflow as tf |
|
import numpy as np |
|
from urllib.request import urlretrieve |
|
import gradio as gr |
|
import numpy as np |
|
|
|
urlretrieve("https://huggingface.co/guiwitz/mnist2023/resolve/main/mnist_model.keras", "mnist_model.keras") |
|
model = tf.keras.models.load_model("mnist_model.keras") |
|
|
|
def recognize_digit(image): |
|
image = image[np.newaxis,:,:, np.newaxis] |
|
prediction = model.predict(image).tolist()[0] |
|
return {str(i): prediction[i] for i in range(10)} |
|
|
|
gr.Interface(fn=recognize_digit, |
|
inputs="sketchpad", |
|
outputs=gr.Label(num_top_classes=3), |
|
live=True, |
|
description="Live MNIST.").launch(); |