ALSv commited on
Commit
c9424af
·
verified ·
1 Parent(s): 4e4c67b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -24
app.py CHANGED
@@ -9,6 +9,7 @@ import io
9
 
10
  # ---------------- CONFIG ----------------
11
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
 
12
 
13
  # ---------------- MODEL ----------------
14
  class Classifier(nn.Module):
@@ -38,18 +39,17 @@ model = Classifier()
38
  model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
39
  model.eval()
40
 
41
- # ---------------- FUNZIONE ----------------
42
- def predict(input_data):
 
 
 
43
  try:
44
- if isinstance(input_data, str):
45
- if input_data.startswith("data:image"):
46
- input_data = input_data.split(",",1)[1]
47
- img_bytes = base64.b64decode(input_data)
48
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
49
- elif isinstance(input_data, Image.Image):
50
- img = input_data.convert("RGB")
51
- else:
52
- return "Input non valido", {}
53
 
54
  img_tensor = preprocess(img).unsqueeze(0)
55
  with torch.no_grad():
@@ -63,20 +63,52 @@ def predict(input_data):
63
  except Exception as e:
64
  return f"Error: {str(e)}", {}
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # ---------------- INTERFACCIA ----------------
67
- demo = gr.Interface(
68
- fn=predict,
69
- inputs=[
70
- gr.Image(type="pil", label="📷 Carica immagine"),
71
- gr.Textbox(label="📤 Oppure Base64", lines=6, placeholder="Incolla base64 qui...")
72
- ],
73
- outputs=[
74
- gr.Textbox(label="Classe predetta"),
75
- gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
76
- ],
77
- live=False,
78
- allow_flagging="never"
79
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # ---------------- LAUNCH ----------------
82
  if __name__ == "__main__":
 
9
 
10
  # ---------------- CONFIG ----------------
11
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
12
+ theme_color = "#6C5B7B"
13
 
14
  # ---------------- MODEL ----------------
15
  class Classifier(nn.Module):
 
39
  model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
40
  model.eval()
41
 
42
+ # ---------------- FUNZIONI ----------------
43
+ def predict(base64_input: str):
44
+ """
45
+ Unico input: stringa base64 (da API o da UI).
46
+ """
47
  try:
48
+ if base64_input.startswith("data:image"):
49
+ base64_input = base64_input.split(",", 1)[1]
50
+
51
+ img_bytes = base64.b64decode(base64_input)
52
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
 
 
 
 
53
 
54
  img_tensor = preprocess(img).unsqueeze(0)
55
  with torch.no_grad():
 
63
  except Exception as e:
64
  return f"Error: {str(e)}", {}
65
 
66
+ def image_to_base64(img: Image.Image):
67
+ """
68
+ Converte immagine caricata in base64 e la ritorna
69
+ così finisce nella Textbox e poi viene analizzata.
70
+ """
71
+ buffered = io.BytesIO()
72
+ img.save(buffered, format="JPEG")
73
+ img_b64 = base64.b64encode(buffered.getvalue()).decode()
74
+ return "data:image/jpeg;base64," + img_b64
75
+
76
+ def clear_all():
77
+ return ""
78
+
79
  # ---------------- INTERFACCIA ----------------
80
+ with gr.Blocks(title="NSFW Image Classifier") as demo:
81
+
82
+ gr.HTML(f"""
83
+ <div style="padding:10px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:10px;">
84
+ <h2 style="color:{theme_color};">🎨 NSFW Image Classifier</h2>
85
+ <p>Carica un'immagine o incolla la stringa base64.<br>
86
+ L'API espone <code>/run/predict</code> e accetta <b>solo base64</b>.</p>
87
+ </div>
88
+ """)
89
+
90
+ with gr.Row():
91
+ with gr.Column(scale=2):
92
+ img_input = gr.Image(label="📷 Carica immagine", type="pil")
93
+ base64_input = gr.Textbox(
94
+ label="📤 Base64 dell'immagine (API)",
95
+ lines=6,
96
+ placeholder="Incolla qui la stringa base64..."
97
+ )
98
+ with gr.Row():
99
+ submit_btn = gr.Button("✨ Analizza", variant="primary")
100
+ clear_btn = gr.Button("🔄 Pulisci", variant="secondary")
101
+
102
+ with gr.Column(scale=1):
103
+ label_output = gr.Textbox(label="Classe predetta", interactive=False)
104
+ result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
105
+
106
+ # Se carico immagine → la converto in base64 → la inserisco nella Textbox
107
+ img_input.change(fn=image_to_base64, inputs=img_input, outputs=base64_input)
108
+
109
+ # Submit → usa sempre base64 come input
110
+ submit_btn.click(fn=predict, inputs=base64_input, outputs=[label_output, result_display])
111
+ clear_btn.click(fn=clear_all, inputs=None, outputs=base64_input)
112
 
113
  # ---------------- LAUNCH ----------------
114
  if __name__ == "__main__":