ALSv commited on
Commit
a42bcbf
·
verified ·
1 Parent(s): 82ea080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -49
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # nsfw_app.py
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
@@ -13,7 +13,7 @@ import traceback
13
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
14
  theme_color = "#6C5B7B"
15
 
16
- # ---------------- MODEL (stesso file pesi) ----------------
17
  class Classifier(nn.Module):
18
  def __init__(self):
19
  super().__init__()
@@ -38,39 +38,23 @@ preprocess = transforms.Compose([
38
  std =[0.229,0.224,0.225])
39
  ])
40
 
41
- # Carica pesi (stesso file che usavi)
42
  model = Classifier()
43
  model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
44
  model.eval()
45
 
46
- # ---------------- FUNZIONE UNICA predict (accetta SOLO base64) ----------------
47
  def predict(base64_input: str):
48
- """
49
- Unico input dell'API: stringa base64 (es. "data:image/jpeg;base64,...")
50
- Ritorna: (label_str, {label:prob})
51
- """
52
  try:
53
  if not base64_input or not isinstance(base64_input, str):
54
  return "Input base64 mancante o non valido", {}
55
 
56
- # rimuovi eventuale prefisso data:image...
57
  if base64_input.startswith("data:image"):
58
  base64_input = base64_input.split(",", 1)[1]
59
 
60
- # decodifica base64
61
- try:
62
- img_bytes = base64.b64decode(base64_input)
63
- except Exception as e:
64
- return f"Errore decodifica base64: {e}", {}
65
 
66
- # apri immagine
67
- try:
68
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
69
- except Exception as e:
70
- return f"Errore apertura immagine: {e}", {}
71
-
72
- # preprocess + inferenza
73
- img_tensor = preprocess(img).unsqueeze(0) # 1x3x224x224
74
  with torch.no_grad():
75
  logits = model(img_tensor)
76
  probs = torch.nn.functional.softmax(logits[0], dim=0)
@@ -82,56 +66,45 @@ def predict(base64_input: str):
82
  except Exception:
83
  return f"Unhandled error:\n{traceback.format_exc()}", {}
84
 
85
- # ---------------- Helper: convert image upload -> base64 ----------------
86
  def image_to_base64(img: Image.Image):
87
- """
88
- Converte PIL image in data:image/jpeg;base64,...
89
- (usato dall'UI: caricamento immagine -> si popola la textbox base64)
90
- """
91
  if img is None:
92
  return ""
93
- buffer = io.BytesIO()
94
- img.save(buffer, format="JPEG", quality=90)
95
- b64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
96
- return "data:image/jpeg;base64," + b64
97
 
98
  def clear_box():
99
  return ""
100
 
101
- # ---------------- UI (Blocks) ----------------
102
- with gr.Blocks(title="NSFW Image Classifier (base64 single-input)"):
103
  gr.HTML(f"""
104
  <div style="padding:12px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:8px;">
105
  <h2 style="color:{theme_color}; margin:0;">🎨 NSFW Image Classifier</h2>
106
- <p style="margin:6px 0 0 0;">Carica un'immagine oppure incolla la base64. L'API accetta solo base64.</p>
107
  </div>
108
  """)
109
  with gr.Row():
110
  with gr.Column(scale=2):
111
- image_input = gr.Image(label="📷 Carica immagine (verrà convertita in base64)", type="pil")
112
- base64_input = gr.Textbox(label="📤 Base64 (API) — unico input", lines=6,
113
- placeholder="Incolla qui la stringa base64 (data:image/..;base64,...)")
114
  with gr.Row():
115
- analyze_btn = gr.Button("✨ Analizza (usa la base64 sopra)")
116
  clear_btn = gr.Button("🔄 Pulisci")
117
  with gr.Column(scale=1):
118
  label_output = gr.Textbox(label="Classe predetta", interactive=False)
119
  result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
120
 
121
- # quando carichi immagine -> converto e popolo la textbox con la base64
122
  image_input.change(fn=image_to_base64, inputs=image_input, outputs=base64_input)
123
 
