File size: 3,905 Bytes
1fad6d0 5cc8910 d35e71b 5cc8910 a04766c 5cc8910 a42bcbf 5cc8910 a04766c 5cc8910 a04766c 5cc8910 e34292b 4fc8fe7 a04766c 82ea080 4fc8fe7 a04766c a7b98c5 c9424af 5cc8910 a04766c a7b98c5 a04766c 5cc8910 a04766c a7b98c5 5cc8910 a04766c 5cc8910 c9424af a04766c 5cc8910 a04766c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18
from PIL import Image
import base64
import io
# ---------------- CONFIG ----------------
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
theme_color = "#6C5B7B"
# ---------------- MODEL ----------------
class Classifier(nn.Module):
def __init__(self):
super().__init__()
self.cnn_layers = resnet18(weights=None)
self.fc_layers = nn.Sequential(
nn.Linear(1000, 512),
nn.Dropout(0.3),
nn.Linear(512, 128),
nn.ReLU(),
nn.Linear(128, len(labels))
)
def forward(self, x):
x = self.cnn_layers(x)
x = self.fc_layers(x)
return x
preprocess = transforms.Compose([
transforms.Resize((224,224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
model = Classifier()
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
model.eval()
# ---------------- FUNZIONE ----------------
def predict(image_input):
"""
Supporta:
- PIL Image (UI web)
- stringa base64 (API)
"""
try:
if isinstance(image_input, str):
if image_input.startswith("data:image"):
image_input = image_input.split(",",1)[1]
img_bytes = base64.b64decode(image_input)
img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
else:
img = image_input.convert("RGB")
img_tensor = preprocess(img).unsqueeze(0)
with torch.no_grad():
logits = model(img_tensor)
probs = torch.nn.functional.softmax(logits[0], dim=0)
probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))}
max_label = max(probs_dict, key=probs_dict.get)
return max_label, probs_dict
except Exception as e:
return f"Error: {str(e)}", {}
def clear_all():
return "", ""
# ---------------- INTERFACCIA ----------------
with gr.Blocks(title="NSFW Image Classifier") as demo:
gr.HTML(f"""
<div style="padding:10px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:10px;">
<h2 style="color:{theme_color};">🎨 NSFW Image Classifier</h2>
<p>Carica un'immagine o incolla la stringa base64 per analizzarla.</p>
</div>
""")
with gr.Row():
with gr.Column(scale=2):
# Input UI
img_input = gr.Image(label="📷 Carica immagine", type="pil")
base64_input = gr.Textbox(
label="📤 Base64 dell'immagine (API)",
lines=6,
placeholder="Incolla qui la stringa base64..."
)
with gr.Row():
submit_btn = gr.Button("✨ Analizza", variant="primary")
clear_btn = gr.Button("🔄 Pulisci", variant="secondary")
with gr.Column(scale=1):
label_output = gr.Textbox(label="Classe predetta", interactive=False)
result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
# ---------------- Eventi UI ----------------
submit_btn.click(
fn=predict,
inputs=[img_input],
outputs=[label_output, result_display]
)
clear_btn.click(fn=clear_all, inputs=None, outputs=[img_input, base64_input])
# ---------------- Pulsante invisibile per API base64 ----------------
api_button = gr.Button(visible=False)
api_button.click(
fn=predict,
inputs=[base64_input],
outputs=[label_output, result_display],
api_name="predict" # espone /run/predict
)
# ---------------- LAUNCH ----------------
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)
|