gabrielchua commited on
Commit
db96ac5
Β·
unverified Β·
1 Parent(s): c9fdddf

update files

Browse files
Files changed (2) hide show
  1. app.py +181 -8
  2. requirements.txt +3 -1
app.py CHANGED
@@ -6,11 +6,70 @@ import os
6
  import gradio as gr
7
  from safetensors.torch import load_file
8
  from huggingface_hub import hf_hub_download
 
 
 
 
 
9
 
10
  # Local imports
11
  from lionguard2 import LionGuard2, CATEGORIES
12
  from utils import get_embeddings
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  def download_model(repo_id, filename="LionGuard2.safetensors", token=None):
16
  """
@@ -125,13 +184,18 @@ def analyze_text(text):
125
 
126
  Returns:
127
  binary_score: Overall safety score with styling
128
- category_table: HTML table with category-specific scores and styling
 
 
129
  """
130
  if not text.strip():
131
  empty_html = '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
132
- return '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>', empty_html
133
 
134
  try:
 
 
 
135
  # Get embeddings for the text
136
  embeddings = get_embeddings([text])
137
 
@@ -141,6 +205,49 @@ def analyze_text(text):
141
  # Extract binary score (overall safety)
142
  binary_score = results.get('binary', [0.0])[0]
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  # Prepare category data with max scores and dropdowns
145
  categories_html = []
146
 
@@ -201,13 +308,38 @@ def analyze_text(text):
201
  </div>
