from dotenv import load_dotenv load_dotenv(override=True) import re import os import json from typing import List, Dict, Any import pandas as pd import gradio as gr from risk_atlas_nexus.blocks.inference import WMLInferenceEngine from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams from risk_atlas_nexus.library import RiskAtlasNexus from functools import lru_cache # Load the taxonomies ran = RiskAtlasNexus() @lru_cache def risk_identifier(usecase: str, model_name_or_path: str = "ibm/granite-20b-code-instruct", taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame: # inference_engine = WMLInferenceEngine( # model_name_or_path= model_name_or_path, # credentials={ # "api_key": os.environ["WML_API_KEY"], # "api_url": os.environ["WML_API_URL"], # "project_id": os.environ["WML_PROJECT_ID"], # }, # parameters=WMLInferenceEngineParams( # max_new_tokens=100, decoding_method="greedy", repetition_penalty=1 # ), # type: ignore # ) # risks = ran.identify_risks_from_usecase( # usecase=usecase, # inference_engine=inference_engine, # taxonomy=taxonomy, # ) risks = ' ["Harmful code generation", "Hallucination", "Harmful output", "Toxic output", "Spreading toxicity", "Spreading disinformation", "Nonconsensual use", "Non-disclosure", "Data contamination", "Data acquisition restrictions", "Data usage rights restrictions", "Confidential data in prompt", "Confidential information in data", "Personal information in prompt", "Personal information in data", "IP information in prompt",' if isinstance(risks, str): # Translate LLM output to Risk catalog out = [] try: risks = json.loads(risks) except json.JSONDecodeError: # Fallback to regex - will skip any partial categories risks = re.findall(r'"(.*?)"', risks) for risk in risks: k = list(filter(lambda r: r.name == risk, ran._ontology.risks)) # type: ignore out+= ([i.model_dump() for i in k]) elif isinstance(risks, List): # FIXME: assumes that the output is structured - not sure if that's correct. out = risks #out_df = pd.DataFrame(out) #return out_df return gr.State(out), gr.Dataset(samples=[i['id'] for i in out], sample_labels=[i['name'] for i in out], samples_per_page=50, visible=True, label="Estimated by an LLM.") @lru_cache def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Dataset, gr.Dataset]: related_risk_ids = ran.get_related_risk_ids_by_risk_id(riskid) action_ids = [] if taxonomy == "ibm-risk-atlas": # look for actions associated with related risks if related_risk_ids: for i in related_risk_ids: rai = ran.get_risk_actions_by_risk_id(i) if rai: action_ids += rai else: action_ids = [] else: # Use only actions related to primary risks action_ids = ran.get_risk_actions_by_risk_id(riskid) # Sanitize outputs if not related_risk_ids: label = "No related risks found." samples = None sample_labels = None else: label = "Related risks" samples = related_risk_ids sample_labels = [i.name for i in ran.get_related_risks_by_risk_id(riskid)] if not action_ids: alabel = "No mitigations found." asamples = None asample_labels = None else: alabel = "" asamples = action_ids asample_labels = [ran.get_action_by_id(i).description for i in action_ids] return (gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True), gr.Dataset(samples=asamples, label=alabel, sample_labels=asample_labels, visible=True))