import gradio as gr import torch import torch.nn as nn from torchvision import transforms from torchvision.models import resnet18 from PIL import Image import base64 import io # ---------------- CONFIG ---------------- labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"] theme_color = "#6C5B7B" # ---------------- MODEL ---------------- class Classifier(nn.Module): def __init__(self): super().__init__() self.cnn_layers = resnet18(weights=None) self.fc_layers = nn.Sequential( nn.Linear(1000, 512), nn.Dropout(0.3), nn.Linear(512, 128), nn.ReLU(), nn.Linear(128, len(labels)) ) def forward(self, x): x = self.cnn_layers(x) x = self.fc_layers(x) return x preprocess = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) model = Classifier() model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu")) model.eval() # ---------------- FUNZIONE ---------------- def predict(image_input): """ Supporta: - PIL Image (UI web) - stringa base64 (API) """ try: if isinstance(image_input, str): if image_input.startswith("data:image"): image_input = image_input.split(",",1)[1] img_bytes = base64.b64decode(image_input) img = Image.open(io.BytesIO(img_bytes)).convert("RGB") else: img = image_input.convert("RGB") img_tensor = preprocess(img).unsqueeze(0) with torch.no_grad(): logits = model(img_tensor) probs = torch.nn.functional.softmax(logits[0], dim=0) probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))} max_label = max(probs_dict, key=probs_dict.get) return max_label, probs_dict except Exception as e: return f"Error: {str(e)}", {} def clear_all(): return "", "" # ---------------- INTERFACCIA ---------------- with gr.Blocks(title="NSFW Image Classifier") as demo: gr.HTML(f"""
Carica un'immagine o incolla la stringa base64 per analizzarla.