risk-atlas-nexus / executor.py
rahulnair23's picture
initial commit
c16aa6d
raw
history blame
4.03 kB
from dotenv import load_dotenv
load_dotenv(override=True)
import re
import os
import json
from typing import List, Dict, Any
import pandas as pd
import gradio as gr
from risk_atlas_nexus.blocks.inference import WMLInferenceEngine
from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams
from risk_atlas_nexus.library import RiskAtlasNexus
from functools import lru_cache
# Load the taxonomies
ran = RiskAtlasNexus()
@lru_cache
def risk_identifier(usecase: str,
model_name_or_path: str = "ibm/granite-20b-code-instruct",
taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame:
# inference_engine = WMLInferenceEngine(
# model_name_or_path= model_name_or_path,
# credentials={
# "api_key": os.environ["WML_API_KEY"],
# "api_url": os.environ["WML_API_URL"],
# "project_id": os.environ["WML_PROJECT_ID"],
# },
# parameters=WMLInferenceEngineParams(
# max_new_tokens=100, decoding_method="greedy", repetition_penalty=1
# ), # type: ignore
# )
# risks = ran.identify_risks_from_usecase(
# usecase=usecase,
# inference_engine=inference_engine,
# taxonomy=taxonomy,
# )
risks = ' ["Harmful code generation", "Hallucination", "Harmful output", "Toxic output", "Spreading toxicity", "Spreading disinformation", "Nonconsensual use", "Non-disclosure", "Data contamination", "Data acquisition restrictions", "Data usage rights restrictions", "Confidential data in prompt", "Confidential information in data", "Personal information in prompt", "Personal information in data", "IP information in prompt",'
if isinstance(risks, str):
# Translate LLM output to Risk catalog
out = []
try:
risks = json.loads(risks)
except json.JSONDecodeError:
# Fallback to regex - will skip any partial categories
risks = re.findall(r'"(.*?)"', risks)
for risk in risks:
k = list(filter(lambda r: r.name == risk, ran._ontology.risks)) # type: ignore
out+= ([i.model_dump() for i in k])
elif isinstance(risks, List):
# FIXME: assumes that the output is structured - not sure if that's correct.
out = risks
#out_df = pd.DataFrame(out)
#return out_df
return gr.State(out), gr.Dataset(samples=[i['id'] for i in out],
sample_labels=[i['name'] for i in out], samples_per_page=50, visible=True, label="Estimated by an LLM.")
@lru_cache
def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Dataset, gr.Dataset]:
related_risk_ids = ran.get_related_risk_ids_by_risk_id(riskid)
action_ids = []
if taxonomy == "ibm-risk-atlas":
# look for actions associated with related risks
if related_risk_ids:
for i in related_risk_ids:
rai = ran.get_risk_actions_by_risk_id(i)
if rai:
action_ids += rai
else:
action_ids = []
else:
# Use only actions related to primary risks
action_ids = ran.get_risk_actions_by_risk_id(riskid)
# Sanitize outputs
if not related_risk_ids:
label = "No related risks found."
samples = None
sample_labels = None
else:
label = "Related risks"
samples = related_risk_ids
sample_labels = [i.name for i in ran.get_related_risks_by_risk_id(riskid)]
if not action_ids:
alabel = "No mitigations found."
asamples = None
asample_labels = None
else:
alabel = ""
asamples = action_ids
asample_labels = [ran.get_action_by_id(i).description for i in action_ids]
return (gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True),
gr.Dataset(samples=asamples, label=alabel, sample_labels=asample_labels, visible=True))