Spaces:
Running
Running
from dotenv import load_dotenv | |
load_dotenv(override=True) | |
import re | |
import os | |
import json | |
from typing import List, Dict, Any | |
import pandas as pd | |
import gradio as gr | |
from risk_atlas_nexus.blocks.inference import WMLInferenceEngine | |
from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams | |
from risk_atlas_nexus.library import RiskAtlasNexus | |
from functools import lru_cache | |
# Load the taxonomies | |
ran = RiskAtlasNexus() | |
def risk_identifier(usecase: str, | |
model_name_or_path: str = "ibm/granite-20b-code-instruct", | |
taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame: | |
# inference_engine = WMLInferenceEngine( | |
# model_name_or_path= model_name_or_path, | |
# credentials={ | |
# "api_key": os.environ["WML_API_KEY"], | |
# "api_url": os.environ["WML_API_URL"], | |
# "project_id": os.environ["WML_PROJECT_ID"], | |
# }, | |
# parameters=WMLInferenceEngineParams( | |
# max_new_tokens=100, decoding_method="greedy", repetition_penalty=1 | |
# ), # type: ignore | |
# ) | |
# risks = ran.identify_risks_from_usecase( | |
# usecase=usecase, | |
# inference_engine=inference_engine, | |
# taxonomy=taxonomy, | |
# ) | |
risks = ' ["Harmful code generation", "Hallucination", "Harmful output", "Toxic output", "Spreading toxicity", "Spreading disinformation", "Nonconsensual use", "Non-disclosure", "Data contamination", "Data acquisition restrictions", "Data usage rights restrictions", "Confidential data in prompt", "Confidential information in data", "Personal information in prompt", "Personal information in data", "IP information in prompt",' | |
if isinstance(risks, str): | |
# Translate LLM output to Risk catalog | |
out = [] | |
try: | |
risks = json.loads(risks) | |
except json.JSONDecodeError: | |
# Fallback to regex - will skip any partial categories | |
risks = re.findall(r'"(.*?)"', risks) | |
for risk in risks: | |
k = list(filter(lambda r: r.name == risk, ran._ontology.risks)) # type: ignore | |
out+= ([i.model_dump() for i in k]) | |
elif isinstance(risks, List): | |
# FIXME: assumes that the output is structured - not sure if that's correct. | |
out = risks | |
#out_df = pd.DataFrame(out) | |
#return out_df | |
return gr.State(out), gr.Dataset(samples=[i['id'] for i in out], | |
sample_labels=[i['name'] for i in out], samples_per_page=50, visible=True, label="Estimated by an LLM.") | |
def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Dataset, gr.Dataset]: | |
related_risk_ids = ran.get_related_risk_ids_by_risk_id(riskid) | |
action_ids = [] | |
if taxonomy == "ibm-risk-atlas": | |
# look for actions associated with related risks | |
if related_risk_ids: | |
for i in related_risk_ids: | |
rai = ran.get_risk_actions_by_risk_id(i) | |
if rai: | |
action_ids += rai | |
else: | |
action_ids = [] | |
else: | |
# Use only actions related to primary risks | |
action_ids = ran.get_risk_actions_by_risk_id(riskid) | |
# Sanitize outputs | |
if not related_risk_ids: | |
label = "No related risks found." | |
samples = None | |
sample_labels = None | |
else: | |
label = "Related risks" | |
samples = related_risk_ids | |
sample_labels = [i.name for i in ran.get_related_risks_by_risk_id(riskid)] | |
if not action_ids: | |
alabel = "No mitigations found." | |
asamples = None | |
asample_labels = None | |
else: | |
alabel = "" | |
asamples = action_ids | |
asample_labels = [ran.get_action_by_id(i).description for i in action_ids] | |
return (gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True), | |
gr.Dataset(samples=asamples, label=alabel, sample_labels=asample_labels, visible=True)) | |