Spaces:
Sleeping
Sleeping
import tensorflow as tf | |
from PIL import Image | |
import numpy as np | |
from fastapi import FastAPI | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
import io | |
import base64 | |
import uvicorn | |
from fastapi.middleware.cors import CORSMiddleware | |
import pydantic | |
app = FastAPI() | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_methods=["*"], | |
allow_headers=["*"] | |
) | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
def read_root(): | |
return FileResponse("static/index.html") | |
# Load the model | |
model = tf.keras.models.load_model('model/k49_model.h5') | |
class ImageRequest(pydantic.BaseModel): | |
image: str | |
label: int | |
def predict(request: dict): | |
image = request["image"] | |
label = request["label"] | |
# Decode base64-encoded image | |
image_bytes = base64.b64decode(image["image"].split(",")[1]) | |
# Convert image to preprocessed numpy array and save image to server | |
X = preprocess(image_bytes) | |
image = Image.fromarray(X[0].reshape(28, 28) * 255).convert("RGB") | |
image.save("image.png") | |
# Make prediction using the model | |
prediction = model.predict(X) | |
# Convert prediction to integer | |
prediction = int(np.argmax(prediction)) | |
print("label: ", label) | |
print("prediction: ", prediction) | |
return {'prediction': prediction == int(label)} | |
def preprocess(image_bytes): | |
# Decode base64-encoded image and convert to PIL Image object | |
img = Image.open(io.BytesIO(image_bytes)) | |
# Convert transparent background to white | |
if img.mode == 'RGBA': | |
img.load() | |
background = Image.new('RGB', img.size, (255, 255, 255)) | |
background.paste(img, mask=img.split()[3]) | |
img = background | |
# Convert to grayscale | |
img = img.convert('L') | |
# Invert colors to black background with white numbers | |
img = np.invert(np.array(img)) | |
# Resize image to 28x28 pixels | |
img = Image.fromarray(img).resize((28, 28)) | |
# Convert image to numpy array and normalize pixel values | |
X = np.array(img).reshape(1, 28, 28, 1) / 255 | |
return X | |
if __name__ == '__main__': | |
uvicorn.run('app:app', host='localhost', port=8000, reload=True) | |