File size: 3,905 Bytes
1fad6d0
5cc8910
 
 
 
d35e71b
5cc8910
 
 
 
 
a04766c
5cc8910
a42bcbf
5cc8910
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a04766c
5cc8910
 
a04766c
5cc8910
 
 
 
 
e34292b
4fc8fe7
a04766c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82ea080
4fc8fe7
a04766c
 
 
 
 
 
 
 
a7b98c5
c9424af
5cc8910
a04766c
a7b98c5
a04766c
 
 
 
 
5cc8910
a04766c
 
a7b98c5
5cc8910
a04766c
5cc8910
c9424af
a04766c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5cc8910
a04766c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.models import resnet18
from PIL import Image
import base64
import io

# ---------------- CONFIG ----------------
labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
theme_color = "#6C5B7B"

# ---------------- MODEL ----------------
class Classifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.cnn_layers = resnet18(weights=None)
        self.fc_layers = nn.Sequential(
            nn.Linear(1000, 512),
            nn.Dropout(0.3),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, len(labels))
        )

    def forward(self, x):
        x = self.cnn_layers(x)
        x = self.fc_layers(x)
        return x

preprocess = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])

model = Classifier()
model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
model.eval()

# ---------------- FUNZIONE ----------------
def predict(image_input):
    """
    Supporta:
    - PIL Image (UI web)
    - stringa base64 (API)
    """
    try:
        if isinstance(image_input, str):
            if image_input.startswith("data:image"):
                image_input = image_input.split(",",1)[1]
            img_bytes = base64.b64decode(image_input)
            img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        else:
            img = image_input.convert("RGB")

        img_tensor = preprocess(img).unsqueeze(0)

        with torch.no_grad():
            logits = model(img_tensor)
            probs = torch.nn.functional.softmax(logits[0], dim=0)

        probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))}
        max_label = max(probs_dict, key=probs_dict.get)

        return max_label, probs_dict

    except Exception as e:
        return f"Error: {str(e)}", {}

def clear_all():
    return "", ""

# ---------------- INTERFACCIA ----------------
with gr.Blocks(title="NSFW Image Classifier") as demo:

    gr.HTML(f"""
    <div style="padding:10px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:10px;">
        <h2 style="color:{theme_color};">🎨 NSFW Image Classifier</h2>
        <p>Carica un'immagine o incolla la stringa base64 per analizzarla.</p>
    </div>
    """)

    with gr.Row():
        with gr.Column(scale=2):
            # Input UI
            img_input = gr.Image(label="📷 Carica immagine", type="pil")
            base64_input = gr.Textbox(
                label="📤 Base64 dell'immagine (API)",
                lines=6,
                placeholder="Incolla qui la stringa base64..."
            )
            with gr.Row():
                submit_btn = gr.Button("✨ Analizza", variant="primary")
                clear_btn = gr.Button("🔄 Pulisci", variant="secondary")

        with gr.Column(scale=1):
            label_output = gr.Textbox(label="Classe predetta", interactive=False)
            result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))

    # ---------------- Eventi UI ----------------
    submit_btn.click(
        fn=predict,
        inputs=[img_input],
        outputs=[label_output, result_display]
    )
    clear_btn.click(fn=clear_all, inputs=None, outputs=[img_input, base64_input])

    # ---------------- Pulsante invisibile per API base64 ----------------
    api_button = gr.Button(visible=False)
    api_button.click(
        fn=predict,
        inputs=[base64_input],
        outputs=[label_output, result_display],
        api_name="predict"  # espone /run/predict
    )

# ---------------- LAUNCH ----------------
if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)