Skip to content

Commit

Permalink
lint code and fix unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
allisonrobbins committed Jun 19, 2024
1 parent a78aa7d commit 2f32e5d
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
2 changes: 1 addition & 1 deletion src/attack_flow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def graphviz(args):
converted = attack_flow.graphviz.convert_attack_tree(flow_bundle)
else:
converted = attack_flow.graphviz.convert_attack_flow(flow_bundle)

with open(args.output, "w") as out:
out.write(converted)
return 0
Expand Down
41 changes: 26 additions & 15 deletions src/attack_flow/graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
def label_escape(text):
return graphviz.escape(html.escape(text))


def convert_attack_flow(bundle):
"""
Convert an Attack Flow STIX bundle into Graphviz format.
Expand Down Expand Up @@ -69,6 +70,7 @@ def convert_attack_flow(bundle):

return gv.source


def convert_attack_tree(bundle):
"""
Convert an Attack Flow STIX bundle into Graphviz format.
Expand All @@ -77,41 +79,46 @@ def convert_attack_tree(bundle):
:rtype: str
"""

gv = graphviz.Digraph(graph_attr={'rankdir':'BT'})
gv = graphviz.Digraph(graph_attr={"rankdir": "BT"})
gv.body = _get_body_label(bundle)
ignored_ids = get_viz_ignored_ids(bundle)

objects = bundle.objects

id_to_remove = []
ids = []
for i,o in enumerate(objects):
for i, o in enumerate(objects):
if o.type == "attack-operator":
id_to_remove.append(
{"id": o.id, "prev_id": objects[i-1].id, "next_id":o.effect_refs[0], "type": o.operator
{
"id": o.id,
"prev_id": objects[i - 1].id,
"next_id": o.effect_refs[0],
"type": o.operator,
}
)

ids = [i["id"] for i in id_to_remove]
objects = [item for item in objects if item.id not in ids]
new_operator_ids = [i["next_id"] for i in id_to_remove]
for operator in id_to_remove:
for i,o in enumerate(objects):
if o.type=="relationship" and o.source_ref == operator["id"]:
o.source_ref = operator.prev_id
if o.type=="relationship" and o.target_ref == operator["id"]:
o.target_ref = operator.next_id
if o.get("effect_refs") and operator["id"] in o.effect_refs:
for i,j in enumerate(o.effect_refs):
if j == operator["id"]:
o.effect_refs[i] = operator["next_id"]

for i, o in enumerate(objects):
if o.type == "relationship" and o.source_ref == operator["id"]:
o.source_ref = operator.prev_id
if o.type == "relationship" and o.target_ref == operator["id"]:
o.target_ref = operator.next_id
if o.get("effect_refs") and operator["id"] in o.effect_refs:
for i, j in enumerate(o.effect_refs):
if j == operator["id"]:
o.effect_refs[i] = operator["next_id"]

for o in objects:
logger.debug("Processing object id=%s", o.id)
if o.type == "attack-action":
if o.id in new_operator_ids:
operator_type = [item["type"] for item in id_to_remove if item["next_id"] == o.id][0]
operator_type = [
item["type"] for item in id_to_remove if item["next_id"] == o.id
][0]
gv.node(
o.id,
label=_get_operator_label(o, operator_type),
Expand Down Expand Up @@ -191,6 +198,8 @@ def _get_action_label(action):
"</TABLE>>",
]
)


def _get_attack_tree_action_label(action):
"""
Generate the GraphViz label for an action node as a table.
Expand All @@ -217,6 +226,7 @@ def _get_attack_tree_action_label(action):
]
)


def _get_asset_label(asset):
"""
Generate the GraphViz label for an asset node as a table.
Expand Down Expand Up @@ -283,6 +293,7 @@ def _get_condition_label(condition):
]
)


def _get_operator_label(action, operator_type):
"""
Generate the GraphViz label for an action node as a table.
Expand Down Expand Up @@ -311,4 +322,4 @@ def _get_operator_label(action, operator_type):
f'<TR><TD ALIGN="LEFT" BALIGN="LEFT"><B>Confidence</B></TD><TD ALIGN="LEFT" BALIGN="LEFT">{confidence}</TD></TR>',
"</TABLE>>",
]
)
)
5 changes: 3 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ def test_doc_schema(schema_mock, generate_mock, insert_mock, exit_mock):


@patch("sys.exit")
@patch("attack_flow.graphviz.convert")
@patch("attack_flow.graphviz.convert_attack_flow")
@patch("attack_flow.model.load_attack_flow_bundle")
def test_graphviz(load_mock, convert_mock, exit_mock):
"""
Test that the script parses a JSON file and passes the resulting object
to convert().
to convert_attack_flow().
"""
convert_mock.return_value = dedent(
r"""\
Expand All @@ -111,6 +111,7 @@ def test_graphviz(load_mock, convert_mock, exit_mock):
)
bundle = stix2.Bundle()
load_mock.return_value = bundle
print("printing resp bundle ", bundle)
with NamedTemporaryFile() as flow, NamedTemporaryFile() as graphviz:
sys.argv = ["af", "graphviz", flow.name, graphviz.name]
runpy.run_module("attack_flow.cli", run_name="__main__")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_graphviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_convert_attack_flow_to_graphviz():
output = attack_flow.graphviz.convert(get_flow_bundle())
output = attack_flow.graphviz.convert_attack_flow(get_flow_bundle())
assert output == dedent(
"""\
digraph {
Expand Down

0 comments on commit 2f32e5d

Please sign in to comment.