File size: 3,573 Bytes
c16aa6d
 
 
 
 
 
9c64352
c16aa6d
 
 
 
 
 
 
 
 
 
 
 
 
9c64352
c16aa6d
 
 
 
 
 
 
9c64352
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c16aa6d
 
9c64352
 
 
c16aa6d
 
 
9c64352
 
 
 
 
 
 
c16aa6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c64352
c16aa6d
9c64352
c16aa6d
 
 
 
 
9c64352
c16aa6d
9c64352
c16aa6d
9c64352
 
 
c16aa6d
 
9c64352
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from dotenv import load_dotenv

load_dotenv(override=True)

import re
import os
import pandas as pd
import json
from typing import List, Dict, Any
import pandas as pd
import gradio as gr


from risk_atlas_nexus.blocks.inference import WMLInferenceEngine
from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams
from risk_atlas_nexus.library import RiskAtlasNexus

from functools import lru_cache

# Load the taxonomies
ran = RiskAtlasNexus() # type: ignore


@lru_cache
def risk_identifier(usecase: str, 
                    model_name_or_path: str = "ibm/granite-20b-code-instruct", 
                    taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame:

    inference_engine = WMLInferenceEngine(
        model_name_or_path= model_name_or_path,
        credentials={
            "api_key": os.environ["WML_API_KEY"],
            "api_url": os.environ["WML_API_URL"],
            "project_id": os.environ["WML_PROJECT_ID"],
        },
        parameters=WMLInferenceEngineParams(
            max_new_tokens=150, decoding_method="greedy", repetition_penalty=1
        ),  # type: ignore
    )

    risks = ran.identify_risks_from_usecases( # type: ignore
        usecases=[usecase],
        inference_engine=inference_engine,
        taxonomy=taxonomy,
    )[0]

    sample_labels = [r.name if r else r.id for r in risks]

    out_sec = gr.Markdown("""<h2> Potential Risks </h2> """)
        
    #return out_df
    return out_sec, gr.State(risks), gr.Dataset(samples=[r.id for r in risks], 
                                     sample_labels=sample_labels, 
                                     samples_per_page=50, visible=True, label="Estimated by an LLM.")
    

@lru_cache
def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Dataset, gr.DataFrame]:
    """
    For a specific risk (riskid), returns
    (a) related risks - as a dataset
    (b) mitigations

    """
    
    related_risk_ids = ran.get_related_risk_ids_by_risk_id(riskid)
    action_ids = []

    if taxonomy == "ibm-risk-atlas":
        # look for actions associated with related risks    
        if related_risk_ids:
            for i in related_risk_ids:
                rai = ran.get_risk_actions_by_risk_id(i)
                if rai:
                    action_ids += rai
    
        else:
            action_ids = []
    else:
        # Use only actions related to primary risks
        action_ids = ran.get_risk_actions_by_risk_id(riskid)
    
    # Sanitize outputs
    if not related_risk_ids:
        label = "No related risks found."
        samples = None
        sample_labels = None
    else:
        label = f"Risks from other taxonomies related to {riskid}"
        samples = related_risk_ids
        sample_labels = [i.name for i in ran.get_related_risks_by_risk_id(riskid)] #type: ignore

    if not action_ids:
        alabel = "No mitigations found."
        asamples = None
        asample_labels = None
        mitdf = pd.DataFrame()
    else:
        alabel = f"Mitigation actions related to risk {riskid}."
        asamples = action_ids
        asample_labels = [ran.get_action_by_id(i).description for i in asamples] # type: ignore
        asample_name = [ran.get_action_by_id(i).name for i in asamples] #type: ignore
        mitdf = pd.DataFrame({"Mitigation": asample_name, "Description": asample_labels})

    return (gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True),
            gr.DataFrame(mitdf, wrap=True, show_copy_button=True, show_search="search", label=alabel, visible=True))