asofter commited on
Commit
d886d80
β€’
1 Parent(s): 40d2b4b

* introduce new model

Browse files
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +14 -6
  3. requirements.txt +6 -6
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ“
4
  colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 4.19.1
8
  pinned: true
9
  license: apache-2.0
10
  ---
 
4
  colorFrom: yellow
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 4.27.0
8
  pinned: true
9
  license: apache-2.0
10
  ---
app.py CHANGED
@@ -64,16 +64,19 @@ def convert_elapsed_time(diff_time) -> float:
64
  deepset_classifier = init_prompt_injection_model(
65
  "ProtectAI/deberta-v3-base-injection-onnx"
66
  ) # ONNX version of deepset/deberta-v3-base-injection
67
- protectai_classifier = init_prompt_injection_model(
68
  "ProtectAI/deberta-v3-base-prompt-injection", "onnx"
69
  )
 
 
 
70
  fmops_classifier = init_prompt_injection_model(
71
  "ProtectAI/fmops-distilbert-prompt-injection-onnx"
72
  ) # ONNX version of fmops/distilbert-prompt-injection
73
 
74
 
75
  def detect_hf(
76
- prompt: str, threshold: float = 0.5, classifier=protectai_classifier, label: str = "INJECTION"
77
  ) -> (bool, bool):
78
  try:
79
  pi_result = classifier(prompt)
@@ -90,8 +93,12 @@ def detect_hf(
90
  return False, False
91
 
92
 
93
- def detect_hf_protectai(prompt: str) -> (bool, bool):
94
- return detect_hf(prompt, classifier=protectai_classifier)
 
 
 
 
95
 
96
 
97
  def detect_hf_deepset(prompt: str) -> (bool, bool):
@@ -191,14 +198,15 @@ def detect_aws_comprehend(prompt: str) -> (bool, bool):
191
 
192
 
193
  detection_providers = {
194
- "ProtectAI (HF model)": detect_hf_protectai,
 
195
  "Deepset (HF model)": detect_hf_deepset,
196
  "FMOps (HF model)": detect_hf_fmops,
197
  "Lakera Guard": detect_lakera,
198
  "Automorphic Aegis": detect_automorphic,
199
  # "Rebuff": detect_rebuff,
200
  "Azure Content Safety": detect_azure,
201
- "AWS Comprehend": detect_aws_comprehend,
202
  }
203
 
204
 
 
64
  deepset_classifier = init_prompt_injection_model(
65
  "ProtectAI/deberta-v3-base-injection-onnx"
66
  ) # ONNX version of deepset/deberta-v3-base-injection
67
+ protectai_v1_classifier = init_prompt_injection_model(
68
  "ProtectAI/deberta-v3-base-prompt-injection", "onnx"
69
  )
70
+ protectai_v2_classifier = init_prompt_injection_model(
71
+ "ProtectAI/deberta-v3-base-prompt-injection-v2", "onnx"
72
+ )
73
  fmops_classifier = init_prompt_injection_model(
74
  "ProtectAI/fmops-distilbert-prompt-injection-onnx"
75
  ) # ONNX version of fmops/distilbert-prompt-injection
76
 
77
 
78
  def detect_hf(
79
+ prompt: str, threshold: float = 0.5, classifier=protectai_v1_classifier, label: str = "INJECTION"
80
  ) -> (bool, bool):
81
  try:
82
  pi_result = classifier(prompt)
 
93
  return False, False
94
 
95
 
96
+ def detect_hf_protectai_v1(prompt: str) -> (bool, bool):
97
+ return detect_hf(prompt, classifier=protectai_v1_classifier)
98
+
99
+
100
+ def detect_hf_protectai_v2(prompt: str) -> (bool, bool):
101
+ return detect_hf(prompt, classifier=protectai_v2_classifier)
102
 
103
 
104
  def detect_hf_deepset(prompt: str) -> (bool, bool):
 
198
 
199
 
200
  detection_providers = {
201
+ "ProtectAI v1 (HF model)": detect_hf_protectai_v1,
202
+ "ProtectAI v2 (HF model)": detect_hf_protectai_v2,
203
  "Deepset (HF model)": detect_hf_deepset,
204
  "FMOps (HF model)": detect_hf_fmops,
205
  "Lakera Guard": detect_lakera,
206
  "Automorphic Aegis": detect_automorphic,
207
  # "Rebuff": detect_rebuff,
208
  "Azure Content Safety": detect_azure,
209
+ #"AWS Comprehend": detect_aws_comprehend,
210
  }
211
 
212
 
requirements.txt CHANGED
@@ -1,9 +1,9 @@
1
- boto3==1.34.44
2
  git+https://github.com/automorphic-ai/aegis.git
3
- gradio==4.19.1
4
- huggingface_hub==0.20.3
5
- onnxruntime==1.17.0
6
- optimum[onnxruntime]==1.17.1
7
  rebuff==0.1.1
8
  requests==2.31.0
9
- transformers==4.37.2
 
1
+ boto3==1.34.88
2
  git+https://github.com/automorphic-ai/aegis.git
3
+ gradio==4.27.0
4
+ huggingface_hub==0.22.2
5
+ onnxruntime==1.17.3
6
+ optimum[onnxruntime]==1.19.0
7
  rebuff==0.1.1
8
  requests==2.31.0
9
+ transformers==4.40.0