124
- # quando la base64 cambia -> chiamo predict (questo espone automaticamente l'endpoint API
125
- # che accetta solo la textbox base64; gradio mappa l'endpoint in /run/predict in locale)
126
- base64_input.change(fn=predict, inputs=base64_input, outputs=[label_output, result_display], api_name="predict")
127
-
128
- # pulsante per analizzare manualmente (usa la base64 contenuta nella textbox)
129
- analyze_btn.click(fn=predict, inputs=base64_input, outputs=[label_output, result_display])
130
 
131
  clear_btn.click(fn=clear_box, inputs=None, outputs=base64_input)
132
 
133
- # ---------------- LAUNCH ----------------
134
  if __name__ == "__main__":
135
- # show_api=True per vedere il link "View API" nella UI (opzionale)
136
- demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)
137
-
 
1
+ # nsfw_app_api_only.py
2
  import gradio as gr
3
  import torch
4
  import torch.nn as nn
 
13
  labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
14
  theme_color = "#6C5B7B"
15
 
16
+ # ---------------- MODEL ----------------
17
  class Classifier(nn.Module):
18
  def __init__(self):
19
  super().__init__()
 
38
  std =[0.229,0.224,0.225])
39
  ])
40
 
 
41
  model = Classifier()
42
  model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
43
  model.eval()
44
 
45
+ # ---------------- FUNZIONE UNICA ----------------
46
  def predict(base64_input: str):
 
 
 
 
47
  try:
48
  if not base64_input or not isinstance(base64_input, str):
49
  return "Input base64 mancante o non valido", {}
50
 
 
51
  if base64_input.startswith("data:image"):
52
  base64_input = base64_input.split(",", 1)[1]
53
 
54
+ img_bytes = base64.b64decode(base64_input)
55
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
 
 
 
56
 
57
+ img_tensor = preprocess(img).unsqueeze(0)
 
 
 
 
 
 
 
58
  with torch.no_grad():
59
  logits = model(img_tensor)
60
  probs = torch.nn.functional.softmax(logits[0], dim=0)
 
66
  except Exception:
67
  return f"Unhandled error:\n{traceback.format_exc()}", {}
68
 
69
+ # ---------------- Helpers ----------------
70
  def image_to_base64(img: Image.Image):
 
 
 
 
71
  if img is None:
72
  return ""
73
+ buf = io.BytesIO()
74
+ img.save(buf, format="JPEG", quality=90)
75
+ return "data:image/jpeg;base64," + base64.b64encode(buf.getvalue()).decode("utf-8")
 
76
 
77
  def clear_box():
78
  return ""
79
 
80
+ # ---------------- UI ----------------
81
+ with gr.Blocks(title="NSFW Image Classifier (API standard)") as demo:
82
  gr.HTML(f"""
83
  <div style="padding:12px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:8px;">
84
  <h2 style="color:{theme_color}; margin:0;">🎨 NSFW Image Classifier</h2>
85
+ <p style="margin:6px 0 0 0;">Carica un'immagine oppure incolla la base64. L'API espone solo <b>/api/predict</b>.</p>
86
  </div>
87
  """)
88
  with gr.Row():
89
  with gr.Column(scale=2):
90
+ image_input = gr.Image(label="📷 Carica immagine (convertita in base64)", type="pil")
91
+ base64_input = gr.Textbox(label="📤 Base64 (API)", lines=6,
92
+ placeholder="Incolla qui la stringa base64...")
93
  with gr.Row():
94
+ analyze_btn = gr.Button("✨ Analizza")
95
  clear_btn = gr.Button("🔄 Pulisci")
96
  with gr.Column(scale=1):
97
  label_output = gr.Textbox(label="Classe predetta", interactive=False)
98
  result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
99
 
100
+ # Carica immagine -> converte in base64 e riempie textbox
101
  image_input.change(fn=image_to_base64, inputs=image_input, outputs=base64_input)
102
 
103
+ # Analizza manualmente (API unica)
104
+ analyze_btn.click(fn=predict, inputs=base64_input, outputs=[label_output, result_display], api_name="predict")
 
 
 
 
105
 
106
  clear_btn.click(fn=clear_box, inputs=None, outputs=base64_input)
107
 
108
+ # ---------------- Launch ----------------
109
  if __name__ == "__main__":
110
+ demo.launch(server_name="0.0.0.0", server_port=7860)