gabrielchua commited on
Commit
27a346a
·
unverified ·
1 Parent(s): c657583

update repo

Browse files
Files changed (8) hide show
  1. .gitignore +167 -0
  2. README.md +0 -12
  3. app.py +261 -135
  4. download_model.py +75 -0
  5. lionguard2.py +170 -0
  6. model.joblib +0 -3
  7. requirements.txt +6 -8
  8. utils.py +44 -0
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # PyInstaller
27
+ # Usually these files are written by a python script from a template
28
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
29
+ *.manifest
30
+ *.spec
31
+
32
+ # Installer logs
33
+ pip-log.txt
34
+ pip-delete-this-directory.txt
35
+
36
+ # Unit test / coverage reports
37
+ htmlcov/
38
+ .tox/
39
+ .nox/
40
+ .coverage
41
+ .coverage.*
42
+ .cache
43
+ nosetests.xml
44
+ coverage.xml
45
+ *.cover
46
+ .hypothesis/
47
+ .pytest_cache/
48
+
49
+ # Translations
50
+ *.mo
51
+ *.pot
52
+
53
+ # Django stuff:
54
+ *.log
55
+ local_settings.py
56
+ db.sqlite3
57
+ db.sqlite3-journal
58
+
59
+ # Flask stuff:
60
+ instance/
61
+ .webassets-cache
62
+
63
+ # Scrapy stuff:
64
+ .scrapy
65
+
66
+ # Sphinx documentation
67
+ docs/_build/
68
+
69
+ # PyBuilder
70
+ target/
71
+
72
+ # Jupyter Notebook
73
+ .ipynb_checkpoints
74
+
75
+ # IPython
76
+ profile_default/
77
+ ipython_config.py
78
+
79
+ # pyenv
80
+ .python-version
81
+
82
+ # pipenv
83
+ Pipfile.lock
84
+
85
+ # poetry
86
+ poetry.lock
87
+
88
+ # PDM
89
+ pdm.lock
90
+ __pypackages__/
91
+
92
+ # mypy
93
+ .mypy_cache/
94
+ .dmypy.json
95
+ dmypy.json
96
+
97
+ # Pyre type checker
98
+ .pyre/
99
+
100
+ # pytype
101
+ .pytype/
102
+
103
+ # Cython debug symbols
104
+ cython_debug/
105
+
106
+ # VS Code
107
+ .vscode/
108
+
109
+ # Mac
110
+ .DS_Store
111
+
112
+ # Model files and large data
113
+ *.safetensors
114
+ *.pt
115
+ *.pth
116
+ *.ckpt
117
+ *.onnx
118
+ *.h5
119
+ *.bin
120
+ *.npy
121
+ *.npz
122
+ *.tar
123
+ *.tar.*
124
+ *.zip
125
+ *.gz
126
+ *.bz2
127
+ *.xz
128
+ *.zst
129
+ *.joblib
130
+ *.pickle
131
+ *.pkl
132
+ *.msgpack
133
+ *.arrow
134
+ *.parquet
135
+ *.tflite
136
+ *.wasm
137
+ *.mlmodel
138
+ *.ftz
139
+ *.rar
140
+ *.7z
141
+
142
+ # LFS cache and pointers
143
+ *.lfs.*
144
+ saved_model/**/*
145
+ *tfevents*
146
+
147
+ # Cache
148
+ .cache/
149
+ .cache/*
150
+
151
+ # Environment
152
+ .env
153
+ .env.*
154
+ .venv/
155
+ venv/
156
+ ENV/
157
+ env/
158
+ env.bak/
159
+ venv.bak/
160
+
161
+ # Gradio
162
+ gradio_cached_examples/
163
+ gradio_cache/
164
+ .gradio/
165
+
166
+ # Cache
167
+ cache/
README.md DELETED
@@ -1,12 +0,0 @@
1
- ---
2
- title: Refactored Guacamole
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: pink
6
- sdk: gradio
7
- sdk_version: 5.20.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,170 +1,296 @@
 
 
 
 
1
  import os
2
  import gradio as gr
3
- import joblib
4
- import numpy as np
5
- import pandas as pd
6
- from openai import OpenAI
7
- from typing import List, Dict, Any
8
-
9
- # --- New Inference Code Components ---
10
-
11
- # Define categories with sub-level information
12
- CATEGORIES = {
13
- 'hateful': ['hateful_lvl_1_discriminatory', 'hateful_lvl_2_hate_speech'],
14
- 'insults': ['insults'],
15
- 'sexual': ['sexual_lvl_1_not_appropriate_for_minors', 'sexual_lvl_2_not_appropriate_for_all'],
16
- 'physical_violence': ['physical_violence'],
17
- 'self_harm': ['self_harm_lvl_1_intent', 'self_harm_lvl_2_action'],
18
- 'all_other_misconduct': ['all_other_misconduct_lvl_1_not_socially_accepted', 'all_other_misconduct_lvl_2_illegal']
19
- }
20
-
21
- def get_embeddings(texts: List[str], model: str = "text-embedding-3-large") -> np.ndarray:
22
  """
23
- Generate embeddings for a list of texts using the OpenAI API synchronously.
24
 
25
  Args:
26
- texts: List of strings to embed.
27
- model: The OpenAI embedding model to use.
28
-
 
29
  Returns:
30
- A numpy array of embeddings.
31
  """
32
- client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
33
- MAX_TOKENS = 8191 # Maximum tokens for the embedding model
34
- truncated_texts = [text[:MAX_TOKENS] for text in texts]
35
 
36
- response = client.embeddings.create(
37
- input=truncated_texts,
38
- model=model
 
 
 
39
  )
40
-
41
- embeddings = np.array([data.embedding for data in response.data])
42
- return embeddings
43
 
44
- def run_model(model_file: str, embeddings: np.ndarray):
 
45
  """
46
- Run the model on the embeddings.
47
 
48
  Args:
49
- model_file: Path to the model file.
50
- embeddings: Numpy array of embeddings.
51
-
52
- Returns:
53
- expanded_predictions, expanded_probabilities, expanded_label_names
54
  """
55
- print("Loading model...")
56
- model_data = joblib.load(model_file)
57
- model = model_data['model']
58
- label_names = model_data['label_names']
59
 
60
- print("Predicting...")
61
- # raw_predictions is a list of arrays with shape (n_samples, 2)
62
- raw_predictions = model.predict(embeddings)
63
 
64
- print("Processing predictions...")
65
- predictions = []
66
- probabilities = []
67
- # Process each category's raw predictions
68
- for i, pred in enumerate(raw_predictions):
69
- # Convert raw predictions (P(y>0), P(y>1)) into a class from {0, 1, 2}
70
- pred_class = np.zeros(len(pred))
71
- pred_class += (pred[:, 0] > 0.5).astype(int) # y > 0
72
- pred_class += (pred[:, 1] > 0.5).astype(int) # y > 1
73
- predictions.append(pred_class)
74
-
75
- # Calculate probabilities for each class:
76
- # P(y=0) = 1 - P(y>0), P(y=1) = P(y>0) - P(y>1), P(y=2) = P(y>1)
77
- prob = np.zeros((len(pred), 3))
78
- prob[:, 0] = 1 - pred[:, 0]
79
- prob[:, 1] = pred[:, 0] - pred[:, 1]
80
- prob[:, 2] = pred[:, 1]
81
- probabilities.append(prob)
82
 
83
- predictions = np.array(predictions).T
84
- probabilities = np.array(probabilities).transpose(1, 0, 2)
85
-
86
- # Expand predictions to sub-levels
87
- expanded_predictions = []
88
- expanded_probabilities = []
89
- expanded_label_names = []
90
- for i, cat in enumerate(label_names):
91
- # Level 1 binary
92
- y_pred_l1 = (predictions[:, i] > 0).astype(int) # y == 1 or y == 2
93
- y_proba_l1 = 1 - probabilities[:, i, 0] # 1 - P(class 0)
94
-
95
- # Level 2 binary
96
- y_pred_l2 = (predictions[:, i] == 2).astype(int) # only y == 2
97
- y_proba_l2 = probabilities[:, i, 2] # Probability of class 2
98
-
99
- if cat in ['binary', 'insults', 'physical_violence']:
100
- expanded_predictions.append(y_pred_l1)
101
- expanded_probabilities.append(y_proba_l1)
102
- expanded_label_names.append(cat)
103
- else:
104
- expanded_predictions.append(y_pred_l1)
105
- expanded_probabilities.append(y_proba_l1)
106
- expanded_label_names.append(CATEGORIES[cat][0])
107
-
108
- expanded_predictions.append(y_pred_l2)
109
- expanded_probabilities.append(y_proba_l2)
110
- expanded_label_names.append(CATEGORIES[cat][1])
111
-
112
- expanded_predictions = np.array(expanded_predictions).T
113
- expanded_probabilities = np.array(expanded_probabilities).T
114
-
115
- return expanded_predictions, expanded_probabilities, expanded_label_names
116
 
117
- def format_output(predictions: np.ndarray, probabilities: np.ndarray, label_names: List[str]) -> pd.DataFrame:
 
 
 
 
 
118
  """
119
- Format the output predictions into a DataFrame.
120
 
121
  Args:
122
- predictions: Binary predictions.
123
- probabilities: Associated prediction scores.
124
- label_names: List of label names.
125
-
126
  Returns:
127
- DataFrame with columns "Label", "Prediction", and "Score".
128
  """
129
- # As our Gradio interface processes one text at a time, we use the first (and only) sample.
130
- data = {
131
- "Label": label_names,
132
- "Prediction": predictions[0].tolist(),
133
- "Score": np.round(probabilities[0], 4).tolist()
134
- }
135
- return pd.DataFrame(data)
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- # --- Gradio App Integration ---
 
 
 
 
 
 
 
 
 
138
 
139
- # Define model file path (adjust as necessary)
140
- MODEL_FILE = "model.joblib"
141
 
142
- def classify_text(text: str):
143
  """
144
- Given an input text, generates embeddings, runs the model inference,
145
- and returns a DataFrame of classification results.
 
 
 
 
 
 
146
  """
147
  if not text.strip():
148
- # Return an empty DataFrame if no text provided
149
- empty_df = pd.DataFrame({"Label": [], "Prediction": [], "Score": []})
150
- return gr.update(value=empty_df, visible=True)
151
 
152
- # Obtain embeddings (input must be a list)
153
- embeddings = get_embeddings([text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- # Run inference on the embeddings using the new model file
156
- predictions, probabilities, label_names = run_model(MODEL_FILE, embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- # Format the results to a DataFrame that Gradio can display
159
- df = format_output(predictions, probabilities, label_names)
160
- return gr.update(value=df, visible=True)
161
-
162
- with gr.Blocks(title="Zoo Entry 001 - Updated Inference") as iface:
163
- input_text = gr.Textbox(lines=5, label="Input Text")
164
- submit_btn = gr.Button("Submit")
165
- output_table = gr.DataFrame(label="Classification Results", visible=False)
166
 
167
- submit_btn.click(fn=classify_text, inputs=input_text, outputs=output_table)
 
 
 
 
 
 
168
 
169
  if __name__ == "__main__":
170
- iface.launch()
 
1
+ """
2
+ simple_demo.py - Gradio Web App for LionGuard2 Content Moderation
3
+ """
4
+
5
  import os
6
  import gradio as gr
7
+ from safetensors.torch import load_file
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ # Local imports
11
+ from lionguard2 import LionGuard2, CATEGORIES
12
+ from utils import get_embeddings
13
+
14
+
15
+ def download_model(repo_id, filename="LionGuard2.safetensors", token=None):
 
 
 
 
 
 
 
 
 
 
16
  """
17
+ Download the LionGuard2 model from a Hugging Face private repository.
18
 
19
  Args:
20
+ repo_id: The Hugging Face repository ID (e.g., "username/repo-name")
21
+ filename: The filename to download (default: "LionGuard2.safetensors")
22
+ token: Hugging Face access token for private repositories
23
+
24
  Returns:
25
+ Path to the downloaded file
26
  """
27
+ if token is None:
28
+ token = os.environ.get("HF_API_KEY")
 
29
 
30
+ # Download the model file
31
+ model_path = hf_hub_download(
32
+ repo_id=repo_id,
33
+ filename=filename,
34
+ token=token,
35
+ cache_dir="./cache"
36
  )
37
+ return model_path
 
 
38
 
39
+
40
+ def load_model(repo_id=None, use_local=True):
41
  """
42
+ Load the LionGuard2 model from either local file or Hugging Face repository.
43
 
44
  Args:
45
+ repo_id: The Hugging Face repository ID (optional)
46
+ use_local: Whether to use local file first (default: True)
 
 
 
47
  """
48
+ model = LionGuard2()
49
+ model.eval()
 
 
50
 
51
+ model_path = "LionGuard2.safetensors"
 
 
52
 
53
+ # Try to download from HF repo if specified and local file doesn't exist or use_local is False
54
+ if repo_id and (not use_local or not os.path.exists(model_path)):
55
+ try:
56
+ print(f"Downloading LionGuard2.safetensors from {repo_id}...")
57
+ model_path = download_model(repo_id)
58
+ print(f"Model downloaded to: {model_path}")
59
+ except Exception as e:
60
+ print(f"Failed to download from HF repo: {e}")
61
+ if not os.path.exists("LionGuard2.safetensors"):
62
+ raise Exception("No local model file found and failed to download from HF repo")
63
+ print("Falling back to local file...")
 
 
 
 
 
 
 
64
 
65
+ state_dict = load_file(model_path)
66
+ model.load_state_dict(state_dict)
67
+ return model
68
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
+ # Load model once at startup
71
+ HF_REPO_ID = "gabrielchua/refactored-guacamole" # Update this with the actual repo ID
72
+ model = load_model(repo_id=HF_REPO_ID)
73
+
74
+
75
+ def format_score_with_style(score_str, compact=False):
76
  """
77
+ Format score with color and emoji based on value.
78
 
79
  Args:
80
+ score_str: Score as string (could be "-" for missing values)
81
+ compact: Whether to use compact styling
82
+
 
83
  Returns:
84
+ HTML formatted string with styling
85
  """
86
+ if score_str == "-":
87
+ return '<span style="color: #9ca3af;">-</span>'
88
+
89
+ try:
90
+ score = float(score_str)
91
+ percentage = int(score * 100)
92
+ padding = "6px 12px" if compact else "8px 16px"
93
+
94
+ if score < 0.4:
95
+ # Safe - refined green
96
+ return f'<span style="background: linear-gradient(135deg, #065f46 0%, #047857 100%); color: #34d399; padding: {padding}; border-radius: 8px; font-size: 13px; font-weight: 600; border: 1px solid #10b981; box-shadow: 0 2px 4px rgba(0,0,0,0.2);">👌 {percentage}%</span>'
97
+ elif 0.4 <= score < 0.7:
98
+ # Warning - refined amber
99
+ return f'<span style="background: linear-gradient(135deg, #92400e 0%, #b45309 100%); color: #fbbf24; padding: {padding}; border-radius: 8px; font-size: 13px; font-weight: 600; border: 1px solid #f59e0b; box-shadow: 0 2px 4px rgba(0,0,0,0.2);">⚠️ {percentage}%</span>'
100
+ else:
101
+ # High risk - refined red
102
+ return f'<span style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: {padding}; border-radius: 8px; font-size: 13px; font-weight: 600; border: 1px solid #ef4444; box-shadow: 0 2px 4px rgba(0,0,0,0.2);">🚨 {percentage}%</span>'
103
+ except:
104
+ return score_str
105
+
106
 
107
+ def format_binary_score(score):
108
+ """Format the binary score with appropriate styling for dark mode."""
109
+ percentage = int(score * 100)
110
+
111
+ if score < 0.4:
112
+ return f'<div style="background: linear-gradient(135deg, #065f46 0%, #047857 100%); color: #34d399; padding: 24px; border-radius: 12px; text-align: center; font-weight: 700; border: 2px solid #10b981; font-size: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.3); margin: 16px 0;">✅ Pass ({percentage}/100)</div>'
113
+ elif 0.4 <= score < 0.7:
114
+ return f'<div style="background: linear-gradient(135deg, #92400e 0%, #b45309 100%); color: #fbbf24; padding: 24px; border-radius: 12px; text-align: center; font-weight: 700; border: 2px solid #f59e0b; font-size: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.3); margin: 16px 0;">⚠️ Warning ({percentage}/100)</div>'
115
+ else:
116
+ return f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 24px; border-radius: 12px; text-align: center; font-weight: 700; border: 2px solid #ef4444; font-size: 20px; box-shadow: 0 4px 12px rgba(0,0,0,0.3); margin: 16px 0;">🚨 Fail ({percentage}/100)</div>'
117
 
 
 
118
 
119
+ def analyze_text(text):
120
  """
121
+ Analyze text for content moderation violations.
122
+
123
+ Args:
124
+ text: Input text to analyze
125
+
126
+ Returns:
127
+ binary_score: Overall safety score with styling
128
+ category_table: HTML table with category-specific scores and styling
129
  """
130
  if not text.strip():
131
+ empty_html = '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
132
+ return '<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>', empty_html
 
133
 
134
+ try:
135
+ # Get embeddings for the text
136
+ embeddings = get_embeddings([text])
137
+
138
+ # Run inference
139
+ results = model.predict(embeddings)
140
+
141
+ # Extract binary score (overall safety)
142
+ binary_score = results.get('binary', [0.0])[0]
143
+
144
+ # Prepare category data with max scores and dropdowns
145
+ categories_html = []
146
+
147
+ # Define the main categories (excluding binary)
148
+ main_categories = ['hateful', 'insults', 'sexual', 'physical_violence', 'self_harm', 'all_other_misconduct']
149
+
150
+ for category in main_categories:
151
+ subcategories = CATEGORIES[category]
152
+ category_name = category.replace('_', ' ').title()
153
+
154
+ # Add emoji to category name based on type
155
+ category_emojis = {
156
+ 'Hateful': '🤬',
157
+ 'Insults': '💢',
158
+ 'Sexual': '🔞',
159
+ 'Physical Violence': '⚔️',
160
+ 'Self Harm': '☹️',
161
+ 'All Other Misconduct': '🙅‍♀️'
162
+ }
163
+ category_display = f"{category_emojis.get(category_name, '📝')} {category_name}"
164
+
165
+ # Get scores for all levels
166
+ level_scores = []
167
+ for i, subcategory_key in enumerate(subcategories):
168
+ score = results.get(subcategory_key, [0.0])[0]
169
+ level_scores.append((f"Level {i+1}", score))
170
+
171
+ # Find max score
172
+ max_score = max([score for _, score in level_scores]) if level_scores else 0.0
173
+
174
+ # Create the row HTML - just show max score
175
+ categories_html.append(f'''
176
+ <tr style="border-bottom: 1px solid #374151; transition: background-color 0.2s ease;">
177
+ <td style="padding: 16px; font-weight: 500; color: #f9fafb; font-size: 15px;">{category_display}</td>
178
+ <td style="padding: 16px; text-align: center;">{format_score_with_style(f"{max_score:.4f}")}</td>
179
+ </tr>
180
+ ''')
181
+
182
+ # Create refined HTML table for dark mode
183
+ html_table = f'''
184
+ <div style="margin: 24px 0;">
185
+ <div style="margin-bottom: 20px; text-align: center;">
186
+ <h2 style="color: #f9fafb; font-size: 20px; font-weight: 600; margin-bottom: 6px;">📊 Category-Specific Scores</h2>
187
+ </div>
188
+ <div style="background: #1f2937; border-radius: 12px; overflow: hidden; box-shadow: 0 4px 12px rgba(0,0,0,0.3); border: 1px solid #374151;">
189
+ <table style="width: 100%; border-collapse: collapse;">
190
+ <thead>
191
+ <tr style="background: linear-gradient(135deg, #374151 0%, #4b5563 100%);">
192
+ <th style="padding: 16px; text-align: left; font-weight: 600; font-size: 15px; color: #f9fafb;">Category</th>
193
+ <th style="padding: 16px; text-align: center; font-weight: 600; font-size: 15px; color: #f9fafb;">Score</th>
194
+ </tr>
195
+ </thead>
196
+ <tbody>
197
+ {"".join(categories_html)}
198
+ </tbody>
199
+ </table>
200
+ </div>
201
+ </div>
202
+ '''
203
+
204
+ return format_binary_score(binary_score), html_table
205
+
206
+ except Exception as e:
207
+ error_msg = f"Error analyzing text: {str(e)}"
208
+ error_html = f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 20px; border-radius: 12px; text-align: center; border: 2px solid #ef4444; box-shadow: 0 4px 12px rgba(0,0,0,0.3);">❌ {error_msg}</div>'
209
+ return f'<div style="background: linear-gradient(135deg, #991b1b 0%, #b91c1c 100%); color: #fca5a5; padding: 16px; border-radius: 8px; text-align: center; border: 1px solid #ef4444;">❌ {error_msg}</div>', error_html
210
+
211
+
212
+ # Create Gradio interface with dark theme
213
+ with gr.Blocks(title="LionGuard2", theme=gr.themes.Base().set(
214
+ body_background_fill="*neutral_950",
215
+ background_fill_primary="*neutral_900",
216
+ background_fill_secondary="*neutral_800",
217
+ border_color_primary="*neutral_700",
218
+ color_accent_soft="*blue_500"
219
+ )) as demo:
220
+ gr.HTML("""
221
+ <div style="text-align: center; margin-bottom: 40px; padding: 20px;">
222
+ <h1 style="color: #f9fafb; font-size: 36px; font-weight: 700; margin-bottom: 12px; text-shadow: 0 2px 4px rgba(0,0,0,0.3);">🦁 LionGuard2</h1>
223
+ <p style="color: #d1d5db; font-size: 16px; font-weight: 400; margin: 0;">Detect safety violations, and localised to Singapore</p>
224
+
225
+ </div>
226
+ """)
227
 
228
+ with gr.Row():
229
+ with gr.Column(scale=1, min_width=400):
230
+ text_input = gr.Textbox(
231
+ label="Enter text to analyze:",
232
+ placeholder="Type your text here...",
233
+ lines=12,
234
+ max_lines=20,
235
+ container=True
236
+ )
237
+
238
+ analyze_btn = gr.Button("🔍 Analyze Text", variant="primary")
239
+
240
+ with gr.Column(scale=1, min_width=400):
241
+ gr.HTML("""
242
+ <div style="margin-bottom: 24px; text-align: center;">
243
+ <h2 style="color: #f9fafb; font-size: 22px; font-weight: 600; margin-bottom: 8px;">Overall Safety Score</h2>
244
+ <p style="color: #d1d5db; font-size: 14px; margin: 0; opacity: 0.8;">Higher percentages indicate higher likelihood of harmful content</p>
245
+ </div>
246
+ """)
247
+
248
+ binary_output = gr.HTML(
249
+ value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Enter text to analyze</div>'
250
+ )
251
+
252
+ category_table = gr.HTML(
253
+ value='<div style="text-align: center; color: #9ca3af; padding: 30px; font-style: italic;">Category scores will appear here after analysis</div>'
254
+ )
255
+
256
+ # Add information about the categories
257
+ with gr.Row():
258
+ with gr.Accordion("ℹ️ About the Scoring System", open=False):
259
+ gr.HTML("""
260
+ <div style="font-size: 14px; line-height: 1.6; color: #f3f4f6; padding: 10px;">
261
+ <h3 style="color: #f9fafb; margin-bottom: 16px;">How Scoring Works:</h3>
262
+ <ul style="color: #d1d5db; margin-bottom: 24px;">
263
+ <li><b>Percentages represent likelihood of harmful content</b> - Higher % = More likely to be harmful</li>
264
+ <li><b>0-40%:</b> Content appears safe</li>
265
+ <li><b>40-70%:</b> Potentially concerning content that warrants review</li>
266
+ <li><b>70-100%:</b> High likelihood of policy violation</li>
267
+ </ul>
268
+ <h3 style="color: #f9fafb; margin-bottom: 16px;">Content Categories (Singapore Context):</h3>
269
+ <ul style="color: #d1d5db;">
270
+ <li><b>🤬 Hateful:</b> Content targeting Singapore's protected traits (e.g., race, religion), including discriminatory remarks and explicit calls for harm/violence.</li>
271
+ <li><b>💢 Insults:</b> Personal attacks on non-protected attributes (e.g., appearance). Note: Sexuality attacks are classified as insults, not hateful, in Singapore.</li>
272
+ <li><b>🔞 Sexual:</b> Sexual content or adult themes, ranging from mild content inappropriate for minors to explicit content inappropriate for general audiences.</li>
273
+ <li><b>⚔️ Physical Violence:</b> Threats, descriptions, or glorification of physical harm against individuals or groups (not property damage).</li>
274
+ <li><b>☹️ Self Harm:</b> Content about self-harm or suicide, including ideation, encouragement, or descriptions of ongoing actions.</li>
275
+ <li><b>🙅‍♀️ All Other Misconduct:</b> Unethical/criminal conduct not covered above, from socially condemned behavior to clearly illegal activities under Singapore law.</li>
276
+ </ul>
277
+ </div>
278
+ """)
279
 
280
+ # Connect the analyze button to the function
281
+ analyze_btn.click(
282
+ fn=analyze_text,
283
+ inputs=[text_input],
284
+ outputs=[binary_output, category_table]
285
+ )
 
 
286
 
287
+ # Allow Enter key to trigger analysis
288
+ text_input.submit(
289
+ fn=analyze_text,
290
+ inputs=[text_input],
291
+ outputs=[binary_output, category_table]
292
+ )
293
+
294
 
295
  if __name__ == "__main__":
296
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860)
download_model.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ download_model.py - Utility script to download LionGuard2 model from Hugging Face
4
+ """
5
+
6
+ import os
7
+ import argparse
8
+ from huggingface_hub import hf_hub_download
9
+
10
+
11
+ def download_lionguard2(repo_id, filename="LionGuard2.safetensors", token=None, output_dir="./"):
12
+ """
13
+ Download LionGuard2 model from Hugging Face private repository.
14
+
15
+ Args:
16
+ repo_id: The Hugging Face repository ID (e.g., "username/repo-name")
17
+ filename: The filename to download (default: "LionGuard2.safetensors")
18
+ token: Hugging Face access token for private repositories
19
+ output_dir: Directory to save the downloaded file
20
+ """
21
+ if token is None:
22
+ token = os.environ.get("HF_API_KEY")
23
+ if not token:
24
+ print("Error: No HF_API_KEY found in environment variables.")
25
+ print("Please set your Hugging Face token:")
26
+ print("export HF_API_KEY=your_token_here")
27
+ return False
28
+
29
+ try:
30
+ print(f"Downloading {filename} from {repo_id}...")
31
+
32
+ # Download the model file
33
+ model_path = hf_hub_download(
34
+ repo_id=repo_id,
35
+ filename=filename,
36
+ token=token,
37
+ local_dir=output_dir,
38
+ local_dir_use_symlinks=False # Download actual file, not symlink
39
+ )
40
+
41
+ print(f"✅ Model successfully downloaded to: {model_path}")
42
+ return True
43
+
44
+ except Exception as e:
45
+ print(f"❌ Failed to download model: {e}")
46
+ return False
47
+
48
+
49
+ def main():
50
+ parser = argparse.ArgumentParser(description="Download LionGuard2 model from Hugging Face")
51
+ parser.add_argument("repo_id", help="Hugging Face repository ID (e.g., username/repo-name)")
52
+ parser.add_argument("--filename", default="LionGuard2.safetensors", help="Filename to download")
53
+ parser.add_argument("--token", help="Hugging Face access token (optional if HF_API_KEY env var is set)")
54
+ parser.add_argument("--output-dir", default="./", help="Output directory for downloaded file")
55
+
56
+ args = parser.parse_args()
57
+
58
+ success = download_lionguard2(
59
+ repo_id=args.repo_id,
60
+ filename=args.filename,
61
+ token=args.token,
62
+ output_dir=args.output_dir
63
+ )
64
+
65
+ if success:
66
+ print(f"\n🎉 Ready to use! The model has been downloaded and can now be used by the application.")
67
+ else:
68
+ print(f"\n💡 Make sure you have:")
69
+ print(f" 1. Valid Hugging Face token with access to the private repository")
70
+ print(f" 2. Correct repository ID: {args.repo_id}")
71
+ print(f" 3. The model file exists in the repository")
72
+
73
+
74
+ if __name__ == "__main__":
75
+ main()
lionguard2.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ lionguard2.py
3
+ """
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ CATEGORIES = {
9
+ "binary": ["binary"],
10
+ "hateful": ["hateful_l1", "hateful_l2"],
11
+ "insults": ["insults"],
12
+ "sexual": [
13
+ "sexual_l1",
14
+ "sexual_l2",
15
+ ],
16
+ "physical_violence": ["physical_violence"],
17
+ "self_harm": ["self_harm_l1", "self_harm_l2"],
18
+ "all_other_misconduct": [
19
+ "all_other_misconduct_l1",
20
+ "all_other_misconduct_l2",
21
+ ],
22
+ }
23
+
24
+ INPUT_DIMENSION = 3072 # length of OpenAI embeddings
25
+
26
+
27
+ class LionGuard2(nn.Module):
28
+ def __init__(
29
+ self,
30
+ input_dim=INPUT_DIMENSION,
31
+ label_names=CATEGORIES.keys(),
32
+ categories=CATEGORIES,
33
+ ):
34
+ """
35
+ LionGuard2 is a localised content moderation model that flags whether text violates the following categories:
36
+
37
+ 1. `hateful`: Text that discriminates, criticizes, insults, denounces, or dehumanizes a person or group on the basis of a protected identity.
38
+
39
+ There are two sub-categories for the `hateful` category:
40
+ a. `level_1_discriminatory`: Text that contains derogatory or generalized negative statements targeting a protected group.
41
+ b. `level_2_hate_speech`: Text that explicitly calls for harm or violence against a protected group; or language praising or justifying violence against them.
42
+
43
+ 2. `insults`: Text that insults demeans, humiliates, mocks, or belittles a person or group **without** referencing a legally protected trait.
44
+ For example, this includes personal attacks on attributes such as someone’s appearance, intellect, behavior, or other non-protected characteristics.
45
+
46
+ 3. `sexual`: Text that depicts or indicates sexual interest, activity, or arousal, using direct or indirect references to body parts, sexual acts, or physical traits.
47
+ This includes sexual content that may be inappropriate for certain audiences.
48
+
49
+ There are two sub-categories for the `sexual` category:
50
+ a. `level_1_not_appropriate_for_minors`: Text that contains mild-to-moderate sexual content that is generally adult-oriented or potentially unsuitable for those under 16.
51
+ May include matter-of-fact discussions about sex, sexuality, or sexual preferences.
52
+ b. `level_2_not_appropriate_for_all_ages`: Text that contains content aimed at adults and considered explicit, graphic, or otherwise inappropriate for a broad audience.
53
+ May include explicit descriptions of sexual acts, detailed sexual fantasies, or highly sexualized content.
54
+
55
+ 4. `physical_violence`: Text that includes glorification of violence or threats to inflict physical harm or injury on a person, group, or entity.
56
+
57
+ 5. `self_harm`: Text that promotes, suggests, or expresses intent to self-harm or commit suicide.
58
+
59
+ There are two sub-categories for the `self_harm` category:
60
+ a. `level_1_self_harm_intent`: Text that expresses suicidal thoughts or self-harm intention; or content encouraging someone to self-harm.
61
+ b. `level_2_self_harm_action`: Text that describes or indicates ongoing or imminent self-harm behavior.
62
+
63
+ 6. `all_other_misconduct`: This is a catch-all category for any other unsafe text that does not fit into the other categories.
64
+ It includes text that seeks or provides information about engaging in misconduct, wrongdoing, or criminal activity, or that threatens to harm,
65
+ defraud, or exploit others. This includes facilitating illegal acts (under Singapore law) or other forms of socially harmful activity.
66
+
67
+ There are two sub-categories for the `all_other_misconduct` category:
68
+ a. `level_1_not_socially_accepted`: Text that advocates or instructs on unethical/immoral activities that may not necessarily be illegal but are socially condemned.
69
+ b. `level_2_illegal_activities`: Text that seeks or provides instructions to carry out clearly illegal activities or serious wrongdoing; includes credible threats of severe harm.
70
+
71
+ Lastly, there is an additional `binary` category (#7) which flags whether the text is unsafe in general.
72
+
73
+ The model takes in as input text, after it has been encoded with OpenAI's `text-embedding-3-small` model.
74
+
75
+ The model outputs the probabilities of each category being true.
76
+
77
+ ================================
78
+
79
+ Args:
80
+ input_dim: The dimension of the input embeddings. This defaults to 3072, which is the dimension of the embeddings from OpenAI's `text-embedding-3-small` model. This should not be changed.
81
+ label_names: The names of the labels. This defaults to the keys of the CATEGORIES dictionary. This should not be changed.
82
+ categories: The categories of the labels. This defaults to the CATEGORIES dictionary. This should not be changed.
83
+
84
+ Returns:
85
+ A LionGuard2 model.
86
+ """
87
+ super(LionGuard2, self).__init__()
88
+ self.label_names = label_names
89
+ self.n_outputs = len(label_names)
90
+ self.categories = categories
91
+
92
+ # Shared layers
93
+ self.shared_layers = nn.Sequential(
94
+ nn.Linear(input_dim, 256),
95
+ nn.ReLU(),
96
+ nn.Dropout(0.2),
97
+ nn.Linear(256, 128),
98
+ nn.ReLU(),
99
+ nn.Dropout(0.2),
100
+ )
101
+
102
+ # Output heads for each label
103
+ self.output_heads = nn.ModuleList(
104
+ [
105
+ nn.Sequential(
106
+ nn.Linear(128, 32),
107
+ nn.ReLU(),
108
+ nn.Linear(32, 2), # 2 thresholds for ordinal classification
109
+ nn.Sigmoid(),
110
+ )
111
+ for _ in range(self.n_outputs)
112
+ ]
113
+ )
114
+
115
+ def forward(self, x):
116
+ # Pass through shared layers
117
+ h = self.shared_layers(x)
118
+ # Pass through each output head
119
+ return [head(h) for head in self.output_heads]
120
+
121
+ def predict(self, embeddings):
122
+ """
123
+ Predict the probabilities of each label being true.
124
+
125
+ Args:
126
+ embeddings: A numpy array of embeddings (N * INPUT_DIMENSION)
127
+
128
+ Returns:
129
+ A dictionary of probabilities.
130
+ """
131
+ # Convert input to PyTorch tensor if not already
132
+ if not isinstance(embeddings, torch.Tensor):
133
+ x = torch.tensor(embeddings, dtype=torch.float32)
134
+ else:
135
+ x = embeddings
136
+
137
+ # Pass through model
138
+ with torch.no_grad():
139
+ outputs = self.forward(x)
140
+
141
+ # Stack outputs into a single tensor
142
+ raw_predictions = torch.stack(outputs) # SIZE:
143
+
144
+ # Extract and format probabilities from raw predictions
145
+ output = {}
146
+ for i, main_cat in enumerate(self.label_names):
147
+ sub_categories = self.categories[main_cat]
148
+ for j, sub_cat in enumerate(sub_categories):
149
+ # j=0 uses P(y>0)
150
+ # j=1 uses P(y>1) if L2 category exists
151
+ output[sub_cat] = raw_predictions[i, :, j]
152
+
153
+ # Post processing step:
154
+ # If L2 category exists, and P(L2) > P(L1),
155
+ # Set both P(L1) and P(L2) to their average to maintain ordinal consistency
156
+ if len(sub_categories) > 1:
157
+ l1 = output[sub_categories[0]]
158
+ l2 = output[sub_categories[1]]
159
+
160
+ # Update probabilities on samples where P(L2) > P(L1)
161
+ mask = l2 > l1
162
+ mean_prob = (l1 + l2) / 2
163
+ l1[mask] = mean_prob[mask]
164
+ l2[mask] = mean_prob[mask]
165
+ output[sub_categories[0]] = l1
166
+ output[sub_categories[1]] = l2
167
+
168
+ for key, value in output.items():
169
+ output[key] = value.numpy().tolist()
170
+ return output
model.joblib DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6a2f7769a11fe468b2499d08e01ce7522d08a18b0103838404b69210c9b2616c
3
- size 20552060
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1,6 @@
1
- keras
2
- openai
3
- tensorflow
4
- joblib
5
- logfire
6
- scikit-learn
7
- pandas
8
- numpy
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ safetensors>=0.4.0
4
+ openai>=1.0.0
5
+ numpy>=1.24.0
6
+ huggingface_hub>=0.19.0
 
 
utils.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ utils.py
3
+ """
4
+
5
+ # Standard imports
6
+ import os
7
+ from typing import List
8
+
9
+ # Third party imports
10
+ import numpy as np
11
+ from openai import OpenAI
12
+
13
+ client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))
14
+
15
+ # Maximum tokens for text-embedding-3-large
16
+ MAX_TOKENS = 8191 # We don't have access to the tokenizer for text-embedding-3-large, and just assume 1 character = 1 token here
17
+
18
+
19
+ def get_embeddings(
20
+ texts: List[str], model: str = "text-embedding-3-large"
21
+ ) -> List[List[float]]:
22
+ """
23
+ Generate embeddings for a list of texts using OpenAI API synchronously.
24
+
25
+ Args:
26
+ texts: List of strings to embed.
27
+ model: OpenAI embedding model to use (default: text-embedding-3-large).
28
+
29
+ Returns:
30
+ A list of embeddings (each embedding is a list of floats).
31
+
32
+ Raises:
33
+ Exception: If the OpenAI API call fails.
34
+ """
35
+
36
+ # Truncate texts to max token limit
37
+ truncated_texts = [text[:MAX_TOKENS] for text in texts]
38
+
39
+ # Make the API call
40
+ response = client.embeddings.create(input=truncated_texts, model=model)
41
+
42
+ # Extract embeddings from response
43
+ embeddings = np.array([data.embedding for data in response.data])
44
+ return embeddings