ALSv commited on
Commit
5cc8910
·
verified ·
1 Parent(s): d35e71b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -34
app.py CHANGED
@@ -1,46 +1,122 @@
1
  import gradio as gr
2
- import base64
3
- from io import BytesIO
 
 
4
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- # Funzione di analisi immagine base64
7
- def analyze_base64(b64_string: str):
 
 
 
 
 
8
  try:
9
- # Decodifica
10
- if b64_string.startswith("data:image"):
11
- b64_string = b64_string.split(",")[1]
12
-
13
- image_data = base64.b64decode(b64_string)
14
- img = Image.open(BytesIO(image_data))
15
-
16
- # Qui metti il tuo modello / logica di analisi
17
- # Per esempio: restituisco dimensioni e formato
18
- result = {
19
- "width": img.width,
20
- "height": img.height,
21
- "format": img.format
22
- }
23
- return f"Analisi completata: {result}"
 
 
 
 
24
  except Exception as e:
25
- return f"Errore: {str(e)}"
 
 
 
26
 
27
- # Se carichi da web → converto subito in base64
28
- def file_to_base64(img: Image.Image):
29
- buffered = BytesIO()
30
- img.save(buffered, format="PNG")
31
- return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode()
32
 
33
- with gr.Blocks() as demo:
34
- gr.Markdown("## Analisi Immagini via Base64")
 
 
 
 
35
 
36
  with gr.Row():
37
- img_input = gr.Image(type="pil", label="Carica immagine (verrà convertita in base64)")
38
- b64_input = gr.Textbox(label="📤 Base64 dell'immagine (API)", lines=6)
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- output = gr.Textbox(label="Risultato analisi")
 
 
 
 
 
 
41
 
42
- img_input.change(fn=file_to_base64, inputs=img_input, outputs=b64_input)
43
- b64_input.submit(fn=analyze_base64, inputs=b64_input, outputs=output)
 
 
 
 
 
 
44
 
45
- # Avvio server
46
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision import transforms
5
+ from torchvision.models import resnet18
6
  from PIL import Image
7
+ import base64
8
+ import io
9
+
10
+ # ---------------- CONFIG ----------------
11
+ labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
12
+ theme_color = "#6C5B7B"
13
+
14
+ # ---------------- MODEL ----------------
15
+ class Classifier(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.cnn_layers = resnet18(weights=None)
19
+ self.fc_layers = nn.Sequential(
20
+ nn.Linear(1000, 512),
21
+ nn.Dropout(0.3),
22
+ nn.Linear(512, 128),
23
+ nn.ReLU(),
24
+ nn.Linear(128, len(labels))
25
+ )
26
+
27
+ def forward(self, x):
28
+ x = self.cnn_layers(x)
29
+ x = self.fc_layers(x)
30
+ return x
31
+
32
+ preprocess = transforms.Compose([
33
+ transforms.Resize((224,224)),
34
+ transforms.ToTensor(),
35
+ transforms.Normalize(mean=[0.485,0.456,0.406],
36
+ std=[0.229,0.224,0.225])
37
+ ])
38
+
39
+ model = Classifier()
40
+ model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
41
+ model.eval()
42
 
43
+ # ---------------- FUNZIONE ----------------
44
+ def predict(image_input):
45
+ """
46
+ Supporta:
47
+ - PIL Image (UI web)
48
+ - stringa base64 (API)
49
+ """
50
  try:
51
+ if isinstance(image_input, str):
52
+ if image_input.startswith("data:image"):
53
+ image_input = image_input.split(",",1)[1]
54
+ img_bytes = base64.b64decode(image_input)
55
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
56
+ else:
57
+ img = image_input.convert("RGB")
58
+
59
+ img_tensor = preprocess(img).unsqueeze(0)
60
+
61
+ with torch.no_grad():
62
+ logits = model(img_tensor)
63
+ probs = torch.nn.functional.softmax(logits[0], dim=0)
64
+
65
+ probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))}
66
+ max_label = max(probs_dict, key=probs_dict.get)
67
+
68
+ return max_label, probs_dict
69
+
70
  except Exception as e:
71
+ return f"Error: {str(e)}", {}
72
+
73
+ def clear_all():
74
+ return "", ""
75
 
76
+ # ---------------- INTERFACCIA ----------------
77
+ with gr.Blocks(title="NSFW Image Classifier") as demo:
 
 
 
78
 
79
+ gr.HTML(f"""
80
+ <div style="padding:10px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:10px;">
81
+ <h2 style="color:{theme_color};">🎨 NSFW Image Classifier</h2>
82
+ <p>Carica un'immagine o incolla la stringa base64 per analizzarla.</p>
83
+ </div>
84
+ """)
85
 
86
  with gr.Row():
87
+ with gr.Column(scale=2):
88
+ # Input UI
89
+ img_input = gr.Image(label="📷 Carica immagine", type="pil")
90
+ base64_input = gr.Textbox(
91
+ label="📤 Base64 dell'immagine (API)",
92
+ lines=6,
93
+ placeholder="Incolla qui la stringa base64..."
94
+ )
95
+ with gr.Row():
96
+ submit_btn = gr.Button("✨ Analizza", variant="primary")
97
+ clear_btn = gr.Button("🔄 Pulisci", variant="secondary")
98
+
99
+ with gr.Column(scale=1):
100
+ label_output = gr.Textbox(label="Classe predetta", interactive=False)
101
+ result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
102
 
103
+ # ---------------- Eventi UI ----------------
104
+ submit_btn.click(
105
+ fn=predict,
106
+ inputs=[img_input],
107
+ outputs=[label_output, result_display]
108
+ )
109
+ clear_btn.click(fn=clear_all, inputs=None, outputs=[img_input, base64_input])
110
 
111
+ # ---------------- Pulsante invisibile per API base64 ----------------
112
+ api_button = gr.Button(visible=False)
113
+ api_button.click(
114
+ fn=predict,
115
+ inputs=[base64_input],
116
+ outputs=[label_output, result_display],
117
+ api_name="predict" # espone /run/predict
118
+ )
119
 
120
+ # ---------------- LAUNCH ----------------
121
+ if __name__ == "__main__":
122
+ demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)