vit-sports-cls / app.py
vieanh's picture
last fix
061d60e
raw
history blame contribute delete
2.64 kB
from PIL import Image
import io
import requests
import os
from dotenv import load_dotenv
import gradio as gr
# Tải biến môi trường từ tệp .env
load_dotenv()
# Định nghĩa API URL và headers
API_URL = "https://api-inference.huggingface.co/models/vieanh/vit-sports-cls"
headers = {"Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}"}
def predict(image):
"""
Hàm xử lý ảnh đầu vào và gửi đến API để dự đoán thể loại thể thao.
Parameters:
image: Ảnh được upload (PIL Image hoặc bytes).
Returns:
dict: Nhãn dự đoán và xác suất tương ứng.
"""
try:
# Chuyển đổi ảnh về định dạng PIL nếu cần thiết
if isinstance(image, bytes):
pil_image = Image.open(io.BytesIO(image)).convert("RGB")
elif isinstance(image, Image.Image):
pil_image = image.convert("RGB")
elif isinstance(image, str):
pil_image = Image.open(image).convert("RGB")
else:
return {"error": "Ảnh không hợp lệ. Vui lòng upload ảnh đúng định dạng."}
# Lưu ảnh vào buffer dưới định dạng JPEG
buffered = io.BytesIO()
pil_image.save(buffered, format="JPEG")
img_bytes = buffered.getvalue()
# Gửi request đến API Hugging Face
response = requests.post(API_URL, headers=headers, data=img_bytes)
response.raise_for_status() # Kiểm tra nếu có lỗi HTTP
# Xử lý kết quả trả về từ API
predictions = response.json()
highest_prediction = max(predictions, key=lambda x: x["score"]) # Tìm nhãn có xác suất cao nhất
label = highest_prediction["label"]
score = highest_prediction["score"]
# Trả về nhãn và xác suất
return {label: score}
except requests.exceptions.RequestException as e:
# Trường hợp xảy ra lỗi khi gửi request
return {"error": str(e)}
except Exception as e:
# Trường hợp lỗi không xác định
return {"error": str(e)}
# Tạo giao diện Gradio
interface = gr.Interface(
fn=predict, # Hàm xử lý dự đoán
inputs=gr.Image(type="pil", label="Upload an Image"), # Đầu vào: Ảnh
outputs=gr.Label(label="Predicted Sport"), # Đầu ra: Nhãn và xác suất
title="Sports Image Classification", # Tiêu đề giao diện
description="Upload an image of a sport and get the predicted category with the highest score." # Mô tả giao diện
)
# Khởi chạy giao diện
interface.launch()