202
  '''
203
 
204
- return format_binary_score(binary_score), html_table
 
 
 
 
 
 
 
 
 
 
205
 
206
  except Exception as e:
207
  error_msg = f"Error analyzing text: {str(e)}"
208
  error_html = f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 20px; border-radius: 12px; text-align: center; border: 2px solid #ef4444; box-shadow: 0 4px 12px rgba(0,0,0,0.3);">❌ {error_msg}</div>'
209
- return f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 16px; border-radius: 8px; text-align: center; border: 1px solid #ef4444;">❌ {error_msg}</div>', error_html
 
 
 
 
 
 
 
 
 
210
 
 
 
 
 
 
 
211
 
212
  # Create Gradio interface with dark theme
213
  with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
@@ -250,6 +382,27 @@ with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
250
  category_table = gr.HTML(
251
  value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>'
252
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  # Add information about the categories
255
  with gr.Row():
@@ -275,18 +428,38 @@ with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
275
  </div>
276
  """)
277
 
 
 
 
 
 
 
 
278
  # Connect the analyze button to the function
279
  analyze_btn.click(
280
- fn=analyze_text,
281
  inputs=[text_input],
282
- outputs=[binary_output, category_table]
283
  )
284
 
285
  # Allow Enter key to trigger analysis
286
  text_input.submit(
287
- fn=analyze_text,
288
  inputs=[text_input],
289
- outputs=[binary_output, category_table]
 
 
 
 
 
 
 
 
 
 
 
 
 
290
  )
291
 
292
 
 
6
  import gradio as gr
7
  from safetensors.torch import load_file
8
  from huggingface_hub import hf_hub_download
9
+ import gspread
10
+ from google.oauth2 import service_account
11
+ import json
12
+ from datetime import datetime
13
+ import uuid
14
 
15
  # Local imports
16
  from lionguard2 import LionGuard2, CATEGORIES
17
  from utils import get_embeddings
18
 
19
+ # Google Sheets configuration
20
+ GOOGLE_SHEET_URL = os.environ.get("GOOGLE_SHEET_URL")
21
+ GOOGLE_CREDENTIALS = os.environ.get("GCP_SERVICE_ACCOUNT")
22
+ RESULTS_SHEET_NAME = "results"
23
+ VOTES_SHEET_NAME = "votes"
24
+
25
+ # Helper to save results data
26
+ def save_results_data(row):
27
+ try:
28
+ # Create credentials object
29
+ credentials = service_account.Credentials.from_service_account_info(
30
+ json.loads(GOOGLE_CREDENTIALS),
31
+ scopes=[
32
+ "https://www.googleapis.com/auth/spreadsheets",
33
+ "https://www.googleapis.com/auth/drive",
34
+ ],
35
+ )
36
+
37
+ # Create authorized client
38
+ gc = gspread.authorize(credentials)
39
+ sheet = gc.open_by_url(GOOGLE_SHEET_URL)
40
+ ws = sheet.worksheet(RESULTS_SHEET_NAME)
41
+ ws.append_row(list(row.values()))
42
+ print(f"Saved results data for text_id: {row['text_id']}")
43
+ except Exception as e:
44
+ print(f"Error saving results data: {e}")
45
+
46
+ # Helper to save vote data
47
+ def save_vote_data(text_id, agree):
48
+ try:
49
+ # Create credentials object
50
+ credentials = service_account.Credentials.from_service_account_info(
51
+ json.loads(GOOGLE_CREDENTIALS),
52
+ scopes=[
53
+ "https://www.googleapis.com/auth/spreadsheets",
54
+ "https://www.googleapis.com/auth/drive",
55
+ ],
56
+ )
57
+
58
+ # Create authorized client
59
+ gc = gspread.authorize(credentials)
60
+ sheet = gc.open_by_url(GOOGLE_SHEET_URL)
61
+ ws = sheet.worksheet(VOTES_SHEET_NAME)
62
+
63
+ vote_row = {
64
+ "datetime": datetime.now().isoformat(),
65
+ "text_id": text_id,
66
+ "agree": agree
67
+ }
68
+ ws.append_row(list(vote_row.values()))
69
+ print(f"Saved vote data for text_id: {text_id}, agree: {agree}")
70
+ except Exception as e:
71
+ print(f"Error saving vote data: {e}")
72
+
73
 
74
  def download_model(repo_id, filename="LionGuard2.safetensors", token=None):
75
  """
 
184
 
185
  Returns:
186
  binary_score: Overall safety score with styling
187
+ category_table: HTML table with category-specific scores and styling
188
+ text_id: Unique identifier for this analysis
189
+ voting_section: HTML for voting buttons
190
  """
191
  if not text.strip():
192
  empty_html = '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
193
+ return '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>', empty_html, "", ""
194
 
195
  try:
196
+ # Generate unique text ID
197
+ text_id = str(uuid.uuid4())
198
+
199
  # Get embeddings for the text
200
  embeddings = get_embeddings([text])
201
 
 
205
  # Extract binary score (overall safety)
206
  binary_score = results.get('binary', [0.0])[0]
207
 
208
+ # Extract specific scores for Google Sheets
209
+ hateful_scores = CATEGORIES['hateful']
210
+ hateful_l1_score = results.get(hateful_scores[0], [0.0])[0] if len(hateful_scores) > 0 else 0.0
211
+ hateful_l2_score = results.get(hateful_scores[1], [0.0])[0] if len(hateful_scores) > 1 else 0.0
212
+
213
+ insults_scores = CATEGORIES['insults']
214
+ insults_score = results.get(insults_scores[0], [0.0])[0] if len(insults_scores) > 0 else 0.0
215
+
216
+ sexual_scores = CATEGORIES['sexual']
217
+ sexual_l1_score = results.get(sexual_scores[0], [0.0])[0] if len(sexual_scores) > 0 else 0.0
218
+ sexual_l2_score = results.get(sexual_scores[1], [0.0])[0] if len(sexual_scores) > 1 else 0.0
219
+
220
+ physical_violence_scores = CATEGORIES['physical_violence']
221
+ physical_violence_score = results.get(physical_violence_scores[0], [0.0])[0] if len(physical_violence_scores) > 0 else 0.0
222
+
223
+ self_harm_scores = CATEGORIES['self_harm']
224
+ self_harm_l1_score = results.get(self_harm_scores[0], [0.0])[0] if len(self_harm_scores) > 0 else 0.0
225
+ self_harm_l2_score = results.get(self_harm_scores[1], [0.0])[0] if len(self_harm_scores) > 1 else 0.0
226
+
227
+ aom_scores = CATEGORIES['all_other_misconduct']
228
+ aom_l1_score = results.get(aom_scores[0], [0.0])[0] if len(aom_scores) > 0 else 0.0
229
+ aom_l2_score = results.get(aom_scores[1], [0.0])[0] if len(aom_scores) > 1 else 0.0
230
+
231
+ # Save results to Google Sheets
232
+ if GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
233
+ results_row = {
234
+ "datetime": datetime.now().isoformat(),
235
+ "text_id": text_id,
236
+ "text": text,
237
+ "binary_score": binary_score,
238
+ "hateful_l1_score": hateful_l1_score,
239
+ "hateful_l2_score": hateful_l2_score,
240
+ "insults_score": insults_score,
241
+ "sexual_l1_score": sexual_l1_score,
242
+ "sexual_l2_score": sexual_l2_score,
243
+ "physical_violence_score": physical_violence_score,
244
+ "self_harm_l1_score": self_harm_l1_score,
245
+ "self_harm_l2_score": self_harm_l2_score,
246
+ "aom_l1_score": aom_l1_score,
247
+ "aom_l2_score": aom_l2_score
248
+ }
249
+ save_results_data(results_row)
250
+
251
  # Prepare category data with max scores and dropdowns
252
  categories_html = []
253
 
 
308
  </div>
309
  '''
310
 
311
+ # Create voting section
312
+ voting_html = f'''
313
+ <div style="margin: 24px 0; text-align: center;">
314
+ <div style="background: #1f2937; border-radius: 12px; padding: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.3); border: 1px solid #374151;">
315
+ <h3 style="color: #f9fafb; font-size: 18px; font-weight: 600; margin-bottom: 12px;">πŸ“Š How accurate are these results?</h3>
316
+ <p style="color: #d1d5db; font-size: 14px; margin-bottom: 16px;">Your feedback helps improve the model</p>
317
+ </div>
318
+ </div>
319
+ '''
320
+
321
+ return format_binary_score(binary_score), html_table, text_id, voting_html
322
 
