asofter commited on
Commit
0604028
β€’
1 Parent(s): 62b6050

* add AWS comprehend

Browse files
Files changed (3) hide show
  1. README.md +2 -1
  2. app.py +38 -3
  3. requirements.txt +7 -6
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ“
4
  colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.9.0
8
  pinned: true
9
  license: apache-2.0
10
  ---
@@ -35,3 +35,4 @@ gradio app.py
35
  - [Rebuff](https://rebuff.ai/)
36
  - [Azure Content Safety AI](https://learn.microsoft.com/en-us/azure/ai-services/content-safety/studio-quickstart)
37
  - [AWS Bedrock Guardrails](https://aws.amazon.com/bedrock/guardrails/) (coming soon)
 
 
4
  colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.19.1
8
  pinned: true
9
  license: apache-2.0
10
  ---
 
35
  - [Rebuff](https://rebuff.ai/)
36
  - [Azure Content Safety AI](https://learn.microsoft.com/en-us/azure/ai-services/content-safety/studio-quickstart)
37
  - [AWS Bedrock Guardrails](https://aws.amazon.com/bedrock/guardrails/) (coming soon)
38
+ - [AWS Comprehend](https://docs.aws.amazon.com/comprehend/latest/dg/trust-safety.html)
app.py CHANGED
@@ -11,6 +11,7 @@ from functools import lru_cache
11
  from typing import List, Union
12
 
13
  import aegis
 
14
  import gradio as gr
15
  import requests
16
  from huggingface_hub import HfApi
@@ -29,6 +30,7 @@ automorphic_api_key = os.getenv("AUTOMORPHIC_API_KEY")
29
  rebuff_api_key = os.getenv("REBUFF_API_KEY")
30
  azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
31
  azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
 
32
 
33
 
34
  @lru_cache(maxsize=2)
@@ -61,7 +63,9 @@ def convert_elapsed_time(diff_time) -> float:
61
  deepset_classifier = init_prompt_injection_model(
62
  "ProtectAI/deberta-v3-base-injection-onnx"
63
  ) # ONNX version of deepset/deberta-v3-base-injection
64
- protectai_classifier = init_prompt_injection_model("ProtectAI/deberta-v3-base-prompt-injection", "onnx")
 
 
65
  fmops_classifier = init_prompt_injection_model(
66
  "ProtectAI/fmops-distilbert-prompt-injection-onnx"
67
  ) # ONNX version of fmops/distilbert-prompt-injection
@@ -155,6 +159,36 @@ def detect_azure(prompt: str) -> (bool, bool):
155
  return False, False
156
 
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  detection_providers = {
159
  "ProtectAI (HF model)": detect_hf_protectai,
160
  "Deepset (HF model)": detect_hf_deepset,
@@ -163,6 +197,7 @@ detection_providers = {
163
  "Automorphic Aegis": detect_automorphic,
164
  # "Rebuff": detect_rebuff,
165
  "Azure Content Safety": detect_azure,
 
166
  }
167
 
168
 
@@ -235,8 +270,8 @@ if __name__ == "__main__":
235
  "The results are <strong>stored in the private dataset</strong> for further analysis and improvements. This interface is for research purposes only."
236
  "<br /><br />"
237
  "HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.<br /><br />"
238
- "<a href=\"https://join.slack.com/t/laiyerai/shared_invite/zt-28jv3ci39-sVxXrLs3rQdaN3mIl9IT~w\">Join our Slack community to discuss LLM Security</a><br />"
239
- "<a href=\"https://github.com/protectai/llm-guard\">Secure your LLM interactions with LLM Guard</a>",
240
  examples=[
241
  [
242
  example,
 
11
  from typing import List, Union
12
 
13
  import aegis
14
+ import boto3
15
  import gradio as gr
16
  import requests
17
  from huggingface_hub import HfApi
 
30
  rebuff_api_key = os.getenv("REBUFF_API_KEY")
31
  azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
32
  azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
33
+ aws_comprehend_client = boto3.client(service_name="comprehend", region_name="us-east-1")
34
 
35
 
36
  @lru_cache(maxsize=2)
 
63
  deepset_classifier = init_prompt_injection_model(
64
  "ProtectAI/deberta-v3-base-injection-onnx"
65
  ) # ONNX version of deepset/deberta-v3-base-injection
66
+ protectai_classifier = init_prompt_injection_model(
67
+ "ProtectAI/deberta-v3-base-prompt-injection", "onnx"
68
+ )
69
  fmops_classifier = init_prompt_injection_model(
70
  "ProtectAI/fmops-distilbert-prompt-injection-onnx"
71
  ) # ONNX version of fmops/distilbert-prompt-injection
 
159
  return False, False
160
 
161
 
162
+ def detect_aws_comprehend(prompt: str) -> (bool, bool):
163
+ response = aws_comprehend_client.classify_document(
164
+ EndpointArn="arn:aws:comprehend:us-east-1:aws:document-classifier-endpoint/prompt-safety",
165
+ Text=prompt,
166
+ )
167
+ response = {
168
+ "Classes": [
169
+ {"Name": "SAFE_PROMPT", "Score": 0.9010000228881836},
170
+ {"Name": "UNSAFE_PROMPT", "Score": 0.0989999994635582},
171
+ ],
172
+ "ResponseMetadata": {
173
+ "RequestId": "e8900fe1-3346-45c0-bad3-007b2840865a",
174
+ "HTTPStatusCode": 200,
175
+ "HTTPHeaders": {
176
+ "x-amzn-requestid": "e8900fe1-3346-45c0-bad3-007b2840865a",
177
+ "content-type": "application/x-amz-json-1.1",
178
+ "content-length": "115",
179
+ "date": "Mon, 19 Feb 2024 08:34:43 GMT",
180
+ },
181
+ "RetryAttempts": 0,
182
+ },
183
+ }
184
+ logger.info(f"Prompt injection result from AWS Comprehend: {response}")
185
+ if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
186
+ logger.error(f"Failed to call AWS Comprehend API: {response}")
187
+ return False, False
188
+
189
+ return True, response["Classes"][0] == "UNSAFE_PROMPT"
190
+
191
+
192
  detection_providers = {
193
  "ProtectAI (HF model)": detect_hf_protectai,
194
  "Deepset (HF model)": detect_hf_deepset,
 
197
  "Automorphic Aegis": detect_automorphic,
198
  # "Rebuff": detect_rebuff,
199
  "Azure Content Safety": detect_azure,
200
+ "AWS Comprehend": detect_aws_comprehend,
201
  }
202
 
203
 
 
270
  "The results are <strong>stored in the private dataset</strong> for further analysis and improvements. This interface is for research purposes only."
271
  "<br /><br />"
272
  "HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.<br /><br />"
273
+ '<a href="https://join.slack.com/t/laiyerai/shared_invite/zt-28jv3ci39-sVxXrLs3rQdaN3mIl9IT~w">Join our Slack community to discuss LLM Security</a><br />'
274
+ '<a href="https://github.com/protectai/llm-guard">Secure your LLM interactions with LLM Guard</a>',
275
  examples=[
276
  [
277
  example,
requirements.txt CHANGED
@@ -1,8 +1,9 @@
 
1
  git+https://github.com/automorphic-ai/aegis.git
2
- gradio==4.9.0
3
- huggingface_hub==0.19.4
4
- onnxruntime==1.16.3
5
- optimum[onnxruntime]==1.15.0
6
- rebuff==0.0.5
7
  requests==2.31.0
8
- transformers==4.36.0
 
1
+ boto3==1.34.44
2
  git+https://github.com/automorphic-ai/aegis.git
3
+ gradio==4.19.1
4
+ huggingface_hub==0.20.3
5
+ onnxruntime==1.17.0
6
+ optimum[onnxruntime]==1.17.1
7
+ rebuff==0.1.1
8
  requests==2.31.0
9
+ transformers==4.37.2