import gradio as gr import torch import torch.nn as nn from torchvision import transforms from torchvision.models import resnet18 from PIL import Image import base64 import io # ---------------- CONFIG ---------------- labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"] theme_color = "#6C5B7B" # ---------------- MODEL ---------------- class Classifier(nn.Module): def __init__(self): super().__init__() self.cnn_layers = resnet18(weights=None) self.fc_layers = nn.Sequential( nn.Linear(1000, 512), nn.Dropout(0.3), nn.Linear(512, 128), nn.ReLU(), nn.Linear(128, len(labels)) ) def forward(self, x): x = self.cnn_layers(x) x = self.fc_layers(x) return x preprocess = transforms.Compose([ transforms.Resize((224,224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]) ]) model = Classifier() model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu")) model.eval() # ---------------- FUNZIONE ---------------- def predict(image_input): """ Supporta: - PIL Image (UI web) - stringa base64 (API) """ try: if isinstance(image_input, str): if image_input.startswith("data:image"): image_input = image_input.split(",",1)[1] img_bytes = base64.b64decode(image_input) img = Image.open(io.BytesIO(img_bytes)).convert("RGB") else: img = image_input.convert("RGB") img_tensor = preprocess(img).unsqueeze(0) with torch.no_grad(): logits = model(img_tensor) probs = torch.nn.functional.softmax(logits[0], dim=0) probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))} max_label = max(probs_dict, key=probs_dict.get) return max_label, probs_dict except Exception as e: return f"Error: {str(e)}", {} def clear_all(): return "", "" # ---------------- INTERFACCIA ---------------- with gr.Blocks(title="NSFW Image Classifier") as demo: gr.HTML(f"""

🎨 NSFW Image Classifier

Carica un'immagine o incolla la stringa base64 per analizzarla.

""") with gr.Row(): with gr.Column(scale=2): # Input UI img_input = gr.Image(label="📷 Carica immagine", type="pil") base64_input = gr.Textbox( label="📤 Base64 dell'immagine (API)", lines=6, placeholder="Incolla qui la stringa base64..." ) with gr.Row(): submit_btn = gr.Button("✨ Analizza", variant="primary") clear_btn = gr.Button("🔄 Pulisci", variant="secondary") with gr.Column(scale=1): label_output = gr.Textbox(label="Classe predetta", interactive=False) result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels)) # ---------------- Eventi UI ---------------- submit_btn.click( fn=predict, inputs=[img_input], outputs=[label_output, result_display] ) clear_btn.click(fn=clear_all, inputs=None, outputs=[img_input, base64_input]) # ---------------- Pulsante invisibile per API base64 ---------------- api_button = gr.Button(visible=False) api_button.click( fn=predict, inputs=[base64_input], outputs=[label_output, result_display], api_name="predict" # espone /run/predict ) # ---------------- LAUNCH ---------------- if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)