Spaces:
Runtime error
Runtime error
from PIL import Image | |
import io | |
import requests | |
import os | |
from dotenv import load_dotenv | |
import gradio as gr | |
# Tải biến môi trường từ tệp .env | |
load_dotenv() | |
# Định nghĩa API URL và headers | |
API_URL = "https://api-inference.huggingface.co/models/vieanh/vit-sports-cls" | |
headers = {"Authorization": f"Bearer {os.getenv('HF_API_TOKEN')}"} | |
def predict(image): | |
""" | |
Hàm xử lý ảnh đầu vào và gửi đến API để dự đoán thể loại thể thao. | |
Parameters: | |
image: Ảnh được upload (PIL Image hoặc bytes). | |
Returns: | |
dict: Nhãn dự đoán và xác suất tương ứng. | |
""" | |
try: | |
# Chuyển đổi ảnh về định dạng PIL nếu cần thiết | |
if isinstance(image, bytes): | |
pil_image = Image.open(io.BytesIO(image)).convert("RGB") | |
elif isinstance(image, Image.Image): | |
pil_image = image.convert("RGB") | |
elif isinstance(image, str): | |
pil_image = Image.open(image).convert("RGB") | |
else: | |
return {"error": "Ảnh không hợp lệ. Vui lòng upload ảnh đúng định dạng."} | |
# Lưu ảnh vào buffer dưới định dạng JPEG | |
buffered = io.BytesIO() | |
pil_image.save(buffered, format="JPEG") | |
img_bytes = buffered.getvalue() | |
# Gửi request đến API Hugging Face | |
response = requests.post(API_URL, headers=headers, data=img_bytes) | |
response.raise_for_status() # Kiểm tra nếu có lỗi HTTP | |
# Xử lý kết quả trả về từ API | |
predictions = response.json() | |
highest_prediction = max(predictions, key=lambda x: x["score"]) # Tìm nhãn có xác suất cao nhất | |
label = highest_prediction["label"] | |
score = highest_prediction["score"] | |
# Trả về nhãn và xác suất | |
return {label: score} | |
except requests.exceptions.RequestException as e: | |
# Trường hợp xảy ra lỗi khi gửi request | |
return {"error": str(e)} | |
except Exception as e: | |
# Trường hợp lỗi không xác định | |
return {"error": str(e)} | |
# Tạo giao diện Gradio | |
interface = gr.Interface( | |
fn=predict, # Hàm xử lý dự đoán | |
inputs=gr.Image(type="pil", label="Upload an Image"), # Đầu vào: Ảnh | |
outputs=gr.Label(label="Predicted Sport"), # Đầu ra: Nhãn và xác suất | |
title="Sports Image Classification", # Tiêu đề giao diện | |
description="Upload an image of a sport and get the predicted category with the highest score." # Mô tả giao diện | |
) | |
# Khởi chạy giao diện | |
interface.launch() |