diff --git a/src/attack_flow/cli.py b/src/attack_flow/cli.py index dfffad7d..09109759 100644 --- a/src/attack_flow/cli.py +++ b/src/attack_flow/cli.py @@ -97,7 +97,7 @@ def graphviz(args): converted = attack_flow.graphviz.convert_attack_tree(flow_bundle) else: converted = attack_flow.graphviz.convert_attack_flow(flow_bundle) - + with open(args.output, "w") as out: out.write(converted) return 0 diff --git a/src/attack_flow/graphviz.py b/src/attack_flow/graphviz.py index 238cef61..08901a3d 100644 --- a/src/attack_flow/graphviz.py +++ b/src/attack_flow/graphviz.py @@ -18,6 +18,7 @@ def label_escape(text): return graphviz.escape(html.escape(text)) + def convert_attack_flow(bundle): """ Convert an Attack Flow STIX bundle into Graphviz format. @@ -69,6 +70,7 @@ def convert_attack_flow(bundle): return gv.source + def convert_attack_tree(bundle): """ Convert an Attack Flow STIX bundle into Graphviz format. @@ -77,7 +79,7 @@ def convert_attack_tree(bundle): :rtype: str """ - gv = graphviz.Digraph(graph_attr={'rankdir':'BT'}) + gv = graphviz.Digraph(graph_attr={"rankdir": "BT"}) gv.body = _get_body_label(bundle) ignored_ids = get_viz_ignored_ids(bundle) @@ -85,10 +87,14 @@ def convert_attack_tree(bundle): id_to_remove = [] ids = [] - for i,o in enumerate(objects): + for i, o in enumerate(objects): if o.type == "attack-operator": id_to_remove.append( - {"id": o.id, "prev_id": objects[i-1].id, "next_id":o.effect_refs[0], "type": o.operator + { + "id": o.id, + "prev_id": objects[i - 1].id, + "next_id": o.effect_refs[0], + "type": o.operator, } ) @@ -96,22 +102,23 @@ def convert_attack_tree(bundle): objects = [item for item in objects if item.id not in ids] new_operator_ids = [i["next_id"] for i in id_to_remove] for operator in id_to_remove: - for i,o in enumerate(objects): - if o.type=="relationship" and o.source_ref == operator["id"]: - o.source_ref = operator.prev_id - if o.type=="relationship" and o.target_ref == operator["id"]: - o.target_ref = operator.next_id - if o.get("effect_refs") and operator["id"] in o.effect_refs: - for i,j in enumerate(o.effect_refs): - if j == operator["id"]: - o.effect_refs[i] = operator["next_id"] - + for i, o in enumerate(objects): + if o.type == "relationship" and o.source_ref == operator["id"]: + o.source_ref = operator.prev_id + if o.type == "relationship" and o.target_ref == operator["id"]: + o.target_ref = operator.next_id + if o.get("effect_refs") and operator["id"] in o.effect_refs: + for i, j in enumerate(o.effect_refs): + if j == operator["id"]: + o.effect_refs[i] = operator["next_id"] for o in objects: logger.debug("Processing object id=%s", o.id) if o.type == "attack-action": if o.id in new_operator_ids: - operator_type = [item["type"] for item in id_to_remove if item["next_id"] == o.id][0] + operator_type = [ + item["type"] for item in id_to_remove if item["next_id"] == o.id + ][0] gv.node( o.id, label=_get_operator_label(o, operator_type), @@ -191,6 +198,8 @@ def _get_action_label(action): ">", ] ) + + def _get_attack_tree_action_label(action): """ Generate the GraphViz label for an action node as a table. @@ -217,6 +226,7 @@ def _get_attack_tree_action_label(action): ] ) + def _get_asset_label(asset): """ Generate the GraphViz label for an asset node as a table. @@ -283,6 +293,7 @@ def _get_condition_label(condition): ] ) + def _get_operator_label(action, operator_type): """ Generate the GraphViz label for an action node as a table. @@ -311,4 +322,4 @@ def _get_operator_label(action, operator_type): f'