diff --git a/.gitignore b/.gitignore index 978a840..d3e4efe 100644 --- a/.gitignore +++ b/.gitignore @@ -124,3 +124,4 @@ tm_example.dot #Others plantuml.jar tm/ +/sqldump diff --git a/CHANGELOG.md b/CHANGELOG.md index 5739ddd..6773f86 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ ## New features +- Add JSON output [#102](https://github.com/izar/pytm/pull/102) - Use numbered dataflow labels in sequence diagram [#94](https://github.com/izar/pytm/pull/94) - Move authenticateDestination to base Element [#88](https://github.com/izar/pytm/pull/88) - Assign inputs and outputs to all elements [#89](https://github.com/izar/pytm/pull/89) diff --git a/pytm/__init__.py b/pytm/__init__.py index 03edb1f..6224c5c 100644 --- a/pytm/__init__.py +++ b/pytm/__init__.py @@ -1,5 +1,4 @@ -__all__ = ['Element', 'Server', 'ExternalEntity', 'Datastore', 'Actor', 'Process', 'SetOfProcesses', 'Dataflow', 'Boundary', 'TM', 'Action', 'Lambda', 'Threat', 'Classification', 'Data'] +__all__ = ['Element', 'Server', 'ExternalEntity', 'Datastore', 'Actor', 'Process', 'SetOfProcesses', 'Dataflow', 'Boundary', 'TM', 'Action', 'Lambda', 'Threat', 'Classification', 'Data', 'load', 'loads'] from .pytm import Element, Server, ExternalEntity, Dataflow, Datastore, Actor, Process, SetOfProcesses, Boundary, TM, Action, Lambda, Threat, Classification, Data - - +from .json import load, loads diff --git a/pytm/json.py b/pytm/json.py new file mode 100644 index 0000000..58d9c9f --- /dev/null +++ b/pytm/json.py @@ -0,0 +1,104 @@ +import json +import sys + +from .pytm import ( + TM, + Boundary, + Element, + Dataflow, + Server, + ExternalEntity, + Datastore, + Actor, + Process, + SetOfProcesses, + Action, + Lambda, +) + + +def loads(s): + result = json.loads(s, object_hook=decode) + if not isinstance(result, TM): + raise ValueError("Failed to decode JSON input as TM") + return result + + +def load(fp): + result = json.load(fp, object_hook=decode) + if not isinstance(result, TM): + raise ValueError("Failed to decode JSON input as TM") + return result + + +def decode(data): + if "elements" not in data and "flows" not in data and "boundaries" not in data: + return data + + boundaries = decode_boundaries(data.pop("boundaries", [])) + elements = decode_elements(data.pop("elements", []), boundaries) + decode_flows(data.pop("flows", []), elements) + + if "name" not in data: + raise ValueError("name property missing for threat model") + if "onDuplicates" in data: + data["onDuplicates"] = Action(data["onDuplicates"]) + return TM(data.pop("name"), **data) + + +def decode_boundaries(flat): + boundaries = {} + refs = {} + for i, e in enumerate(flat): + name = e.pop("name", None) + if name is None: + raise ValueError(f"name property missing in boundary {i}") + if "inBoundary" in e: + refs[name] = e.pop("inBoundary") + e = Boundary(name, **e) + boundaries[name] = e + + # do a second pass to resolve self-references + for b in boundaries.values(): + if b.name not in refs: + continue + b.inBoundary = boundaries[refs[b.name]] + + return boundaries + + +def decode_elements(flat, boundaries): + elements = {} + for i, e in enumerate(flat): + klass = getattr(sys.modules[__name__], e.pop("__class__", "Asset")) + name = e.pop("name", None) + if name is None: + raise ValueError(f"name property missing in element {i}") + if "inBoundary" in e: + if e["inBoundary"] not in boundaries: + raise ValueError( + f"element {name} references invalid boundary {e['inBoundary']}" + ) + e["inBoundary"] = boundaries[e["inBoundary"]] + e = klass(name, **e) + elements[name] = e + + return elements + + +def decode_flows(flat, elements): + for i, e in enumerate(flat): + name = e.pop("name", None) + if name is None: + raise ValueError(f"name property missing in dataflow {i}") + if "source" not in e: + raise ValueError(f"dataflow {name} is missing source property") + if e["source"] not in elements: + raise ValueError(f"dataflow {name} references invalid source {e['source']}") + source = elements[e.pop("source")] + if "sink" not in e: + raise ValueError(f"dataflow {name} is missing sink property") + if e["sink"] not in elements: + raise ValueError(f"dataflow {name} references invalid sink {e['sink']}") + sink = elements[e.pop("sink")] + Dataflow(source, sink, name, **e) diff --git a/pytm/pytm.py b/pytm/pytm.py index a8dabd1..c02531d 100644 --- a/pytm/pytm.py +++ b/pytm/pytm.py @@ -1,5 +1,5 @@ - import argparse +import errno import inspect import json import logging @@ -7,19 +7,17 @@ import random import sys import uuid -import errno from collections import defaultdict from collections.abc import Iterable from enum import Enum +from functools import singledispatch, lru_cache from hashlib import sha224 from itertools import combinations -from os.path import dirname -from os import mkdir +from shutil import rmtree from textwrap import indent, wrap - from weakref import WeakKeyDictionary + from pydal import DAL, Field -from shutil import rmtree from .template_engine import SuperFormatter @@ -484,15 +482,16 @@ def __init__( def __str__(self): return f"{self.target}: {self.description}\n{self.details}\n{self.severity}" + class TM(): """Describes the threat model administratively, and holds all details during a run""" - _BagOfFlows = [] - _BagOfElements = [] - _BagOfThreats = [] - _BagOfBoundaries = [] - _BagOfData = [] + _flows = [] + _elements = [] + _threats = [] + _boundaries = [] + _data = [] _threatsExcluded = [] _sf = None _duplicate_ignored_attrs = "name", "note", "order", "response", "responseTo" @@ -519,14 +518,14 @@ def __init__(self, name, **kwargs): @classmethod def reset(cls): - cls._BagOfFlows = [] - cls._BagOfElements = [] - cls._BagOfThreats = [] - cls._BagOfBoundaries = [] - cls._BagOfData = [] + cls._flows = [] + cls._elements = [] + cls._threats = [] + cls._boundaries = [] + cls._data = [] def _init_threats(self): - TM._BagOfThreats = [] + TM._threats = [] self._add_threats() def _add_threats(self): @@ -534,15 +533,15 @@ def _add_threats(self): threats_json = json.load(threat_file) for i in threats_json: - TM._BagOfThreats.append(Threat(**i)) + TM._threats.append(Threat(**i)) def resolve(self): findings = [] elements = defaultdict(list) - for e in TM._BagOfElements: + for e in TM._elements: if not e.inScope: continue - for t in TM._BagOfThreats: + for t in TM._threats: if not t.apply(e): continue f = Finding( @@ -565,20 +564,20 @@ def check(self): if self.description is None: raise ValueError("""Every threat model should have at least a brief description of the system being modeled.""") - TM._BagOfFlows = _match_responses(_sort(TM._BagOfFlows, self.isOrdered)) - self._check_duplicates(TM._BagOfFlows) - _apply_defaults(TM._BagOfFlows, TM._BagOfData) + TM._flows = _match_responses(_sort(TM._flows, self.isOrdered)) + self._check_duplicates(TM._flows) + _apply_defaults(TM._flows, TM._data) if self.ignoreUnused: - TM._BagOfElements, TM._BagOfBoundaries = _get_elements_and_boundaries( - TM._BagOfFlows + TM._elements, TM._boundaries = _get_elements_and_boundaries( + TM._flows ) result = True - for e in (TM._BagOfElements): + for e in (TM._elements): if not e.check(): result = False if self.ignoreUnused: # cannot rely on user defined order if assets are re-used in multiple models - TM._BagOfElements = _sort_elem(TM._BagOfElements) + TM._elements = _sort_elem(TM._elements) return result def _check_duplicates(self, flows): @@ -633,13 +632,13 @@ def _dfd_template(self): def dfd(self): edges = [] - for b in TM._BagOfBoundaries: + for b in TM._boundaries: edges.append(b.dfd()) if self.mergeResponses: - for e in TM._BagOfFlows: + for e in TM._flows: if e.response is not None: e.response._is_drawn = True - for e in TM._BagOfElements: + for e in TM._elements: if not e._is_drawn and not isinstance(e, Boundary) and e.inBoundary is None: edges.append(e.dfd(mergeResponses=self.mergeResponses)) @@ -654,7 +653,7 @@ def _seq_template(self): def seq(self): participants = [] - for e in TM._BagOfElements: + for e in TM._elements: if isinstance(e, Actor): participants.append( 'actor {0} as "{1}"'.format(e._uniq_name(), e.display_name()) @@ -669,7 +668,7 @@ def seq(self): ) messages = [] - for e in TM._BagOfFlows: + for e in TM._flows: message = "{0} -> {1}: {2}".format( e.source._uniq_name(), e.sink._uniq_name(), e.display_name() ) @@ -690,12 +689,12 @@ def report(self, *args, **kwargs): data = { "tm": self, - "dataflows": self._BagOfFlows, - "threats": self._BagOfThreats, + "dataflows": TM._flows, + "threats": TM._threats, "findings": self.findings, - "elements": self._BagOfElements, - "boundaries": self._BagOfBoundaries, - "data": self._BagOfData, + "elements": TM._elements, + "boundaries": TM._boundaries, + "data": TM._data, } return self._sf.format(template, **data) @@ -705,72 +704,90 @@ def process(self): logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") if result.debug: logger.setLevel(logging.DEBUG) + if result.seq is True: print(self.seq()) + if result.dfd is True: print(self.dfd()) + + if ( + result.report is not None + or result.json is not None + or result.sqldump is not None + ): + self.resolve() + if result.sqldump is not None: self.sqlDump(result.sqldump) + + if result.json: + with open(result.json, "w", encoding="utf8") as f: + json.dump(self, f, default=to_serializable) + if result.report is not None: - self.resolve() print(self.report()) + if result.exclude is not None: TM._threatsExcluded = result.exclude.split(",") + if result.describe is not None: _describe_classes(result.describe.split()) + if result.list is True: - [print("{} - {}".format(t.id, t.description)) for t in TM._BagOfThreats] - - def _dumpElement(self, db, table, e, f): - - args = {} - - # status 09/01 - Findings need to be dumped to a separate table - logger.debug("Dumping " + str(e)) - for fieldname in f: - if fieldname == "findings": - # dump findings in Findings table - for finding in getattr(e, fieldname): - finding_args = {} - for field in [x for x in dir(finding) if not x.startswith('_') and x != 'id' and not callable(getattr(finding, x))]: - finding_args[field] = getattr(finding, field ) - db["Finding"].bulk_insert([finding_args]) - continue - else: - args[fieldname] = getattr(e, fieldname) - db[table].bulk_insert([args]) + [print("{} - {}".format(t.id, t.description)) for t in TM._threats] def sqlDump(self, filename): - fields = {} - table = {} - try: rmtree('./sqldump') - mkdir('./sqldump') + os.mkdir('./sqldump') except OSError as e: if e.errno != errno.ENOENT: raise - else: - mkdir('./sqldump') - - db = DAL('sqlite://' + filename, folder='sqldump') - - # fill everything up - self.resolve() + else: + os.mkdir('./sqldump') - # create tables - for e in Server, ExternalEntity, Dataflow, Datastore, Actor, Process, SetOfProcesses, Boundary, TM, Lambda, Threat, Finding: - # id is internal and a reserved field name - fields[e.__name__] = [x for x in dir(e) if not x.startswith('_') and x != 'id' and not callable(getattr(e, x))] - logger.debug("Creating table " + e.__name__) - table[e.__name__] = db.define_table(e.__name__, [Field(x) for x in fields[e.__name__]]) + db = DAL('sqlite://' + filename, folder='sqldump') - for el in TM._BagOfElements: - self._dumpElement(db, table[el.__class__.__name__], el, fields[el.__class__.__name__]) + for klass in ( + Server, + ExternalEntity, + Dataflow, + Datastore, + Actor, + Process, + SetOfProcesses, + Boundary, + TM, + Threat, + Lambda, + Data, + Finding, + ): + self.get_table(db, klass) + + for e in TM._threats + TM._data + TM._elements + self.findings + [self]: + table = self.get_table(db, e.__class__) + row = {} + for k, v in serialize(e).items(): + if k == "id": + k = "SID" + row[k] = ", ".join(v) if isinstance(v, list) else v + db[table].bulk_insert([row]) - # close database db.close() + @lru_cache + def get_table(self, db, klass): + name = klass.__name__ + fields = [ + Field("SID" if i == "id" else i) + for i in dir(klass) + if not i.startswith("_") + and not callable(getattr(klass, i)) + ] + return db.define_table(name, fields) + class Element(): """A generic element""" @@ -779,25 +796,11 @@ class Element(): description = varString("") inBoundary = varBoundary(None, doc="Trust boundary this element exists in") inScope = varBool(True, doc="Is the element in scope of the threat model") - onAWS = varBool(False) - isHardened = varBool(False) maxClassification = varClassification( Classification.UNKNOWN, required=False, doc="Maximum data classification this element can handle.", ) - implementsAuthenticationScheme = varBool(False) - implementsNonce = varBool(False, doc="""Nonce is an arbitrary number -that can be used just once in a cryptographic communication. -It is often a random or pseudo-random number issued in an authentication protocol -to ensure that old communications cannot be reused in replay attacks. -They can also be useful as initialization vectors and in cryptographic -hash functions.""") - handlesResources = varBool(False) - definesConnectionTimeout = varBool(False) - authenticatesDestination = varBool(False) - OS = varString("") - isAdmin = varBool(False) findings = varFindings([]) def __init__(self, name, **kwargs): @@ -806,7 +809,7 @@ def __init__(self, name, **kwargs): self.name = name self.uuid = uuid.UUID(int=random.getrandbits(128)) self._is_drawn = False - TM._BagOfElements.append(self) + TM._elements.append(self) def __repr__(self): return "<{0}.{1}({2}) at {3}>".format( @@ -952,7 +955,7 @@ def __init__(self, name, **kwargs): for key, value in kwargs.items(): setattr(self, key, value) self.name = name - TM._BagOfData.append(self) + TM._data.append(self) def __repr__(self): return "<{0}.{1}({2}) at {3}>".format( @@ -963,27 +966,45 @@ def __str__(self): return "{0}({1})".format(type(self).__name__, self.name) -class Lambda(Element): - """A lambda function running in a Function-as-a-Service (FaaS) environment""" - - port = varInt(-1, doc="Default TCP port for outgoing data flows") - protocol = varString("", doc="Default network protocol for outgoing data flows") - data = varData([], doc="Default type of data in outgoing data flows") - onAWS = varBool(True) +class Asset(Element): + """An asset with outgoing or incoming dataflows""" + port = varInt(-1, doc="Default TCP port for incoming data flows") + isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted") + protocol = varString("", doc="Default network protocol for incoming data flows") + data = varData([], doc="Default type of data in incoming data flows") + inputs = varElements([], doc="incoming Dataflows") + outputs = varElements([], doc="outgoing Dataflows") + onAWS = varBool(False) + isHardened = varBool(False) + implementsAuthenticationScheme = varBool(False) + implementsNonce = varBool(False, doc="""Nonce is an arbitrary number +that can be used just once in a cryptographic communication. +It is often a random or pseudo-random number issued in an authentication protocol +to ensure that old communications cannot be reused in replay attacks. +They can also be useful as initialization vectors and in cryptographic +hash functions.""") + handlesResources = varBool(False) + definesConnectionTimeout = varBool(False) + authenticatesDestination = varBool(False) authenticatesSource = varBool(False) + authorizesSource = varBool(False) hasAccessControl = varBool(False) + validatesInput = varBool(False) sanitizesInput = varBool(False) + checksInputBounds = varBool(False) encodesOutput = varBool(False) handlesResourceConsumption = varBool(False) authenticationScheme = varString("") usesEnvironmentVariables = varBool(False) - validatesInput = varBool(False) - checksInputBounds = varBool(False) + OS = varString("") + + +class Lambda(Asset): + """A lambda function running in a Function-as-a-Service (FaaS) environment""" + + onAWS = varBool(True) environment = varString("") implementsAPI = varBool(False) - authorizesSource = varBool(False) - inputs = varElements([], doc="incoming Dataflows") - outputs = varElements([], doc="outgoing Dataflows") def __init__(self, name, **kwargs): super().__init__(name, **kwargs) @@ -1018,33 +1039,19 @@ def _shape(self): return "none" -class Server(Element): +class Server(Asset): """An entity processing data""" - port = varInt(-1, doc="Default TCP port for incoming data flows") - isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted") - protocol = varString("", doc="Default network protocol for incoming data flows") - data = varData([], doc="Default type of data in incoming data flows") - inputs = varElements([], doc="incoming Dataflows") - outputs = varElements([], doc="outgoing Dataflows") providesConfidentiality = varBool(False) providesIntegrity = varBool(False) - authenticatesSource = varBool(False) - sanitizesInput = varBool(False) - encodesOutput = varBool(False) - hasAccessControl = varBool(False) - implementsCSRFToken = varBool(False) - handlesResourceConsumption = varBool(False) - isResilient = varBool(False) - authenticationScheme = varString("") - validatesInput = varBool(False) validatesHeaders = varBool(False) encodesHeaders = varBool(False) + implementsCSRFToken = varBool(False) + isResilient = varBool(False) usesSessionTokens = varBool(False) usesEncryptionAlgorithm = varString("") usesCache = varBool(False) usesVPN = varBool(False) - authorizesSource = varBool(False) usesCodeSigning = varBool(False) validatesContentType = varBool(False) invokesScriptFilters = varBool(False) @@ -1053,7 +1060,6 @@ class Server(Element): implementsServerSideValidation = varBool(False) usesXMLParser = varBool(False) disablesDTD = varBool(False) - checksInputBounds = varBool(False) implementsStrictHTTPValidation = varBool(False) implementsPOLP = varBool(False, doc="""The principle of least privilege (PoLP), also known as the principle of minimal privilege or the principle of least authority, @@ -1069,22 +1075,16 @@ def _shape(self): return "circle" -class ExternalEntity(Element): +class ExternalEntity(Asset): hasPhysicalAccess = varBool(False) def __init__(self, name, **kwargs): super().__init__(name, **kwargs) -class Datastore(Element): +class Datastore(Asset): """An entity storing data""" - port = varInt(-1, doc="Default TCP port for incoming data flows") - isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted") - protocol = varString("", doc="Default network protocol for incoming data flows") - data = varData([], doc="Default type of data in incoming data flows") - inputs = varElements([], doc="incoming Dataflows") - outputs = varElements([], doc="outgoing Dataflows") onRDS = varBool(False) storesLogData = varBool(False) storesPII = varBool(False, doc="""Personally Identifiable Information @@ -1093,17 +1093,12 @@ class Datastore(Element): isSQL = varBool(True) providesConfidentiality = varBool(False) providesIntegrity = varBool(False) - authenticatesSource = varBool(False) isShared = varBool(False) hasWriteAccess = varBool(False) handlesResourceConsumption = varBool(False) isResilient = varBool(False) handlesInterruptions = varBool(False) - authorizesSource = varBool(False) - hasAccessControl = varBool(False) - authenticationScheme = varString("") usesEncryptionAlgorithm = varString("") - validatesInput = varBool(False) implementsPOLP = varBool(False, doc="""The principle of least privilege (PoLP), also known as the principle of minimal privilege or the principle of least authority, requires that in a particular abstraction layer of a computing environment, @@ -1139,41 +1134,29 @@ class Actor(Element): data = varData([], doc="Default type of data in outgoing data flows") inputs = varElements([], doc="incoming Dataflows") outputs = varElements([], doc="outgoing Dataflows") + authenticatesDestination = varBool(False) + isAdmin = varBool(False) def __init__(self, name, **kwargs): super().__init__(name, **kwargs) -class Process(Element): +class Process(Asset): """An entity processing data""" - port = varInt(-1, doc="Default TCP port for incoming data flows") - isEncrypted = varBool(False, doc="Requires incoming data flow to be encrypted") - protocol = varString("", doc="Default network protocol for incoming data flows") - data = varData([], doc="Default type of data in incoming data flows") - inputs = varElements([], doc="incoming Dataflows") - outputs = varElements([], doc="outgoing Dataflows") codeType = varString("Unmanaged") implementsCommunicationProtocol = varBool(False) providesConfidentiality = varBool(False) providesIntegrity = varBool(False) - authenticatesSource = varBool(False) isResilient = varBool(False) - hasAccessControl = varBool(False) tracksExecutionFlow = varBool(False) implementsCSRFToken = varBool(False) handlesResourceConsumption = varBool(False) handlesCrashes = varBool(False) handlesInterruptions = varBool(False) - authorizesSource = varBool(False) - authenticationScheme = varString("") - checksInputBounds = varBool(False) - validatesInput = varBool(False) - sanitizesInput = varBool(False) implementsAPI = varBool(False) usesSecureFunctions = varBool(False) environment = varString("") - usesEnvironmentVariables = varBool(False) disablesiFrames = varBool(False) implementsPOLP = varBool(False, doc="""The principle of least privilege (PoLP), also known as the principle of minimal privilege or the principle of least authority, @@ -1181,7 +1164,6 @@ class Process(Element): every module (such as a process, a user, or a program, depending on the subject) must be able to access only the information and resources that are necessary for its legitimate purpose.""") - encodesOutput = varBool(False) usesParameterizedInput = varBool(False) allowsClientSideScripting = varBool(False) usesStrongSessionIdentifiers = varBool(False) @@ -1222,9 +1204,11 @@ class Dataflow(Element): dstPort = varInt(-1, doc="Destination TCP port") isEncrypted = varBool(False, doc="Is the data encrypted") protocol = varString("", doc="Protocol used in this data flow") - data = varData([], "Type of data carried in this data flow") + data = varData([], doc="Default type of data in incoming data flows") + authenticatesDestination = varBool(False) authenticatedWith = varBool(False) order = varInt(-1, doc="Number of this data flow in the threat model") + implementsAuthenticationScheme = varBool(False) implementsCommunicationProtocol = varBool(False) note = varString("") usesVPN = varBool(False) @@ -1236,7 +1220,7 @@ def __init__(self, source, sink, name, **kwargs): self.source = source self.sink = sink super().__init__(name, **kwargs) - TM._BagOfFlows.append(self) + TM._flows.append(self) def display_name(self): if self.order == -1: @@ -1286,8 +1270,8 @@ class Boundary(Element): def __init__(self, name, **kwargs): super().__init__(name, **kwargs) - if name not in TM._BagOfBoundaries: - TM._BagOfBoundaries.append(self) + if name not in TM._boundaries: + TM._boundaries.append(self) def _dfd_template(self): return """subgraph cluster_{uniq_name} {{ @@ -1310,7 +1294,7 @@ def dfd(self): self._is_drawn = True logger.debug("Now drawing boundary " + self.name) edges = [] - for e in TM._BagOfElements: + for e in TM._elements: if e.inBoundary != self or e._is_drawn: continue # The content to draw can include Boundary objects @@ -1327,9 +1311,71 @@ def _color(self): return "firebrick2" +@singledispatch +def to_serializable(val): + """Used by default.""" + return str(val) + + +@to_serializable.register(TM) +def ts_tm(obj): + return serialize(obj, nested=True) + + +@to_serializable.register(Data) +@to_serializable.register(Threat) +@to_serializable.register(Element) +@to_serializable.register(Finding) +def ts_element(obj): + return serialize(obj, nested=False) + + +def serialize(obj, nested=False): + """Used if *obj* is an instance of TM, Element, Threat or Finding.""" + klass = obj.__class__ + result = {} + if isinstance(obj, (Actor, Asset)): + result["__class__"] = klass.__name__ + for i in dir(obj): + if ( + i.startswith("__") + or callable(getattr(klass, i, {})) + or ( + isinstance(obj, TM) + and i in ("_sf", "_duplicate_ignored_attrs", "_threats") + ) + or ( + isinstance(obj, Element) + and i in ("_is_drawn", "uuid") + ) + or (isinstance(obj, Finding) and i == "element") + ): + continue + value = getattr(obj, i) + if isinstance(obj, TM) and i == "_elements": + value = [e for e in value if isinstance(e, (Actor, Asset))] + if value is not None: + if isinstance(value, (Element, Data)): + value = value.name + elif isinstance(obj, Threat) and i == "target": + value = [v.__name__ for v in value] + elif ( + not nested + and not isinstance(value, str) + and isinstance(value, Iterable) + ): + value = [v.id if isinstance(v, Finding) else v.name for v in value] + result[i.lstrip("_")] = value + return result + + def get_args(): _parser = argparse.ArgumentParser() - _parser.add_argument('--sqldump', help='dumps all threat model elements and findings into the named sqlite file (erased if exists)') + _parser.add_argument( + "--sqldump", + help="""dumps all threat model elements and findings +into the named sqlite file (erased if exists)""", + ) _parser.add_argument("--debug", action="store_true", help="print debug messages") _parser.add_argument("--dfd", action="store_true", help="output DFD") _parser.add_argument( @@ -1345,6 +1391,7 @@ def get_args(): _parser.add_argument( "--describe", help="describe the properties available for a given element" ) + _parser.add_argument('--json', help='output a JSON file') _args = _parser.parse_args() return _args diff --git a/tests/input.json b/tests/input.json new file mode 100644 index 0000000..0cdcff8 --- /dev/null +++ b/tests/input.json @@ -0,0 +1,54 @@ +{ + "name": "my test tm", + "description": "aaa", + "isOrdered": true, + "onDuplicates": "IGNORE", + "boundaries": [ + { + "name": "Internet" + }, + { + "name": "Server/DB" + } + ], + "elements": [ + { + "__class__": "Actor", + "name": "User", + "inBoundary": "Internet" + }, + { + "__class__": "Server", + "name": "Web Server" + }, + { + "__class__": "Datastore", + "name": "SQL Database", + "inBoundary": "Server/DB" + } + ], + "flows": [ + { + "name": "Request", + "source": "User", + "sink": "Web Server", + "note": "bbb" + }, + { + "name": "Insert", + "source": "Web Server", + "sink": "SQL Database", + "note": "ccc" + }, + { + "name": "Select", + "source": "SQL Database", + "sink": "Web Server" + }, + { + "name": "Response", + "source": "Web Server", + "sink": "User" + } + ] +} diff --git a/tests/output.json b/tests/output.json new file mode 100644 index 0000000..f3e347f --- /dev/null +++ b/tests/output.json @@ -0,0 +1,426 @@ +{ + "boundaries": [ + { + "description": "", + "findings": [], + "inBoundary": null, + "inScope": true, + "maxClassification": "Classification.UNKNOWN", + "name": "Internet" + }, + { + "description": "", + "findings": [], + "inBoundary": null, + "inScope": true, + "maxClassification": "Classification.UNKNOWN", + "name": "Server/DB" + } + ], + "data": [], + "description": "aaa", + "elements": [ + { + "__class__": "Actor", + "authenticatesDestination": false, + "data": [], + "description": "", + "findings": [], + "inBoundary": "Internet", + "inScope": true, + "inputs": [ + "Show comments (*)" + ], + "isAdmin": false, + "maxClassification": "Classification.UNKNOWN", + "name": "User", + "outputs": [ + "User enters comments (*)" + ], + "port": -1, + "protocol": "" + }, + { + "OS": "", + "__class__": "Server", + "authenticatesDestination": false, + "authenticatesSource": false, + "authenticationScheme": "", + "authorizesSource": false, + "checksInputBounds": false, + "data": [], + "definesConnectionTimeout": false, + "description": "", + "disablesDTD": false, + "encodesHeaders": false, + "encodesOutput": false, + "findings": [], + "handlesResourceConsumption": false, + "handlesResources": false, + "hasAccessControl": false, + "implementsAuthenticationScheme": false, + "implementsCSRFToken": false, + "implementsNonce": false, + "implementsPOLP": false, + "implementsServerSideValidation": false, + "implementsStrictHTTPValidation": false, + "inBoundary": null, + "inScope": true, + "inputs": [ + "User enters comments (*)", + "Retrieve comments" + ], + "invokesScriptFilters": false, + "isEncrypted": false, + "isHardened": false, + "isResilient": false, + "maxClassification": "Classification.UNKNOWN", + "name": "Web Server", + "onAWS": false, + "outputs": [ + "Insert query with comments", + "Call func", + "Show comments (*)" + ], + "port": -1, + "protocol": "", + "providesConfidentiality": false, + "providesIntegrity": false, + "sanitizesInput": false, + "usesCache": false, + "usesCodeSigning": false, + "usesEncryptionAlgorithm": "", + "usesEnvironmentVariables": false, + "usesLatestTLSversion": false, + "usesSessionTokens": false, + "usesStrongSessionIdentifiers": false, + "usesVPN": false, + "usesXMLParser": false, + "validatesContentType": false, + "validatesHeaders": false, + "validatesInput": false + }, + { + "OS": "", + "__class__": "Lambda", + "authenticatesDestination": false, + "authenticatesSource": false, + "authenticationScheme": "", + "authorizesSource": false, + "checksInputBounds": false, + "data": [], + "definesConnectionTimeout": false, + "description": "", + "encodesOutput": false, + "environment": "", + "findings": [], + "handlesResourceConsumption": false, + "handlesResources": false, + "hasAccessControl": false, + "implementsAPI": false, + "implementsAuthenticationScheme": false, + "implementsNonce": false, + "inBoundary": null, + "inScope": true, + "inputs": [ + "Call func" + ], + "isEncrypted": false, + "isHardened": false, + "maxClassification": "Classification.UNKNOWN", + "name": "Lambda func", + "onAWS": true, + "outputs": [], + "port": -1, + "protocol": "", + "sanitizesInput": false, + "usesEnvironmentVariables": false, + "validatesInput": false + }, + { + "OS": "", + "__class__": "Process", + "allowsClientSideScripting": false, + "authenticatesDestination": false, + "authenticatesSource": false, + "authenticationScheme": "", + "authorizesSource": false, + "checksInputBounds": false, + "codeType": "Unmanaged", + "data": [], + "definesConnectionTimeout": false, + "description": "", + "disablesiFrames": false, + "encodesOutput": false, + "encryptsCookies": false, + "encryptsSessionData": false, + "environment": "", + "findings": [], + "handlesCrashes": false, + "handlesInterruptions": false, + "handlesResourceConsumption": false, + "handlesResources": false, + "hasAccessControl": false, + "implementsAPI": false, + "implementsAuthenticationScheme": false, + "implementsCSRFToken": false, + "implementsCommunicationProtocol": false, + "implementsNonce": false, + "implementsPOLP": false, + "inBoundary": null, + "inScope": true, + "inputs": [], + "isEncrypted": false, + "isHardened": false, + "isResilient": false, + "maxClassification": "Classification.UNKNOWN", + "name": "Task queue worker", + "onAWS": false, + "outputs": [ + "Query for tasks" + ], + "port": -1, + "protocol": "", + "providesConfidentiality": false, + "providesIntegrity": false, + "sanitizesInput": false, + "tracksExecutionFlow": false, + "usesEnvironmentVariables": false, + "usesMFA": false, + "usesParameterizedInput": false, + "usesSecureFunctions": false, + "usesStrongSessionIdentifiers": false, + "validatesInput": false, + "verifySessionIdentifiers": false + }, + { + "OS": "", + "__class__": "Datastore", + "authenticatesDestination": false, + "authenticatesSource": false, + "authenticationScheme": "", + "authorizesSource": false, + "checksInputBounds": false, + "data": [], + "definesConnectionTimeout": false, + "description": "", + "encodesOutput": false, + "findings": [], + "handlesInterruptions": false, + "handlesResourceConsumption": false, + "handlesResources": false, + "hasAccessControl": false, + "hasWriteAccess": false, + "implementsAuthenticationScheme": false, + "implementsNonce": false, + "implementsPOLP": false, + "inBoundary": "Server/DB", + "inScope": true, + "inputs": [ + "Insert query with comments", + "Query for tasks" + ], + "isEncrypted": false, + "isHardened": false, + "isResilient": false, + "isSQL": true, + "isShared": false, + "maxClassification": "Classification.UNKNOWN", + "name": "SQL Database", + "onAWS": false, + "onRDS": false, + "outputs": [ + "Retrieve comments" + ], + "port": -1, + "protocol": "", + "providesConfidentiality": false, + "providesIntegrity": false, + "sanitizesInput": false, + "storesLogData": false, + "storesPII": false, + "storesSensitiveData": false, + "usesEncryptionAlgorithm": "", + "usesEnvironmentVariables": false, + "validatesInput": false + } + ], + "findings": [], + "flows": [ + { + "authenticatedWith": false, + "authenticatesDestination": false, + "authorizesSource": false, + "data": [], + "description": "", + "dstPort": -1, + "findings": [], + "implementsAuthenticationScheme": false, + "implementsCommunicationProtocol": false, + "inBoundary": null, + "inScope": true, + "isEncrypted": false, + "isResponse": false, + "maxClassification": "Classification.UNKNOWN", + "name": "User enters comments (*)", + "note": "bbb", + "order": 1, + "protocol": "", + "response": null, + "responseTo": null, + "sink": "Web Server", + "source": "User", + "srcPort": -1, + "usesLatestTLSversion": false, + "usesSessionTokens": false, + "usesVPN": false + }, + { + "authenticatedWith": false, + "authenticatesDestination": false, + "authorizesSource": false, + "data": [], + "description": "", + "dstPort": -1, + "findings": [], + "implementsAuthenticationScheme": false, + "implementsCommunicationProtocol": false, + "inBoundary": null, + "inScope": true, + "isEncrypted": false, + "isResponse": false, + "maxClassification": "Classification.UNKNOWN", + "name": "Insert query with comments", + "note": "ccc", + "order": 2, + "protocol": "", + "response": null, + "responseTo": null, + "sink": "SQL Database", + "source": "Web Server", + "srcPort": -1, + "usesLatestTLSversion": false, + "usesSessionTokens": false, + "usesVPN": false + }, + { + "authenticatedWith": false, + "authenticatesDestination": false, + "authorizesSource": false, + "data": [], + "description": "", + "dstPort": -1, + "findings": [], + "implementsAuthenticationScheme": false, + "implementsCommunicationProtocol": false, + "inBoundary": null, + "inScope": true, + "isEncrypted": false, + "isResponse": false, + "maxClassification": "Classification.UNKNOWN", + "name": "Call func", + "note": "", + "order": 3, + "protocol": "", + "response": null, + "responseTo": null, + "sink": "Lambda func", + "source": "Web Server", + "srcPort": -1, + "usesLatestTLSversion": false, + "usesSessionTokens": false, + "usesVPN": false + }, + { + "authenticatedWith": false, + "authenticatesDestination": false, + "authorizesSource": false, + "data": [], + "description": "", + "dstPort": -1, + "findings": [], + "implementsAuthenticationScheme": false, + "implementsCommunicationProtocol": false, + "inBoundary": null, + "inScope": true, + "isEncrypted": false, + "isResponse": false, + "maxClassification": "Classification.UNKNOWN", + "name": "Retrieve comments", + "note": "", + "order": 4, + "protocol": "", + "response": null, + "responseTo": null, + "sink": "Web Server", + "source": "SQL Database", + "srcPort": -1, + "usesLatestTLSversion": false, + "usesSessionTokens": false, + "usesVPN": false + }, + { + "authenticatedWith": false, + "authenticatesDestination": false, + "authorizesSource": false, + "data": [], + "description": "", + "dstPort": -1, + "findings": [], + "implementsAuthenticationScheme": false, + "implementsCommunicationProtocol": false, + "inBoundary": null, + "inScope": true, + "isEncrypted": false, + "isResponse": false, + "maxClassification": "Classification.UNKNOWN", + "name": "Show comments (*)", + "note": "", + "order": 5, + "protocol": "", + "response": null, + "responseTo": null, + "sink": "User", + "source": "Web Server", + "srcPort": -1, + "usesLatestTLSversion": false, + "usesSessionTokens": false, + "usesVPN": false + }, + { + "authenticatedWith": false, + "authenticatesDestination": false, + "authorizesSource": false, + "data": [], + "description": "", + "dstPort": -1, + "findings": [], + "implementsAuthenticationScheme": false, + "implementsCommunicationProtocol": false, + "inBoundary": null, + "inScope": true, + "isEncrypted": false, + "isResponse": false, + "maxClassification": "Classification.UNKNOWN", + "name": "Query for tasks", + "note": "", + "order": 6, + "protocol": "", + "response": null, + "responseTo": null, + "sink": "SQL Database", + "source": "Task queue worker", + "srcPort": -1, + "usesLatestTLSversion": false, + "usesSessionTokens": false, + "usesVPN": false + } + ], + "ignoreUnused": false, + "isOrdered": true, + "mergeResponses": false, + "name": "my test tm", + "onDuplicates": "Action.NO_ACTION", + "threatsExcluded": [], + "threatsFile": "pytm/threatlib/threats.json" +} diff --git a/tests/test_private_func.py b/tests/test_private_func.py index 746cee2..bd0bb69 100644 --- a/tests/test_private_func.py +++ b/tests/test_private_func.py @@ -36,7 +36,7 @@ def test_kwargs(self): def test_load_threats(self): tm = TM("TM") - self.assertNotEqual(len(TM._BagOfThreats), 0) + self.assertNotEqual(len(TM._threats), 0) with self.assertRaises(FileNotFoundError): tm.threatsFile = "threats.json" diff --git a/tests/test_pytmfunc.py b/tests/test_pytmfunc.py index f81b7ed..9503a7c 100644 --- a/tests/test_pytmfunc.py +++ b/tests/test_pytmfunc.py @@ -16,7 +16,9 @@ Process, Server, Threat, + loads ) +from pytm.pytm import to_serializable with open( os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) @@ -176,24 +178,86 @@ def test_resolve(self): results = Dataflow(db, web, "Retrieve comments") resp = Dataflow(web, user, "Show comments (*)") - TM._BagOfThreats = [ + TM._threats = [ Threat(SID=klass, target=klass) for klass in ["Actor", "Server", "Datastore", "Dataflow"] ] tm.resolve() self.maxDiff = None - self.assertListEqual( + self.assertEqual( [f.id for f in tm.findings], ["Server", "Datastore", "Dataflow", "Dataflow", "Dataflow", "Dataflow"], ) - self.assertListEqual([f.id for f in user.findings], []) - self.assertListEqual([f.id for f in web.findings], ["Server"]) - self.assertListEqual([f.id for f in db.findings], ["Datastore"]) - self.assertListEqual([f.id for f in req.findings], ["Dataflow"]) - self.assertListEqual([f.id for f in query.findings], ["Dataflow"]) - self.assertListEqual([f.id for f in results.findings], ["Dataflow"]) - self.assertListEqual([f.id for f in resp.findings], ["Dataflow"]) + self.assertEqual([f.id for f in user.findings], []) + self.assertEqual([f.id for f in web.findings], ["Server"]) + self.assertEqual([f.id for f in db.findings], ["Datastore"]) + self.assertEqual([f.id for f in req.findings], ["Dataflow"]) + self.assertEqual([f.id for f in query.findings], ["Dataflow"]) + self.assertEqual([f.id for f in results.findings], ["Dataflow"]) + self.assertEqual([f.id for f in resp.findings], ["Dataflow"]) + + def test_json_dumps(self): + random.seed(0) + dir_path = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(dir_path, 'output.json')) as x: + expected = x.read().strip() + + TM.reset() + tm = TM( + "my test tm", description="aaa", threatsFile="pytm/threatlib/threats.json" + ) + tm.isOrdered = True + internet = Boundary("Internet") + server_db = Boundary("Server/DB") + user = Actor("User", inBoundary=internet) + web = Server("Web Server") + func = Lambda("Lambda func") + worker = Process("Task queue worker") + db = Datastore("SQL Database", inBoundary=server_db) + + Dataflow(user, web, "User enters comments (*)", note="bbb") + Dataflow(web, db, "Insert query with comments", note="ccc") + Dataflow(web, func, "Call func") + Dataflow(db, web, "Retrieve comments") + Dataflow(web, user, "Show comments (*)") + Dataflow(worker, db, "Query for tasks") + + self.assertTrue(tm.check()) + output = json.dumps(tm, default=to_serializable, sort_keys=True, indent=4) + + self.maxDiff = None + self.assertEqual(output, expected) + + def test_json_loads(self): + random.seed(0) + dir_path = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(dir_path, 'input.json')) as x: + contents = x.read().strip() + + TM.reset() + tm = loads(contents) + self.assertTrue(tm.check()) + + self.maxDiff = None + self.assertEqual([b.name for b in tm._boundaries], ["Internet", "Server/DB"]) + self.assertEqual( + [e.name for e in tm._elements], + [ + "Internet", + "Server/DB", + "User", + "Web Server", + "SQL Database", + "Request", + "Insert", + "Select", + "Response", + ], + ) + self.assertEqual( + [f.name for f in tm._flows], ["Request", "Insert", "Select", "Response"] + ) class Testpytm(unittest.TestCase):