Spaces:
Runtime error
Runtime error
import argparse | |
import glob | |
import json | |
import logging | |
import multiprocessing as mp | |
import os | |
import time | |
import uuid | |
from datetime import timedelta | |
from functools import lru_cache | |
from typing import List, Union | |
import boto3 | |
import gradio as gr | |
import requests | |
from huggingface_hub import HfApi | |
from optimum.onnxruntime import ORTModelForSequenceClassification | |
from rebuff import Rebuff | |
from transformers import AutoTokenizer, pipeline | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
hf_token = os.getenv("HF_TOKEN") | |
hf_api = HfApi(token=hf_token) | |
num_processes = 2 # mp.cpu_count() | |
lakera_api_key = os.getenv("LAKERA_API_KEY") | |
sydelabs_api_key = os.getenv("SYDELABS_API_KEY") | |
rebuff_api_key = os.getenv("REBUFF_API_KEY") | |
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT") | |
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY") | |
bedrock_runtime_client = boto3.client('bedrock-runtime', region_name="us-east-1") | |
def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline: | |
hf_model = ORTModelForSequenceClassification.from_pretrained( | |
prompt_injection_ort_model, | |
export=False, | |
subfolder=subfolder, | |
file_name="model.onnx", | |
token=hf_token, | |
) | |
hf_tokenizer = AutoTokenizer.from_pretrained( | |
prompt_injection_ort_model, subfolder=subfolder, token=hf_token | |
) | |
hf_tokenizer.model_input_names = ["input_ids", "attention_mask"] | |
logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU") | |
return pipeline( | |
"text-classification", | |
model=hf_model, | |
tokenizer=hf_tokenizer, | |
device="cpu", | |
batch_size=1, | |
truncation=True, | |
max_length=512, | |
) | |
def convert_elapsed_time(diff_time) -> float: | |
return round(timedelta(seconds=diff_time).total_seconds(), 2) | |
deepset_classifier = init_prompt_injection_model( | |
"protectai/deberta-v3-base-injection-onnx" | |
) # ONNX version of deepset/deberta-v3-base-injection | |
protectai_v2_classifier = init_prompt_injection_model( | |
"protectai/deberta-v3-base-prompt-injection-v2", "onnx" | |
) | |
fmops_classifier = init_prompt_injection_model( | |
"protectai/fmops-distilbert-prompt-injection-onnx" | |
) # ONNX version of fmops/distilbert-prompt-injection | |
protectai_v2_small_classifier = init_prompt_injection_model( | |
"protectai/deberta-v3-small-prompt-injection-v2", "onnx" | |
) # ONNX version of protectai/deberta-v3-small-prompt-injection-v2 | |
def detect_hf( | |
prompt: str, | |
threshold: float = 0.5, | |
classifier=protectai_v2_classifier, | |
label: str = "INJECTION", | |
) -> (bool, bool): | |
try: | |
pi_result = classifier(prompt) | |
injection_score = round( | |
pi_result[0]["score"] if pi_result[0]["label"] == label else 1 - pi_result[0]["score"], | |
2, | |
) | |
logger.info(f"Prompt injection result from the HF model: {pi_result}") | |
return True, injection_score > threshold | |
except Exception as err: | |
logger.error(f"Failed to call HF model: {err}") | |
return False, False | |
def detect_hf_protectai_v2(prompt: str) -> (bool, bool): | |
return detect_hf(prompt, classifier=protectai_v2_classifier) | |
def detect_hf_protectai_v2_small(prompt: str) -> (bool, bool): | |
return detect_hf(prompt, classifier=protectai_v2_small_classifier) | |
def detect_hf_deepset(prompt: str) -> (bool, bool): | |
return detect_hf(prompt, classifier=deepset_classifier) | |
def detect_hf_fmops(prompt: str) -> (bool, bool): | |
return detect_hf(prompt, classifier=fmops_classifier, label="LABEL_1") | |
def detect_lakera(prompt: str) -> (bool, bool): | |
try: | |
response = requests.post( | |
"https://api.lakera.ai/v1/prompt_injection", | |
json={"input": prompt}, | |
headers={"Authorization": f"Bearer {lakera_api_key}"}, | |
) | |
response_json = response.json() | |
logger.info(f"Prompt injection result from Lakera: {response.json()}") | |
return True, response_json["results"][0]["flagged"] | |
except requests.RequestException as err: | |
logger.error(f"Failed to call Lakera API: {err}") | |
return False, False | |
def detect_rebuff(prompt: str) -> (bool, bool): | |
try: | |
rb = Rebuff(api_token=rebuff_api_key, api_url="https://www.rebuff.ai") | |
result = rb.detect_injection(prompt) | |
logger.info(f"Prompt injection result from Rebuff: {result}") | |
return True, result.injectionDetected | |
except Exception as err: | |
logger.error(f"Failed to call Rebuff API: {err}") | |
return False, False | |
def detect_azure(prompt: str) -> (bool, bool): | |
try: | |
response = requests.post( | |
f"{azure_content_safety_endpoint}contentsafety/text:shieldPrompt?api-version=2024-02-15-preview", | |
json={"userPrompt": prompt}, | |
headers={"Ocp-Apim-Subscription-Key": azure_content_safety_key}, | |
) | |
response_json = response.json() | |
logger.info(f"Prompt injection result from Azure: {response.json()}") | |
if "userPromptAnalysis" not in response_json: | |
return False, False | |
return True, response_json["userPromptAnalysis"]["attackDetected"] | |
except requests.RequestException as err: | |
logger.error(f"Failed to call Azure API: {err}") | |
return False, False | |
def detect_aws_bedrock(prompt: str) -> (bool, bool): | |
response = bedrock_runtime_client.apply_guardrail( | |
guardrailIdentifier="tx8t6psx14ho", | |
guardrailVersion="1", | |
source='INPUT', | |
content=[ | |
{"text": {"text": prompt}} | |
]) | |
logger.info(f"Prompt injection result from AWS Bedrock Guardrails: {response}") | |
if response["ResponseMetadata"]["HTTPStatusCode"] != 200: | |
logger.error(f"Failed to call AWS Bedrock Guardrails API: {response}") | |
return False, False | |
return True, response['action'] != 'NONE' | |
def detect_sydelabs(prompt: str) -> (bool, bool): | |
try: | |
response = requests.post( | |
"https://guard.sydelabs.ai/api/v1/guard/generate-score", | |
json={"prompt": prompt}, | |
headers={ | |
"Authorization": f"Bearer {lakera_api_key}", | |
"X-Api-Key": sydelabs_api_key, | |
}, | |
) | |
response_json = response.json() | |
logger.info(f"Prompt injection result from SydeLabs: {response.json()}") | |
prompt_injection_risk = next( | |
( | |
category["risk"] | |
for category in response_json["category_scores"] | |
if category["category"] == "PROMPT_INJECT" | |
), | |
False, | |
) | |
return True, prompt_injection_risk | |
except requests.RequestException as err: | |
logger.error(f"Failed to call SydeLabs API: {err}") | |
return False, False | |
detection_providers = { | |
"ProtectAI v2 (HF model)": detect_hf_protectai_v2, | |
"ProtectAI v2 Small (HF model)": detect_hf_protectai_v2_small, | |
"Deepset (HF model)": detect_hf_deepset, | |
"FMOps (HF model)": detect_hf_fmops, | |
"Lakera Guard": detect_lakera, | |
# "Rebuff": detect_rebuff, | |
"Azure Content Safety": detect_azure, | |
"SydeLabs": detect_sydelabs, | |
"AWS Bedrock Guardrails": detect_aws_bedrock, | |
} | |
def is_detected(provider: str, prompt: str) -> (str, bool, bool, float): | |
if provider not in detection_providers: | |
logger.warning(f"Provider {provider} is not supported") | |
return False, 0.0 | |
start_time = time.monotonic() | |
request_result, is_injection = detection_providers[provider](prompt) | |
end_time = time.monotonic() | |
return provider, request_result, is_injection, convert_elapsed_time(end_time - start_time) | |
def execute(prompt: str) -> List[Union[str, bool, float]]: | |
results = [] | |
with mp.Pool(processes=num_processes) as pool: | |
for result in pool.starmap( | |
is_detected, [(provider, prompt) for provider in detection_providers.keys()] | |
): | |
results.append(result) | |
# Save image and result | |
fileobj = json.dumps( | |
{"prompt": prompt, "results": results}, indent=2, ensure_ascii=False | |
).encode("utf-8") | |
result_path = f"/prompts/train/{str(uuid.uuid4())}.json" | |
hf_api.upload_file( | |
path_or_fileobj=fileobj, | |
path_in_repo=result_path, | |
repo_id="protectai/prompt-injection-benchmark", | |
repo_type="dataset", | |
) | |
logger.info(f"Stored prompt: {prompt}") | |
return results | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--port", type=int, default=7860) | |
parser.add_argument("--url", type=str, default="0.0.0.0") | |
args, left_argv = parser.parse_known_args() | |
example_files = glob.glob(os.path.join(os.path.dirname(__file__), "examples", "*.txt")) | |
examples = [open(file).read() for file in example_files] | |
gr.Interface( | |
fn=execute, | |
inputs=[ | |
gr.Textbox(label="Prompt"), | |
], | |
outputs=[ | |
gr.Dataframe( | |
headers=[ | |
"Provider", | |
"Is processed successfully?", | |
"Is prompt injection?", | |
"Latency (seconds)", | |
], | |
datatype=["str", "bool", "bool", "number"], | |
label="Results", | |
), | |
], | |
title="Prompt Injection Solutions Benchmark", | |
description="This interface aims to benchmark the known prompt injection detection providers. " | |
"The results are <strong>stored in the private dataset</strong> for further analysis and improvements. This interface is for research purposes only." | |
"<br /><br />" | |
"HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.<br /><br />" | |
'<a href="https://join.slack.com/t/laiyerai/shared_invite/zt-28jv3ci39-sVxXrLs3rQdaN3mIl9IT~w">Join our Slack community to discuss LLM Security</a><br />' | |
'<a href="https://github.com/protectai/llm-guard">Secure your LLM interactions with LLM Guard</a>', | |
examples=[ | |
[ | |
example, | |
False, | |
] | |
for example in examples | |
], | |
cache_examples=True, | |
allow_flagging="never", | |
concurrency_limit=1, | |
).launch(server_name=args.url, server_port=args.port) | |