Spaces:
Runtime error
Runtime error
import argparse | |
import glob | |
import json | |
import logging | |
import multiprocessing as mp | |
import os | |
import time | |
import uuid | |
from datetime import timedelta | |
from functools import lru_cache | |
from typing import List, Union | |
import aegis | |
import boto3 | |
import gradio as gr | |
import requests | |
from huggingface_hub import HfApi | |
from optimum.onnxruntime import ORTModelForSequenceClassification | |
from rebuff import Rebuff | |
from transformers import AutoTokenizer, pipeline | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
hf_api = HfApi(token=os.getenv("HF_TOKEN")) | |
num_processes = 2 # mp.cpu_count() | |
lakera_api_key = os.getenv("LAKERA_API_KEY") | |
automorphic_api_key = os.getenv("AUTOMORPHIC_API_KEY") | |
rebuff_api_key = os.getenv("REBUFF_API_KEY") | |
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT") | |
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY") | |
aws_comprehend_client = boto3.client(service_name="comprehend", region_name="us-east-1") | |
def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline: | |
hf_model = ORTModelForSequenceClassification.from_pretrained( | |
prompt_injection_ort_model, | |
export=False, | |
subfolder=subfolder, | |
file_name="model.onnx" | |
) | |
hf_tokenizer = AutoTokenizer.from_pretrained(prompt_injection_ort_model, subfolder=subfolder) | |
hf_tokenizer.model_input_names = ["input_ids", "attention_mask"] | |
logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU") | |
return pipeline( | |
"text-classification", | |
model=hf_model, | |
tokenizer=hf_tokenizer, | |
device="cpu", | |
batch_size=1, | |
truncation=True, | |
max_length=512, | |
) | |
def convert_elapsed_time(diff_time) -> float: | |
return round(timedelta(seconds=diff_time).total_seconds(), 2) | |
deepset_classifier = init_prompt_injection_model( | |
"ProtectAI/deberta-v3-base-injection-onnx" | |
) # ONNX version of deepset/deberta-v3-base-injection | |
protectai_v1_classifier = init_prompt_injection_model( | |
"ProtectAI/deberta-v3-base-prompt-injection", "onnx" | |
) | |
protectai_v2_classifier = init_prompt_injection_model( | |
"ProtectAI/deberta-v3-base-prompt-injection-v2", "onnx" | |
) | |
fmops_classifier = init_prompt_injection_model( | |
"ProtectAI/fmops-distilbert-prompt-injection-onnx" | |
) # ONNX version of fmops/distilbert-prompt-injection | |
def detect_hf( | |
prompt: str, threshold: float = 0.5, classifier=protectai_v1_classifier, label: str = "INJECTION" | |
) -> (bool, bool): | |
try: | |
pi_result = classifier(prompt) | |
injection_score = round( | |
pi_result[0]["score"] if pi_result[0]["label"] == label else 1 - pi_result[0]["score"], | |
2, | |
) | |
logger.info(f"Prompt injection result from the HF model: {pi_result}") | |
return True, injection_score > threshold | |
except Exception as err: | |
logger.error(f"Failed to call HF model: {err}") | |
return False, False | |
def detect_hf_protectai_v1(prompt: str) -> (bool, bool): | |
return detect_hf(prompt, classifier=protectai_v1_classifier) | |
def detect_hf_protectai_v2(prompt: str) -> (bool, bool): | |
return detect_hf(prompt, classifier=protectai_v2_classifier) | |
def detect_hf_deepset(prompt: str) -> (bool, bool): | |
return detect_hf(prompt, classifier=deepset_classifier) | |
def detect_hf_fmops(prompt: str) -> (bool, bool): | |
return detect_hf(prompt, classifier=fmops_classifier, label="LABEL_1") | |
def detect_lakera(prompt: str) -> (bool, bool): | |
try: | |
response = requests.post( | |
"https://api.lakera.ai/v1/prompt_injection", | |
json={"input": prompt}, | |
headers={"Authorization": f"Bearer {lakera_api_key}"}, | |
) | |
response_json = response.json() | |
logger.info(f"Prompt injection result from Lakera: {response.json()}") | |
return True, response_json["results"][0]["flagged"] | |
except requests.RequestException as err: | |
logger.error(f"Failed to call Lakera API: {err}") | |
return False, False | |
def detect_automorphic(prompt: str) -> (bool, bool): | |
ag = aegis.Aegis(automorphic_api_key) | |
try: | |
ingress_attack_detected = ag.ingress(prompt, "") | |
logger.info(f"Prompt injection result from Automorphic: {ingress_attack_detected}") | |
return True, ingress_attack_detected["detected"] | |
except Exception as err: | |
logger.error(f"Failed to call Automorphic API: {err}") | |
return False, False # Assume it's not attack | |
def detect_rebuff(prompt: str) -> (bool, bool): | |
try: | |
rb = Rebuff(api_token=rebuff_api_key, api_url="https://www.rebuff.ai") | |
result = rb.detect_injection(prompt) | |
logger.info(f"Prompt injection result from Rebuff: {result}") | |
return True, result.injectionDetected | |
except Exception as err: | |
logger.error(f"Failed to call Rebuff API: {err}") | |
return False, False | |
def detect_azure(prompt: str) -> (bool, bool): | |
try: | |
response = requests.post( | |
f"{azure_content_safety_endpoint}contentsafety/text:detectJailbreak?api-version=2023-10-15-preview", | |
json={"text": prompt}, | |
headers={"Ocp-Apim-Subscription-Key": azure_content_safety_key}, | |
) | |
response_json = response.json() | |
logger.info(f"Prompt injection result from Azure: {response.json()}") | |
if "jailbreakAnalysis" not in response_json: | |
return False, False | |
return True, response_json["jailbreakAnalysis"]["detected"] | |
except requests.RequestException as err: | |
logger.error(f"Failed to call Azure API: {err}") | |
return False, False | |
def detect_aws_comprehend(prompt: str) -> (bool, bool): | |
response = aws_comprehend_client.classify_document( | |
EndpointArn="arn:aws:comprehend:us-east-1:aws:document-classifier-endpoint/prompt-safety", | |
Text=prompt, | |
) | |
response = { | |
"Classes": [ | |
{"Name": "SAFE_PROMPT", "Score": 0.9010000228881836}, | |
{"Name": "UNSAFE_PROMPT", "Score": 0.0989999994635582}, | |
], | |
"ResponseMetadata": { | |
"RequestId": "e8900fe1-3346-45c0-bad3-007b2840865a", | |
"HTTPStatusCode": 200, | |
"HTTPHeaders": { | |
"x-amzn-requestid": "e8900fe1-3346-45c0-bad3-007b2840865a", | |
"content-type": "application/x-amz-json-1.1", | |
"content-length": "115", | |
"date": "Mon, 19 Feb 2024 08:34:43 GMT", | |
}, | |
"RetryAttempts": 0, | |
}, | |
} | |
logger.info(f"Prompt injection result from AWS Comprehend: {response}") | |
if response["ResponseMetadata"]["HTTPStatusCode"] != 200: | |
logger.error(f"Failed to call AWS Comprehend API: {response}") | |
return False, False | |
return True, response["Classes"][0] == "UNSAFE_PROMPT" | |
detection_providers = { | |
"ProtectAI v1 (HF model)": detect_hf_protectai_v1, | |
"ProtectAI v2 (HF model)": detect_hf_protectai_v2, | |
"Deepset (HF model)": detect_hf_deepset, | |
"FMOps (HF model)": detect_hf_fmops, | |
"Lakera Guard": detect_lakera, | |
"Automorphic Aegis": detect_automorphic, | |
# "Rebuff": detect_rebuff, | |
"Azure Content Safety": detect_azure, | |
#"AWS Comprehend": detect_aws_comprehend, | |
} | |
def is_detected(provider: str, prompt: str) -> (str, bool, bool, float): | |
if provider not in detection_providers: | |
logger.warning(f"Provider {provider} is not supported") | |
return False, 0.0 | |
start_time = time.monotonic() | |
request_result, is_injection = detection_providers[provider](prompt) | |
end_time = time.monotonic() | |
return provider, request_result, is_injection, convert_elapsed_time(end_time - start_time) | |
def execute(prompt: str) -> List[Union[str, bool, float]]: | |
results = [] | |
with mp.Pool(processes=num_processes) as pool: | |
for result in pool.starmap( | |
is_detected, [(provider, prompt) for provider in detection_providers.keys()] | |
): | |
results.append(result) | |
# Save image and result | |
fileobj = json.dumps( | |
{"prompt": prompt, "results": results}, indent=2, ensure_ascii=False | |
).encode("utf-8") | |
result_path = f"/prompts/train/{str(uuid.uuid4())}.json" | |
hf_api.upload_file( | |
path_or_fileobj=fileobj, | |
path_in_repo=result_path, | |
repo_id="ProtectAI/prompt-injection-benchmark", | |
repo_type="dataset", | |
) | |
logger.info(f"Stored prompt: {prompt}") | |
return results | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--url", type=str, default="0.0.0.0") | |
args, left_argv = parser.parse_known_args() | |
example_files = glob.glob(os.path.join(os.path.dirname(__file__), "examples", "*.txt")) | |
examples = [open(file).read() for file in example_files] | |
gr.Interface( | |
fn=execute, | |
inputs=[ | |
gr.Textbox(label="Prompt"), | |
], | |
outputs=[ | |
gr.Dataframe( | |
headers=[ | |
"Provider", | |
"Is processed successfully?", | |
"Is prompt injection?", | |
"Latency (seconds)", | |
], | |
datatype=["str", "bool", "bool", "number"], | |
label="Results", | |
), | |
], | |
title="Prompt Injection Solutions Benchmark", | |
description="This interface aims to benchmark the known prompt injection detection providers. " | |
"The results are <strong>stored in the private dataset</strong> for further analysis and improvements. This interface is for research purposes only." | |
"<br /><br />" | |
"HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.<br /><br />" | |
'<a href="https://join.slack.com/t/laiyerai/shared_invite/zt-28jv3ci39-sVxXrLs3rQdaN3mIl9IT~w">Join our Slack community to discuss LLM Security</a><br />' | |
'<a href="https://github.com/protectai/llm-guard">Secure your LLM interactions with LLM Guard</a>', | |
examples=[ | |
[ | |
example, | |
False, | |
] | |
for example in examples | |
], | |
cache_examples=True, | |
allow_flagging="never", | |
concurrency_limit=1, | |
).launch(server_name=args.url, server_port=args.port) | |