ALSv commited on
Commit
d35e71b
·
verified ·
1 Parent(s): c9424af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -103
app.py CHANGED
@@ -1,115 +1,46 @@
1
  import gradio as gr
2
- import torch
3
- import torch.nn as nn
4
- from torchvision import transforms
5
- from torchvision.models import resnet18
6
- from PIL import Image
7
  import base64
8
- import io
9
-
10
- # ---------------- CONFIG ----------------
11
- labels = ["Drawings", "Hentai", "Neutral", "Porn", "Sexy"]
12
- theme_color = "#6C5B7B"
13
-
14
- # ---------------- MODEL ----------------
15
- class Classifier(nn.Module):
16
- def __init__(self):
17
- super().__init__()
18
- self.cnn_layers = resnet18(weights=None)
19
- self.fc_layers = nn.Sequential(
20
- nn.Linear(1000, 512),
21
- nn.Dropout(0.3),
22
- nn.Linear(512, 128),
23
- nn.ReLU(),
24
- nn.Linear(128, len(labels))
25
- )
26
-
27
- def forward(self, x):
28
- x = self.cnn_layers(x)
29
- x = self.fc_layers(x)
30
- return x
31
-
32
- preprocess = transforms.Compose([
33
- transforms.Resize((224,224)),
34
- transforms.ToTensor(),
35
- transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
36
- ])
37
-
38
- model = Classifier()
39
- model.load_state_dict(torch.load("classify_nsfw_v3.0.pth", map_location="cpu"))
40
- model.eval()
41
 
42
- # ---------------- FUNZIONI ----------------
43
- def predict(base64_input: str):
44
- """
45
- Unico input: stringa base64 (da API o da UI).
46
- """
47
  try:
48
- if base64_input.startswith("data:image"):
49
- base64_input = base64_input.split(",", 1)[1]
50
-
51
- img_bytes = base64.b64decode(base64_input)
52
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
53
-
54
- img_tensor = preprocess(img).unsqueeze(0)
55
- with torch.no_grad():
56
- logits = model(img_tensor)
57
- probs = torch.nn.functional.softmax(logits[0], dim=0)
58
-
59
- probs_dict = {labels[i]: float(probs[i]) for i in range(len(labels))}
60
- max_label = max(probs_dict, key=probs_dict.get)
61
- return max_label, probs_dict
62
-
63
  except Exception as e:
64
- return f"Error: {str(e)}", {}
65
 
66
- def image_to_base64(img: Image.Image):
67
- """
68
- Converte immagine caricata in base64 e la ritorna
69
- così finisce nella Textbox e poi viene analizzata.
70
- """
71
- buffered = io.BytesIO()
72
- img.save(buffered, format="JPEG")
73
- img_b64 = base64.b64encode(buffered.getvalue()).decode()
74
- return "data:image/jpeg;base64," + img_b64
75
 
76
- def clear_all():
77
- return ""
78
-
79
- # ---------------- INTERFACCIA ----------------
80
- with gr.Blocks(title="NSFW Image Classifier") as demo:
81
-
82
- gr.HTML(f"""
83
- <div style="padding:10px; background:linear-gradient(135deg,#f8f9fa 0%,#e9ecef 100%); border-radius:10px;">
84
- <h2 style="color:{theme_color};">🎨 NSFW Image Classifier</h2>
85
- <p>Carica un'immagine o incolla la stringa base64.<br>
86
- L'API espone <code>/run/predict</code> e accetta <b>solo base64</b>.</p>
87
- </div>
88
- """)
89
 
90
  with gr.Row():
91
- with gr.Column(scale=2):
92
- img_input = gr.Image(label="📷 Carica immagine", type="pil")
93
- base64_input = gr.Textbox(
94
- label="📤 Base64 dell'immagine (API)",
95
- lines=6,
96
- placeholder="Incolla qui la stringa base64..."
97
- )
98
- with gr.Row():
99
- submit_btn = gr.Button("✨ Analizza", variant="primary")
100
- clear_btn = gr.Button("🔄 Pulisci", variant="secondary")
101
-
102
- with gr.Column(scale=1):
103
- label_output = gr.Textbox(label="Classe predetta", interactive=False)
104
- result_display = gr.Label(label="Distribuzione probabilità", num_top_classes=len(labels))
105
 
106
- # Se carico immagine → la converto in base64 → la inserisco nella Textbox
107
- img_input.change(fn=image_to_base64, inputs=img_input, outputs=base64_input)
108
 
109
- # Submit → usa sempre base64 come input
110
- submit_btn.click(fn=predict, inputs=base64_input, outputs=[label_output, result_display])
111
- clear_btn.click(fn=clear_all, inputs=None, outputs=base64_input)
112
 
113
- # ---------------- LAUNCH ----------------
114
- if __name__ == "__main__":
115
- demo.launch(server_name="0.0.0.0", server_port=7860, show_api=True)
 
1
  import gradio as gr
 
 
 
 
 
2
  import base64
3
+ from io import BytesIO
4
+ from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Funzione di analisi immagine base64
7
+ def analyze_base64(b64_string: str):
 
 
 
8
  try:
9
+ # Decodifica
10
+ if b64_string.startswith("data:image"):
11
+ b64_string = b64_string.split(",")[1]
12
+
13
+ image_data = base64.b64decode(b64_string)
14
+ img = Image.open(BytesIO(image_data))
15
+
16
+ # Qui metti il tuo modello / logica di analisi
17
+ # Per esempio: restituisco dimensioni e formato
18
+ result = {
19
+ "width": img.width,
20
+ "height": img.height,
21
+ "format": img.format
22
+ }
23
+ return f"Analisi completata: {result}"
24
  except Exception as e:
25
+ return f"Errore: {str(e)}"
26
 
27
+ # Se carichi da web → converto subito in base64
28
+ def file_to_base64(img: Image.Image):
29
+ buffered = BytesIO()
30
+ img.save(buffered, format="PNG")
31
+ return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode()
 
 
 
 
32
 
33
+ with gr.Blocks() as demo:
34
+ gr.Markdown("## Analisi Immagini via Base64")
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  with gr.Row():
37
+ img_input = gr.Image(type="pil", label="Carica immagine (verrà convertita in base64)")
38
+ b64_input = gr.Textbox(label="📤 Base64 dell'immagine (API)", lines=6)
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ output = gr.Textbox(label="Risultato analisi")
 
41
 
42
+ img_input.change(fn=file_to_base64, inputs=img_input, outputs=b64_input)
43
+ b64_input.submit(fn=analyze_base64, inputs=b64_input, outputs=output)
 
44
 
45
+ # Avvio server
46
+ demo.launch(server_name="0.0.0.0", server_port=7860)