Spaces:
Running
Running
| import os | |
| import pandas as pd | |
| import gradio as gr | |
| import datetime | |
| from pathlib import Path | |
| import json | |
| from ai_atlas_nexus.blocks.inference import WMLInferenceEngine | |
| from ai_atlas_nexus.blocks.inference.params import WMLInferenceEngineParams | |
| from ai_atlas_nexus.library import AIAtlasNexus | |
| from functools import lru_cache, wraps | |
| from dotenv import load_dotenv | |
| load_dotenv(override=True) | |
| # Load the taxonomies | |
| ran = AIAtlasNexus() # type: ignore | |
| def clear_previous_risks(): | |
| return gr.Markdown("""<h2> Potential Risks </h2> """), [], gr.Dataset(samples=[], | |
| sample_labels=[], | |
| samples_per_page=50, visible=False), gr.DownloadButton("Download JSON", visible=False, ), "", gr.Dataset(samples=[], sample_labels=[], visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.Markdown(" "), gr.Markdown(" "), | |
| def clear_previous_mitigations(): | |
| return "", gr.Dataset(samples=[], sample_labels=[], visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.DataFrame([], wrap=True, show_copy_button=True, show_search="search", visible=False), gr.Markdown(" "), gr.Markdown(" ") | |
| def generate_subgraph(usecase, risk): | |
| lines =[f'```mermaid\n', '---\n' | |
| 'config:\n' | |
| ' theme: mc\n' | |
| ' layout: dagre\n' | |
| ' look: classic\n' | |
| '---\n' | |
| 'flowchart TB\n'] | |
| lines.append(f'uc_173@{{ label: "{usecase}" }} -- subClassOf --> AISystem["AISystem"]\n') | |
| lines.append(f'uc_173 -- hasRisk --> Risk2["{risk.name}"]\n') | |
| lines.append(f'Risk2 -- isPartOf --> {risk.isPartOf}\n') | |
| lines.append(f'Risk2 -- isDefinedByTaxonomy --> {risk.isDefinedByTaxonomy}\n') | |
| # add related risks | |
| rrs = ran.get_related_risks(id=risk.id) | |
| if len(rrs) > 0: | |
| r_risks = '' | |
| for rr in rrs: | |
| r_risks = r_risks + f'{rr.name}, ' | |
| lines.append(f'Risk2 -- hasRelatedRisks --> Risk3["{r_risks}"]\n') | |
| # add related evals | |
| revals = ran.get_related_evaluations(risk_id=risk.id) | |
| if len(revals) > 0: | |
| r_evals ='' | |
| for reval in revals: | |
| r_evals = r_evals + f'{reval.name}, ' | |
| lines.append(f'Risk2 -- hasAiEvaluations --> Risk4["{r_evals[:100]}"] \n') | |
| # add related mitigations | |
| rmits = get_controls_and_actions(risk.id, risk.isDefinedByTaxonomy) | |
| if len(rmits) > 0: | |
| r_mits = ', '.join(rmits) | |
| lines.append(f'Risk2 -- hasMitigations --> Risk5["{r_mits[:100]}"] \n') | |
| lines.append(f"```") | |
| diagram_string = "".join(lines) | |
| return gr.Markdown(value = diagram_string) | |
| def custom_lru_cache(maxsize=128, exclude_values=(None,[],[[]])): | |
| """ | |
| Make the LRU cache not cache result when empty result was returned | |
| """ | |
| def decorator(func): | |
| cached_func = lru_cache(maxsize=maxsize)(func) | |
| def wrapper(*args, **kwargs): | |
| result = cached_func(*args, **kwargs) | |
| # check for empty df of risks | |
| if result[2].constructor_args["samples"] in exclude_values: | |
| return func(*args, **kwargs) | |
| return result | |
| return wrapper | |
| return decorator | |
| def risk_identifier(usecase: str, | |
| model_name_or_path: str = "meta-llama/llama-3-3-70b-instruct", | |
| taxonomy: str = "ibm-risk-atlas"): # -> List[Dict[str, Any]]: #pd.DataFrame: | |
| downloadable = False | |
| inference_engine = WMLInferenceEngine( | |
| model_name_or_path= model_name_or_path, | |
| credentials={ | |
| "api_key": os.environ["WML_API_KEY"], | |
| "api_url": os.environ["WML_API_URL"], | |
| "project_id": os.environ["WML_PROJECT_ID"], | |
| }, | |
| parameters=WMLInferenceEngineParams( | |
| max_new_tokens=1000, decoding_method="greedy", repetition_penalty=1 | |
| ), # type: ignore | |
| ) | |
| risks_a = ran.identify_risks_from_usecases(# type: ignore | |
| usecases=[usecase], | |
| inference_engine=inference_engine, | |
| taxonomy=taxonomy, | |
| zero_shot_only=True, | |
| max_risk=5 | |
| ) | |
| risks = risks_a[0] | |
| sample_labels = [r.name if r else r.id for r in risks] | |
| out_sec = gr.Markdown("""<h2> Potential Risks </h2> """) | |
| # write out a JSON | |
| data = {'time': str(datetime.datetime.now(datetime.timezone.utc)), | |
| 'intent': usecase, | |
| 'model': model_name_or_path, | |
| 'taxonomy': taxonomy, | |
| 'risks': [json.loads(r.json()) for r in risks] | |
| } | |
| file_path = Path("static/download.json") | |
| with open(file_path, mode='w') as f: | |
| f.write(json.dumps(data, indent=4)) | |
| downloadable = True | |
| # return out_df | |
| return out_sec, gr.State(risks), gr.Dataset(samples=[r.id for r in risks], | |
| sample_labels=sample_labels, | |
| samples_per_page=50, visible=True, label="Estimated by an LLM."), gr.DownloadButton("Download JSON", "static/download.json", visible=(downloadable and len(risks) > 0)) | |
| def get_controls_and_actions(riskid, taxonomy): | |
| selected_risk = ran.get_risk(id=riskid) | |
| related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)] | |
| action_ids = [] | |
| control_ids =[] | |
| intrinsic_ids=[] | |
| if taxonomy == "ibm-risk-atlas": | |
| # look for actions associated with related risks | |
| if related_risk_ids: | |
| for i in related_risk_ids: | |
| rai = ran.get_related_actions(id=i) | |
| if rai: | |
| action_ids += rai | |
| rac = ran.get_related_risk_controls(id=i) | |
| if rac: | |
| control_ids += rac | |
| ran_intrinsics = ran.get_related_intrinsics(risk_id=i) | |
| if ran_intrinsics: | |
| intrinsic_ids += ran_intrinsics | |
| else: | |
| action_ids = [] | |
| control_ids = [] | |
| intrinsic_ids=[] | |
| else: | |
| # Use only actions related to primary risks | |
| action_ids = ran.get_related_actions(id=riskid) | |
| control_ids = ran.get_related_risk_controls(id=riskid) | |
| intrinsic_ids = ran.get_related_intrinsics(risk_id=riskid) | |
| return [ran.get_action_by_id(i).name for i in action_ids] + [ran.get_risk_control(i.id).name for i in control_ids] + [ran.get_intrinsic(i.id).name for i in intrinsic_ids]#type: ignore | |
| def mitigations(usecase: str, riskid: str, taxonomy: str) -> tuple[gr.Markdown, gr.Dataset, gr.DataFrame, gr.DataFrame, gr.Markdown, gr.Markdown]: | |
| """ | |
| For a specific risk (riskid), returns | |
| (a) a risk description | |
| (b) related risks - as a dataset | |
| (c) mitigations | |
| (d) related AI evaluations | |
| (e) A subgraph of risk to mitigations | |
| """ | |
| try: | |
| selected_risk = ran.get_risk(id=riskid) | |
| risk_desc = selected_risk.description # type: ignore | |
| risk_sec = f"<h3>Description: </h3> {risk_desc}" | |
| except AttributeError: | |
| risk_sec = "" | |
| related_risk_ids = [r.id for r in ran.get_related_risks(id=riskid)] | |
| related_ai_eval_ids = [ai_eval.id for ai_eval in ran.get_related_evaluations(risk_id=riskid)] | |
| action_ids = [] | |
| control_ids =[] | |
| intrinsic_ids=[] | |
| if taxonomy == "ibm-risk-atlas": | |
| # look for actions associated with related risks | |
| if related_risk_ids: | |
| for i in related_risk_ids: | |
| ran_actions = ran.get_related_actions(id=i) | |
| if ran_actions: | |
| action_ids += ran_actions | |
| ran_controls = ran.get_related_risk_controls(id=i) | |
| if ran_controls: | |
| control_ids += ran_controls | |
| ran_intrinsics = ran.get_related_intrinsics(risk_id=i) | |
| if ran_intrinsics: | |
| intrinsic_ids += ran_intrinsics | |
| else: | |
| action_ids = [] | |
| control_ids = [] | |
| intrinsic_ids=[] | |
| else: | |
| # Use only actions related to primary risks | |
| action_ids = ran.get_related_actions(id=riskid) | |
| control_ids = ran.get_related_risk_controls(id=riskid) | |
| intrinsic_ids = ran.get_related_intrinsics(risk_id=riskid) | |
| # Sanitize outputs | |
| if not related_risk_ids: | |
| label = "No related risks found." | |
| samples = None | |
| sample_labels = None | |
| else: | |
| label = f"Risks from other taxonomies related to {riskid}" | |
| samples = related_risk_ids | |
| sample_labels = [i.name for i in ran.get_related_risks(id=riskid)] #type: ignore | |
| if not action_ids and not control_ids and not intrinsic_ids: | |
| alabel = "No mitigations found." | |
| asamples = None | |
| asample_labels = None | |
| mitdf = pd.DataFrame() | |
| else: | |
| alabel = f"Mitigation actions and controls related to risk {riskid}." | |
| asamples = action_ids | |
| asamples_ctl = control_ids | |
| asamples_int = intrinsic_ids | |
| asample_labels = [ran.get_action_by_id(i).description for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl] + [ran.get_intrinsic(i.id).description for i in asamples_int]# type: ignore | |
| asample_name = [ran.get_action_by_id(i).name for i in asamples] + [ran.get_risk_control(i.id).name for i in asamples_ctl] + [ran.get_intrinsic(i.id).name for i in asamples_int] #type: ignore | |
| asample_types = ["Action" for i in asamples] + ["Control" for i in asamples_ctl] + ["Intrinsic" for i in asamples_int] | |
| mitdf = pd.DataFrame({"Type": asample_types, "Mitigation": asample_name, "Description": asample_labels}) | |
| if not related_ai_eval_ids: | |
| blabel = "No related AI evaluations found." | |
| bsamples = None | |
| bsample_labels = None | |
| aievalsdf = pd.DataFrame() | |
| else: | |
| blabel = f"AI Evaluations related to {riskid}" | |
| bsamples = related_ai_eval_ids | |
| bsample_labels = [ran.get_evaluation(i).description for i in bsamples] # type: ignore | |
| bsample_name = [ran.get_evaluation(i).name for i in bsamples] #type: ignore | |
| aievalsdf = pd.DataFrame({"AI Evaluation": bsample_name, "Description": bsample_labels}) | |
| status = gr.Markdown(" ") if len(mitdf) > 0 else gr.Markdown("No mitigations found.") | |
| fig = gr.Markdown(" ") if not selected_risk else generate_subgraph(usecase, selected_risk) | |
| return (gr.Markdown(risk_sec), | |
| gr.Dataset(samples=samples, label=label, sample_labels=sample_labels, visible=True), | |
| gr.DataFrame(mitdf, wrap=True, show_copy_button=True, show_search="search", label=alabel, visible=True), | |
| gr.DataFrame(aievalsdf, wrap=True, show_copy_button=True, show_search="search", label=blabel, visible=True), | |
| status, fig) | |