risk-atlas-nexus / executor.py
rahulnair23's picture
Merge branch 'main' of hf.co:spaces/ibm/risk-atlas-nexus
37c68ce
raw
history blame
3.99 kB
from ast import Attribute
from dotenv import load_dotenv
load_dotenv(override=True)
import re
import os
import pandas as pd
import json
from typing import List, Dict, Any
import pandas as pd
import gradio as gr
from risk_atlas_nexus.blocks.inference import WMLInferenceEngine
from risk_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams
from risk_atlas_nexus.library import RiskAtlasNexus
from functools import lru_cache
# Load the taxonomies
ran = RiskAtlasNexus() # type: ignore
@lru_cache
def risk_identifier(usecase: str,
model_name_or_path: str = "ibm/granite-20b-code-instruct",
taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame:
inference_engine = WMLInferenceEngine(
model_name_or_path= model_name_or_path,
credentials={
"api_key": os.environ["WML_API_KEY"],
"api_url": os.environ["WML_API_URL"],
"project_id": os.environ["WML_PROJECT_ID"],
},
parameters=WMLInferenceEngineParams(
max_new_tokens=150, decoding_method="greedy", repetition_penalty=1
), # type: ignore
)
risks = ran.identify_risks_from_usecases( # type: ignore
usecases=[usecase],
inference_engine=inference_engine,
taxonomy=taxonomy,
)[0]
sample_labels = [r.name if r else r.id for r in risks]
out_sec = gr.Markdown("""<h2> Potential Risks </h2> """)
#return out_df
return out_sec, gr.State(risks), gr.Dataset(samples=[r.id for r in risks],
sample_labels=sample_labels,
samples_per_page=50, visible=True, label="Estimated by an LLM.")
@lru_cache
def mitigations(riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr.DataFrame, gr.Markdown]:
"""
For a specific risk (riskid), returns
(a) a risk description
(b) related risks - as a dataset
(c) mitigations
"""
try:
risk_desc = ran.get_risk(id=riskid).description # type: ignore
risk_sec = f"<h3>Description: </h3> {risk_desc}"
except AttributeError:
risk_sec = ""
related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)]
action_ids = []
if taxonomy == "ibm-risk-atlas":
# look for actions associated with related risks
if related_risk_ids:
for i in related_risk_ids:
rai = ran.get_related_actions(id=i)
if rai:
action_ids += rai
else:
action_ids = []
else:
# Use only actions related to primary risks
action_ids = ran.get_related_actions(id=riskid)
# Sanitize outputs
if not related_risk_ids:
label = "No related risks found."
samples = None
sample_labels = None
else:
label = f"Risks from other taxonomies related to {riskid}"
samples = related_risk_ids
sample_labels = [i.name for i in ran.get_related_risks(id=riskid)] #type: ignore
if not action_ids:
alabel = "No mitigations found."
asamples = None
asample_labels = None
mitdf = pd.DataFrame()
else:
alabel = f"Mitigation actions related to risk {riskid}."
asamples = action_ids
asample_labels = [ran.get_action_by_id(i).description for i in asamples] # type: ignore
asample_name = [ran.get_action_by_id(i).name for i in asamples] #type: ignore
mitdf = pd.DataFrame({"Mitigation": asample_name, "Description": asample_labels})
status = gr.Markdown(" ") if len(mitdf) > 0 else gr.Markdown("No mitigations found.")
return (gr.Markdown(risk_sec),
gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True),
gr.DataFrame(mitdf, wrap=True, show_copy_button=True, show_search="search", label=alabel, visible=True),
status)