File size: 4,032 Bytes
c16aa6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from dotenv import load_dotenv

load_dotenv(override=True)

import re
import os
import json
from typing import List, Dict, Any
import pandas as pd
import gradio as gr


from risk_atlas_nexus.blocks.inference import WMLInferenceEngine
from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams
from risk_atlas_nexus.library import RiskAtlasNexus

from functools import lru_cache

# Load the taxonomies
ran = RiskAtlasNexus()


@lru_cache
def risk_identifier(usecase: str, 
                    model_name_or_path: str = "ibm/granite-20b-code-instruct", 
                    taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame:

    # inference_engine = WMLInferenceEngine(
    #     model_name_or_path= model_name_or_path,
    #     credentials={
    #         "api_key": os.environ["WML_API_KEY"],
    #         "api_url": os.environ["WML_API_URL"],
    #         "project_id": os.environ["WML_PROJECT_ID"],
    #     },
    #     parameters=WMLInferenceEngineParams(
    #         max_new_tokens=100, decoding_method="greedy", repetition_penalty=1
    #     ),  # type: ignore
    # )

    # risks = ran.identify_risks_from_usecase(
    #     usecase=usecase,
    #     inference_engine=inference_engine,
    #     taxonomy=taxonomy,
    # )
    risks = ' ["Harmful code generation", "Hallucination", "Harmful output", "Toxic output", "Spreading toxicity", "Spreading disinformation", "Nonconsensual use", "Non-disclosure", "Data contamination", "Data acquisition restrictions", "Data usage rights restrictions", "Confidential data in prompt", "Confidential information in data", "Personal information in prompt", "Personal information in data", "IP information in prompt",'


    if isinstance(risks, str):
        # Translate LLM output to Risk catalog
        out = []
        try:
            risks = json.loads(risks)
        except json.JSONDecodeError:
            
            # Fallback to regex - will skip any partial categories
            risks = re.findall(r'"(.*?)"', risks)

        for risk in risks:
            k = list(filter(lambda r: r.name == risk, ran._ontology.risks)) # type: ignore
            out+= ([i.model_dump() for i in k])

    elif isinstance(risks, List):
        # FIXME: assumes that the output is structured - not sure if that's correct.
        out = risks
    
    #out_df = pd.DataFrame(out)
        
    #return out_df
    return gr.State(out), gr.Dataset(samples=[i['id'] for i in out], 
                                     sample_labels=[i['name'] for i in out], samples_per_page=50, visible=True, label="Estimated by an LLM.")
    

@lru_cache
def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Dataset, gr.Dataset]:
    
    related_risk_ids = ran.get_related_risk_ids_by_risk_id(riskid)
    action_ids = []

    if taxonomy == "ibm-risk-atlas":
        # look for actions associated with related risks    
        if related_risk_ids:
            for i in related_risk_ids:
                rai = ran.get_risk_actions_by_risk_id(i)
                if rai:
                    action_ids += rai
    
        else:
            action_ids = []
    else:
        # Use only actions related to primary risks
        action_ids = ran.get_risk_actions_by_risk_id(riskid)
    
    # Sanitize outputs
    if not related_risk_ids:
        label = "No related risks found."
        samples = None
        sample_labels = None
    else:
        label = "Related risks"
        samples = related_risk_ids
        sample_labels = [i.name for i in ran.get_related_risks_by_risk_id(riskid)]

    if not action_ids:
        alabel = "No mitigations found."
        asamples = None
        asample_labels = None
    else:
        alabel = ""
        asamples = action_ids
        asample_labels = [ran.get_action_by_id(i).description for i in action_ids]


    return (gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True),
            gr.Dataset(samples=asamples, label=alabel, sample_labels=asample_labels, visible=True))