update files
Browse files- app.py +181 -8
- requirements.txt +3 -1
app.py
CHANGED
@@ -6,11 +6,70 @@ import os
|
|
6 |
import gradio as gr
|
7 |
from safetensors.torch import load_file
|
8 |
from huggingface_hub import hf_hub_download
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
# Local imports
|
11 |
from lionguard2 import LionGuard2, CATEGORIES
|
12 |
from utils import get_embeddings
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
def download_model(repo_id, filename="LionGuard2.safetensors", token=None):
|
16 |
"""
|
@@ -125,13 +184,18 @@ def analyze_text(text):
|
|
125 |
|
126 |
Returns:
|
127 |
binary_score: Overall safety score with styling
|
128 |
-
category_table: HTML table with category-specific scores and styling
|
|
|
|
|
129 |
"""
|
130 |
if not text.strip():
|
131 |
empty_html = '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
|
132 |
-
return '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>', empty_html
|
133 |
|
134 |
try:
|
|
|
|
|
|
|
135 |
# Get embeddings for the text
|
136 |
embeddings = get_embeddings([text])
|
137 |
|
@@ -141,6 +205,49 @@ def analyze_text(text):
|
|
141 |
# Extract binary score (overall safety)
|
142 |
binary_score = results.get('binary', [0.0])[0]
|
143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
# Prepare category data with max scores and dropdowns
|
145 |
categories_html = []
|
146 |
|
@@ -201,13 +308,38 @@ def analyze_text(text):
|
|
201 |
</div>
|
202 |
'''
|
203 |
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
205 |
|
206 |
except Exception as e:
|
207 |
error_msg = f"Error analyzing text: {str(e)}"
|
208 |
error_html = f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 20px; border-radius: 12px; text-align: center; border: 2px solid #ef4444; box-shadow: 0 4px 12px rgba(0,0,0,0.3);">β {error_msg}</div>'
|
209 |
-
return f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 16px; border-radius: 8px; text-align: center; border: 1px solid #ef4444;">β {error_msg}</div>', error_html
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
210 |
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
|
212 |
# Create Gradio interface with dark theme
|
213 |
with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
|
@@ -250,6 +382,27 @@ with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
|
|
250 |
category_table = gr.HTML(
|
251 |
value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>'
|
252 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
# Add information about the categories
|
255 |
with gr.Row():
|
@@ -275,18 +428,38 @@ with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
|
|
275 |
</div>
|
276 |
""")
|
277 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
# Connect the analyze button to the function
|
279 |
analyze_btn.click(
|
280 |
-
fn=
|
281 |
inputs=[text_input],
|
282 |
-
outputs=[binary_output, category_table]
|
283 |
)
|
284 |
|
285 |
# Allow Enter key to trigger analysis
|
286 |
text_input.submit(
|
287 |
-
fn=
|
288 |
inputs=[text_input],
|
289 |
-
outputs=[binary_output, category_table]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
290 |
)
|
291 |
|
292 |
|
|
|
6 |
import gradio as gr
|
7 |
from safetensors.torch import load_file
|
8 |
from huggingface_hub import hf_hub_download
|
9 |
+
import gspread
|
10 |
+
from google.oauth2 import service_account
|
11 |
+
import json
|
12 |
+
from datetime import datetime
|
13 |
+
import uuid
|
14 |
|
15 |
# Local imports
|
16 |
from lionguard2 import LionGuard2, CATEGORIES
|
17 |
from utils import get_embeddings
|
18 |
|
19 |
+
# Google Sheets configuration
|
20 |
+
GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL")
|
21 |
+
GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT")
|
22 |
+
RESULTS_SHEET_NAME = "results"
|
23 |
+
VOTES_SHEET_NAME = "votes"
|
24 |
+
|
25 |
+
# Helper to save results data
|
26 |
+
def save_results_data(row):
|
27 |
+
try:
|
28 |
+
# Create credentials object
|
29 |
+
credentials = service_account.Credentials.from_service_account_info(
|
30 |
+
json.loads(GOOGLE_CREDENTIALS),
|
31 |
+
scopes=[
|
32 |
+
"https://www.googleapis.com/auth/spreadsheets",
|
33 |
+
"https://www.googleapis.com/auth/drive",
|
34 |
+
],
|
35 |
+
)
|
36 |
+
|
37 |
+
# Create authorized client
|
38 |
+
gc = gspread.authorize(credentials)
|
39 |
+
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
|
40 |
+
ws = sheet.worksheet(RESULTS_SHEET_NAME)
|
41 |
+
ws.append_row(list(row.values()))
|
42 |
+
print(f"Saved results data for text_id: {row['text_id']}")
|
43 |
+
except Exception as e:
|
44 |
+
print(f"Error saving results data: {e}")
|
45 |
+
|
46 |
+
# Helper to save vote data
|
47 |
+
def save_vote_data(text_id, agree):
|
48 |
+
try:
|
49 |
+
# Create credentials object
|
50 |
+
credentials = service_account.Credentials.from_service_account_info(
|
51 |
+
json.loads(GOOGLE_CREDENTIALS),
|
52 |
+
scopes=[
|
53 |
+
"https://www.googleapis.com/auth/spreadsheets",
|
54 |
+
"https://www.googleapis.com/auth/drive",
|
55 |
+
],
|
56 |
+
)
|
57 |
+
|
58 |
+
# Create authorized client
|
59 |
+
gc = gspread.authorize(credentials)
|
60 |
+
sheet = gc.open_by_url(GOOGLE_SHEET_URL)
|
61 |
+
ws = sheet.worksheet(VOTES_SHEET_NAME)
|
62 |
+
|
63 |
+
vote_row = {
|
64 |
+
"datetime": datetime.now().isoformat(),
|
65 |
+
"text_id": text_id,
|
66 |
+
"agree": agree
|
67 |
+
}
|
68 |
+
ws.append_row(list(vote_row.values()))
|
69 |
+
print(f"Saved vote data for text_id: {text_id}, agree: {agree}")
|
70 |
+
except Exception as e:
|
71 |
+
print(f"Error saving vote data: {e}")
|
72 |
+
|
73 |
|
74 |
def download_model(repo_id, filename="LionGuard2.safetensors", token=None):
|
75 |
"""
|
|
|
184 |
|
185 |
Returns:
|
186 |
binary_score: Overall safety score with styling
|
187 |
+
category_table: HTML table with category-specific scores and styling
|
188 |
+
text_id: Unique identifier for this analysis
|
189 |
+
voting_section: HTML for voting buttons
|
190 |
"""
|
191 |
if not text.strip():
|
192 |
empty_html = '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
|
193 |
+
return '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>', empty_html, "", ""
|
194 |
|
195 |
try:
|
196 |
+
# Generate unique text ID
|
197 |
+
text_id = str(uuid.uuid4())
|
198 |
+
|
199 |
# Get embeddings for the text
|
200 |
embeddings = get_embeddings([text])
|
201 |
|
|
|
205 |
# Extract binary score (overall safety)
|
206 |
binary_score = results.get('binary', [0.0])[0]
|
207 |
|
208 |
+
# Extract specific scores for Google Sheets
|
209 |
+
hateful_scores = CATEGORIES['hateful']
|
210 |
+
hateful_l1_score = results.get(hateful_scores[0], [0.0])[0] if len(hateful_scores) > 0 else 0.0
|
211 |
+
hateful_l2_score = results.get(hateful_scores[1], [0.0])[0] if len(hateful_scores) > 1 else 0.0
|
212 |
+
|
213 |
+
insults_scores = CATEGORIES['insults']
|
214 |
+
insults_score = results.get(insults_scores[0], [0.0])[0] if len(insults_scores) > 0 else 0.0
|
215 |
+
|
216 |
+
sexual_scores = CATEGORIES['sexual']
|
217 |
+
sexual_l1_score = results.get(sexual_scores[0], [0.0])[0] if len(sexual_scores) > 0 else 0.0
|
218 |
+
sexual_l2_score = results.get(sexual_scores[1], [0.0])[0] if len(sexual_scores) > 1 else 0.0
|
219 |
+
|
220 |
+
physical_violence_scores = CATEGORIES['physical_violence']
|
221 |
+
physical_violence_score = results.get(physical_violence_scores[0], [0.0])[0] if len(physical_violence_scores) > 0 else 0.0
|
222 |
+
|
223 |
+
self_harm_scores = CATEGORIES['self_harm']
|
224 |
+
self_harm_l1_score = results.get(self_harm_scores[0], [0.0])[0] if len(self_harm_scores) > 0 else 0.0
|
225 |
+
self_harm_l2_score = results.get(self_harm_scores[1], [0.0])[0] if len(self_harm_scores) > 1 else 0.0
|
226 |
+
|
227 |
+
aom_scores = CATEGORIES['all_other_misconduct']
|
228 |
+
aom_l1_score = results.get(aom_scores[0], [0.0])[0] if len(aom_scores) > 0 else 0.0
|
229 |
+
aom_l2_score = results.get(aom_scores[1], [0.0])[0] if len(aom_scores) > 1 else 0.0
|
230 |
+
|
231 |
+
# Save results to Google Sheets
|
232 |
+
if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
|
233 |
+
results_row = {
|
234 |
+
"datetime": datetime.now().isoformat(),
|
235 |
+
"text_id": text_id,
|
236 |
+
"text": text,
|
237 |
+
"binary_score": binary_score,
|
238 |
+
"hateful_l1_score": hateful_l1_score,
|
239 |
+
"hateful_l2_score": hateful_l2_score,
|
240 |
+
"insults_score": insults_score,
|
241 |
+
"sexual_l1_score": sexual_l1_score,
|
242 |
+
"sexual_l2_score": sexual_l2_score,
|
243 |
+
"physical_violence_score": physical_violence_score,
|
244 |
+
"self_harm_l1_score": self_harm_l1_score,
|
245 |
+
"self_harm_l2_score": self_harm_l2_score,
|
246 |
+
"aom_l1_score": aom_l1_score,
|
247 |
+
"aom_l2_score": aom_l2_score
|
248 |
+
}
|
249 |
+
save_results_data(results_row)
|
250 |
+
|
251 |
# Prepare category data with max scores and dropdowns
|
252 |
categories_html = []
|
253 |
|
|
|
308 |
</div>
|
309 |
'''
|
310 |
|
311 |
+
# Create voting section
|
312 |
+
voting_html = f'''
|
313 |
+
<div style="margin: 24px 0; text-align: center;">
|
314 |
+
<div style="background: #1f2937; border-radius: 12px; padding: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.3); border: 1px solid #374151;">
|
315 |
+
<h3 style="color: #f9fafb; font-size: 18px; font-weight: 600; margin-bottom: 12px;">π How accurate are these results?</h3>
|
316 |
+
<p style="color: #d1d5db; font-size: 14px; margin-bottom: 16px;">Your feedback helps improve the model</p>
|
317 |
+
</div>
|
318 |
+
</div>
|
319 |
+
'''
|
320 |
+
|
321 |
+
return format_binary_score(binary_score), html_table, text_id, voting_html
|
322 |
|
323 |
except Exception as e:
|
324 |
error_msg = f"Error analyzing text: {str(e)}"
|
325 |
error_html = f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 20px; border-radius: 12px; text-align: center; border: 2px solid #ef4444; box-shadow: 0 4px 12px rgba(0,0,0,0.3);">β {error_msg}</div>'
|
326 |
+
return f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 16px; border-radius: 8px; text-align: center; border: 1px solid #ef4444;">β {error_msg}</div>', error_html, "", ""
|
327 |
+
|
328 |
+
|
329 |
+
# Voting functions
|
330 |
+
def vote_thumbs_up(text_id):
|
331 |
+
"""Handle thumbs up vote"""
|
332 |
+
if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
|
333 |
+
save_vote_data(text_id, True)
|
334 |
+
return '<div style="background: linear-gradient(135deg, #065f46 0%, #047857 100%); color: #34d399; padding: 16px; border-radius: 8px; text-align: center; font-weight: 600; border: 1px solid #10b981; margin: 8px 0;">π Thanks for your feedback!</div>'
|
335 |
+
return '<div style="color: #9ca3af; text-align: center; padding: 16px;">Voting not available</div>'
|
336 |
|
337 |
+
def vote_thumbs_down(text_id):
|
338 |
+
"""Handle thumbs down vote"""
|
339 |
+
if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
|
340 |
+
save_vote_data(text_id, False)
|
341 |
+
return '<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 16px; border-radius: 8px; text-align: center; font-weight: 600; border: 1px solid #ef4444; margin: 8px 0;">π Thanks for your feedback!</div>'
|
342 |
+
return '<div style="color: #9ca3af; text-align: center; padding: 16px;">Voting not available</div>'
|
343 |
|
344 |
# Create Gradio interface with dark theme
|
345 |
with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
|
|
|
382 |
category_table = gr.HTML(
|
383 |
value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>'
|
384 |
)
|
385 |
+
|
386 |
+
# Voting section
|
387 |
+
voting_section = gr.HTML(value="")
|
388 |
+
|
389 |
+
with gr.Row(visible=False) as voting_row:
|
390 |
+
with gr.Column():
|
391 |
+
gr.HTML("""
|
392 |
+
<div style="margin: 16px 0; text-align: center;">
|
393 |
+
<h3 style="color: #f9fafb; font-size: 18px; font-weight: 600; margin-bottom: 8px;">π How accurate are these results?</h3>
|
394 |
+
<p style="color: #d1d5db; font-size: 14px; margin-bottom: 16px;">Your feedback helps improve the model</p>
|
395 |
+
</div>
|
396 |
+
""")
|
397 |
+
|
398 |
+
with gr.Row():
|
399 |
+
thumbs_up_btn = gr.Button("π Accurate", variant="secondary", scale=1)
|
400 |
+
thumbs_down_btn = gr.Button("π Inaccurate", variant="secondary", scale=1)
|
401 |
+
|
402 |
+
vote_feedback = gr.HTML(value="")
|
403 |
+
|
404 |
+
# Hidden text_id to track current analysis
|
405 |
+
current_text_id = gr.Textbox(value="", visible=False)
|
406 |
|
407 |
# Add information about the categories
|
408 |
with gr.Row():
|
|
|
428 |
</div>
|
429 |
""")
|
430 |
|
431 |
+
# Function to handle analysis and show voting buttons
|
432 |
+
def analyze_and_show_voting(text):
|
433 |
+
binary_score, category_table, text_id, voting_html = analyze_text(text)
|
434 |
+
# Show voting row if we have results
|
435 |
+
voting_row_update = gr.update(visible=bool(text_id))
|
436 |
+
return binary_score, category_table, text_id, voting_row_update, ""
|
437 |
+
|
438 |
# Connect the analyze button to the function
|
439 |
analyze_btn.click(
|
440 |
+
fn=analyze_and_show_voting,
|
441 |
inputs=[text_input],
|
442 |
+
outputs=[binary_output, category_table, current_text_id, voting_row, vote_feedback]
|
443 |
)
|
444 |
|
445 |
# Allow Enter key to trigger analysis
|
446 |
text_input.submit(
|
447 |
+
fn=analyze_and_show_voting,
|
448 |
inputs=[text_input],
|
449 |
+
outputs=[binary_output, category_table, current_text_id, voting_row, vote_feedback]
|
450 |
+
)
|
451 |
+
|
452 |
+
# Connect voting buttons
|
453 |
+
thumbs_up_btn.click(
|
454 |
+
fn=vote_thumbs_up,
|
455 |
+
inputs=[current_text_id],
|
456 |
+
outputs=[vote_feedback]
|
457 |
+
)
|
458 |
+
|
459 |
+
thumbs_down_btn.click(
|
460 |
+
fn=vote_thumbs_down,
|
461 |
+
inputs=[current_text_id],
|
462 |
+
outputs=[vote_feedback]
|
463 |
)
|
464 |
|
465 |
|
requirements.txt
CHANGED
@@ -3,4 +3,6 @@ torch>=2.0.0
|
|
3 |
safetensors>=0.4.0
|
4 |
openai>=1.0.0
|
5 |
numpy>=1.24.0
|
6 |
-
huggingface_hub>=0.19.0
|
|
|
|
|
|
3 |
safetensors>=0.4.0
|
4 |
openai>=1.0.0
|
5 |
numpy>=1.24.0
|
6 |
+
huggingface_hub>=0.19.0
|
7 |
+
gspread==6.2.1
|
8 |
+
google-auth==2.40.3
|