Skip to content

Commit

Permalink
load thread model from a json file
Browse files Browse the repository at this point in the history
  • Loading branch information
Jan Waś committed Sep 11, 2020
1 parent 2d96688 commit fa510ca
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pytm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
__all__ = ['Element', 'Server', 'ExternalEntity', 'Datastore', 'Actor', 'Process', 'SetOfProcesses', 'Dataflow', 'Boundary', 'TM', 'Action', 'Lambda', 'Threat']
__all__ = ['Element', 'Server', 'ExternalEntity', 'Datastore', 'Actor', 'Process', 'SetOfProcesses', 'Dataflow', 'Boundary', 'TM', 'Action', 'Lambda', 'Threat', 'load', 'loads']

from .pytm import Element, Server, ExternalEntity, Dataflow, Datastore, Actor, Process, SetOfProcesses, Boundary, TM, Action, Lambda, Threat
from .json import load, loads

104 changes: 104 additions & 0 deletions pytm/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import json
import sys

from .pytm import (
TM,
Boundary,
Element,
Dataflow,
Server,
ExternalEntity,
Datastore,
Actor,
Process,
SetOfProcesses,
Action,
Lambda,
)


def loads(s):
result = json.loads(s, object_hook=decode)
if not isinstance(result, TM):
raise ValueError("Failed to decode JSON input as TM")
return result


def load(fp):
result = json.load(fp, object_hook=decode)
if not isinstance(result, TM):
raise ValueError("Failed to decode JSON input as TM")
return result


def decode(data):
if "elements" not in data and "flows" not in data and "boundaries" not in data:
return data

boundaries = decode_boundaries(data.pop("boundaries", []))
elements = decode_elements(data.pop("elements", []), boundaries)
decode_flows(data.pop("flows", []), elements)

if "name" not in data:
raise ValueError("name property missing for threat model")
if "onDuplicates" in data:
data["onDuplicates"] = Action(data["onDuplicates"])
return TM(data.pop("name"), **data)


def decode_boundaries(flat):
boundaries = {}
refs = {}
for i, e in enumerate(flat):
name = e.pop("name", None)
if name is None:
raise ValueError(f"name property missing in boundary {i}")
if "inBoundary" in e:
refs[name] = e.pop("inBoundary")
e = Boundary(name, **e)
boundaries[name] = e

# do a second pass to resolve self-references
for b in boundaries.values():
if b.name not in refs:
continue
b.inBoundary = boundaries[refs[b.name]]

return boundaries


def decode_elements(flat, boundaries):
elements = {}
for i, e in enumerate(flat):
klass = getattr(sys.modules[__name__], e.pop("__class__", "Asset"))
name = e.pop("name", None)
if name is None:
raise ValueError(f"name property missing in element {i}")
if "inBoundary" in e:
if e["inBoundary"] not in boundaries:
raise ValueError(
f"element {name} references invalid boundary {e['inBoundary']}"
)
e["inBoundary"] = boundaries[e["inBoundary"]]
e = klass(name, **e)
elements[name] = e

return elements


def decode_flows(flat, elements):
for i, e in enumerate(flat):
name = e.pop("name", None)
if name is None:
raise ValueError(f"name property missing in dataflow {i}")
if "source" not in e:
raise ValueError(f"dataflow {name} is missing source property")
if e["source"] not in elements:
raise ValueError(f"dataflow {name} references invalid source {e['source']}")
source = elements[e.pop("source")]
if "sink" not in e:
raise ValueError(f"dataflow {name} is missing sink property")
if e["sink"] not in elements:
raise ValueError(f"dataflow {name} references invalid sink {e['sink']}")
sink = elements[e.pop("sink")]
Dataflow(source, sink, name, **e)
54 changes: 54 additions & 0 deletions tests/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"name": "my test tm",
"description": "aaa",
"isOrdered": true,
"onDuplicates": "IGNORE",
"boundaries": [
{
"name": "Internet"
},
{
"name": "Server/DB"
}
],
"elements": [
{
"__class__": "Actor",
"name": "User",
"inBoundary": "Internet"
},
{
"__class__": "Server",
"name": "Web Server"
},
{
"__class__": "Datastore",
"name": "SQL Database",
"inBoundary": "Server/DB"
}
],
"flows": [
{
"name": "Request",
"source": "User",
"sink": "Web Server",
"note": "bbb"
},
{
"name": "Insert",
"source": "Web Server",
"sink": "SQL Database",
"note": "ccc"
},
{
"name": "Select",
"source": "SQL Database",
"sink": "Web Server"
},
{
"name": "Response",
"source": "Web Server",
"sink": "User"
}
]
}
31 changes: 31 additions & 0 deletions tests/test_pytmfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Process,
Server,
Threat,
loads
)
from pytm.pytm import to_serializable

Expand Down Expand Up @@ -222,6 +223,36 @@ def test_json_dumps(self):
self.maxDiff = None
self.assertEqual(output, expected)

def test_json_loads(self):
random.seed(0)
dir_path = os.path.dirname(os.path.realpath(__file__))
with open(os.path.join(dir_path, 'input.json')) as x:
contents = x.read().strip()

TM.reset()
tm = loads(contents)
self.assertTrue(tm.check())

self.maxDiff = None
self.assertEqual([b.name for b in tm._boundaries], ["Internet", "Server/DB"])
self.assertEqual(
[e.name for e in tm._elements],
[
"Internet",
"Server/DB",
"User",
"Web Server",
"SQL Database",
"Request",
"Insert",
"Select",
"Response",
],
)
self.assertEqual(
[f.name for f in tm._flows], ["Request", "Insert", "Select", "Response"]
)


class Testpytm(unittest.TestCase):
# Test for all the threats in threats.py - test Threat.apply() function
Expand Down

0 comments on commit fa510ca

Please sign in to comment.