323
  except Exception as e:
324
  error_msg = f"Error analyzing text: {str(e)}"
325
  error_html = f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 20px; border-radius: 12px; text-align: center; border: 2px solid #ef4444; box-shadow: 0 4px 12px rgba(0,0,0,0.3);">❌ {error_msg}</div>'
326
+ return f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 16px; border-radius: 8px; text-align: center; border: 1px solid #ef4444;">❌ {error_msg}</div>', error_html, "", ""
327
+
328
+
329
+ # Voting functions
330
+ def vote_thumbs_up(text_id):
331
+ """Handle thumbs up vote"""
332
+ if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
333
+ save_vote_data(text_id, True)
334
+ return '<div style="background: linear-gradient(135deg, #065f46 0%, #047857 100%); color: #34d399; padding: 16px; border-radius: 8px; text-align: center; font-weight: 600; border: 1px solid #10b981; margin: 8px 0;">πŸ‘ Thanks for your feedback!</div>'
335
+ return '<div style="color: #9ca3af; text-align: center; padding: 16px;">Voting not available</div>'
336
 
337
+ def vote_thumbs_down(text_id):
338
+ """Handle thumbs down vote"""
339
+ if text_id and GOOGLE_SHEET_URL and GOOGLE_CREDENTIALS:
340
+ save_vote_data(text_id, False)
341
+ return '<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 16px; border-radius: 8px; text-align: center; font-weight: 600; border: 1px solid #ef4444; margin: 8px 0;">πŸ‘Ž Thanks for your feedback!</div>'
342
+ return '<div style="color: #9ca3af; text-align: center; padding: 16px;">Voting not available</div>'
343
 
344
  # Create Gradio interface with dark theme
345
  with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
 
382
  category_table = gr.HTML(
383
  value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>'
384
  )
385
+
386
+ # Voting section
387
+ voting_section = gr.HTML(value="")
388
+
389
+ with gr.Row(visible=False) as voting_row:
390
+ with gr.Column():
391
+ gr.HTML("""
392
+ <div style="margin: 16px 0; text-align: center;">
393
+ <h3 style="color: #f9fafb; font-size: 18px; font-weight: 600; margin-bottom: 8px;">πŸ“Š How accurate are these results?</h3>
394
+ <p style="color: #d1d5db; font-size: 14px; margin-bottom: 16px;">Your feedback helps improve the model</p>
395
+ </div>
396
+ """)
397
+
398
+ with gr.Row():
399
+ thumbs_up_btn = gr.Button("πŸ‘ Accurate", variant="secondary", scale=1)
400
+ thumbs_down_btn = gr.Button("πŸ‘Ž Inaccurate", variant="secondary", scale=1)
401
+
402
+ vote_feedback = gr.HTML(value="")
403
+
404
+ # Hidden text_id to track current analysis
405
+ current_text_id = gr.Textbox(value="", visible=False)
406
 
407
  # Add information about the categories
408
  with gr.Row():
 
428
  </div>
429
  """)
430
 
431
+ # Function to handle analysis and show voting buttons
432
+ def analyze_and_show_voting(text):
433
+ binary_score, category_table, text_id, voting_html = analyze_text(text)
434
+ # Show voting row if we have results
435
+ voting_row_update = gr.update(visible=bool(text_id))
436
+ return binary_score, category_table, text_id, voting_row_update, ""
437
+
438
  # Connect the analyze button to the function
439
  analyze_btn.click(
440
+ fn=analyze_and_show_voting,
441
  inputs=[text_input],
442
+ outputs=[binary_output, category_table, current_text_id, voting_row, vote_feedback]
443
  )
444
 
445
  # Allow Enter key to trigger analysis
446
  text_input.submit(
447
+ fn=analyze_and_show_voting,
448
  inputs=[text_input],
449
+ outputs=[binary_output, category_table, current_text_id, voting_row, vote_feedback]
450
+ )
451
+
452
+ # Connect voting buttons
453
+ thumbs_up_btn.click(
454
+ fn=vote_thumbs_up,
455
+ inputs=[current_text_id],
456
+ outputs=[vote_feedback]
457
+ )
458
+
459
+ thumbs_down_btn.click(
460
+ fn=vote_thumbs_down,
461
+ inputs=[current_text_id],
462
+ outputs=[vote_feedback]
463
  )
464
 
465
 
requirements.txt CHANGED
@@ -3,4 +3,6 @@ torch>=2.0.0
3
  safetensors>=0.4.0
4
  openai>=1.0.0
5
  numpy>=1.24.0
6
- huggingface_hub>=0.19.0
 
 
 
3
  safetensors>=0.4.0
4
  openai>=1.0.0
5
  numpy>=1.24.0
6
+ huggingface_hub>=0.19.0
7
+ gspread==6.2.1
8
+ google-auth==2.40.3