Skip to content

Commit

Permalink
Merge pull request #210 from stratosphereips/ondra-data-object-redesign
Browse files Browse the repository at this point in the history
Ondra data object redesign
  • Loading branch information
ondrej-lukas authored Apr 25, 2024
2 parents d67b377 + 8bd2fed commit 1cae2a5
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 19 deletions.
6 changes: 4 additions & 2 deletions docs/Components.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@ Service class holds information about services running in hosts. Each Service ha
Example: `s = Service('postgresql', 'passive', '14.3.0', False)`

## Data
Data class holds information about datapoints (files) present in the NetSecGame.
Data class holds information about datapoints (files) present in the NetSecGame. Datapoints DO NOT hold the content of files.
Each data instance has two parameters:
- `owner`:str - specifying the user who ownes this datapoint
- `id`: str - unique identifier of the datapoint
- `id`: str - unique identifier of the datapoint in a host
- `size`: int - size of the datapoint (optional, default=0)
- `type`: str - indetification of a type of the file (optional, default="")

Example:`Data("User1", "DatabaseData")`

Expand Down
18 changes: 16 additions & 2 deletions env/game_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,11 @@ class Data():
"""
owner: str
id: str

size: int = 0
type: str = ""

def __hash__(self) -> int:
return hash((self.owner, self.id, self.size, self.type))
@enum.unique
class ActionType(enum.Enum):
"""
Expand Down Expand Up @@ -445,4 +449,14 @@ def from_string(cls, string:str):
case "GameStatus.BAD_REQUEST":
return GameStatus.BAD_REQUEST
def __repr__(self) -> str:
return str(self)
return str(self)
if __name__ == "__main__":
data1 = Data(owner="test", id="test_data", content="content", type="db")
data2 = Data(owner="test", id="test_data", content="content", type="db")
# print(data)
# print(data.size)

# s = set()
# s.add(data)
# s.add( Data("test", "test_data", content="new_content", type="db"))
# print(s)
26 changes: 24 additions & 2 deletions env/network_security_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self, task_config_file) -> None:
self._services = {} # Dict of all services in the environment. Keys: hostname (`str`), values: `set` of `Service` objetcs.
self._data = {} # Dict of all services in the environment. Keys: hostname (`str`), values `set` of `Service` objetcs.
self._firewall = {} # dict of all the allowed connections in the environment. Keys `IP` ,values: `set` of `IP` objects.
self._data_content = {} #content of each datapoint from self._data
# All exploits in the environment
self._exploits = {}
# A list of all the hosts where the attacker can start in a random start
Expand Down Expand Up @@ -228,6 +229,7 @@ def __init__(self, task_config_file) -> None:

# Make a copy of data placements so it is possible to reset to it when episode ends
self._data_original = copy.deepcopy(self._data)
self._data_content_original = copy.deepcopy(self._data_content)

# CURRENT STATE OF THE GAME - all set to None until self.reset()
self._current_state = None
Expand Down Expand Up @@ -515,7 +517,10 @@ def process_node_config(node_obj:NodeConfig) -> None:
for data in service.private_data:
if node_obj.id not in self._data:
self._data[node_obj.id] = set()
self._data[node_obj.id].add(components.Data(data.owner, data.description))
datapoint = components.Data(data.owner, data.description)
self._data[node_obj.id].add(datapoint)
# add content
self._data_content[node_obj.id, datapoint.id] = f"Content of {datapoint.id}"
except AttributeError:
pass
#service does not contain any data
Expand Down Expand Up @@ -732,7 +737,7 @@ def _get_services_from_host(self, host_ip:str, controlled_hosts:set)-> set:
found_services = {}
if host_ip in self._ip_to_hostname: #is it existing IP?
if self._ip_to_hostname[host_ip] in self._services: #does it have any services?
if host_ip in controlled_hosts: # Shoul local services be included ?
if host_ip in controlled_hosts: # Should local services be included ?
found_services = {s for s in self._services[self._ip_to_hostname[host_ip]]}
else:
found_services = {s for s in self._services[self._ip_to_hostname[host_ip]] if not s.is_local}
Expand Down Expand Up @@ -766,6 +771,21 @@ def _get_data_in_host(self, host_ip:str, controlled_hosts:set)->set:
logger.debug("\t\t\tCan't get data in host. The host is not controlled.")
return data

def _get_data_content(self, host_ip:str, data_id:str)->str:
"""
Returns content of data identified by a host_ip and data_ip.
"""
content = None
if host_ip in self._ip_to_hostname: #is it existing IP?
hostname = self._ip_to_hostname[host_ip]
if (hostname, data_id) in self._data_content:
content = self._data_content[hostname,data_id]
else:
logger.info(f"\tData '{data_id}' not found in host '{hostname}'({host_ip})")
else:
logger.debug("\Data content not found because target IP does not exists.")
return content

def _execute_action(self, current:components.GameState, action:components.Action, action_type='netsecenv')-> components.GameState:
"""
Execute the action and update the values in the state
Expand Down Expand Up @@ -1120,6 +1140,8 @@ def reset(self, trajectory_filename=None)->components.Observation:
self._create_new_network_mapping()
# reset self._data to orignal state
self._data = copy.deepcopy(self._data_original)
# reset self._data_content to orignal state
self._data_content_original = copy.deepcopy(self._data_content_original)
# create starting state (randomized if needed)
self._current_state = self._create_starting_state()
# create win conditions for this episode (randomize if needed)
Expand Down
64 changes: 51 additions & 13 deletions tests/test_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,29 +128,66 @@ class TestData:
"""
Test cases for the Data class
"""
def test_create_data(self):
def test_create_data_minimal(self):
"""
Test that the data are created and all elements can be accessed
Test that the data object is created with ONLY required fields (using default for the rest)
"""
data = Data("Ondra", "Password")
data = Data(owner="Ondra", id="Password")
assert data.owner == "Ondra"
assert data.id == "Password"
assert data.type == ""
assert data.size == 0

def test_create_data_all(self):
"""
Test that the data object is created with ALL fields (using default for the rest)
"""
data = Data(owner="Ondra", id="Password",size=42, type="txt")
assert data.owner == "Ondra"
assert data.id == "Password"
assert data.type == "txt"
assert data.size == 42

def test_data_equal(self):
"""
Test that two data objects with the same parameters are equal
Test that two data objects with the same required parameters are equal
"""
data = Data("Ondra", "Password")
data2 = Data("Ondra", "Password")
# test equality with all fields used
data3 = Data(owner="Ondra", id="Password",size=42, type="txt")
data4 = Data(owner="Ondra", id="Password", size=42, type="txt")
assert data == data2
assert data3 == data4

def test_data_not_equal(self):
"""
Test that two data objects with different parameters are not equal
Test that two data objects with different required parameters are NOT equal
"""
data = Data("Ondra", "Password")
data2 = Data("User2", "WebData")
data2 = Data("ChuckNorris", "Password")
data3 = Data(owner="Ondra", id="Password",size=42, type="txt")
data4 = Data(owner="Ondra", id="DifferentPassword",size=41, type="rsa")
assert data != data2
assert data3 != data4

def test_data_hash_equal(self):
data = Data("Ondra", "Password")
data2 = Data("Ondra", "Password")
# test equality with all fields used
data3 = Data(owner="Ondra", id="Password",size=42, type="txt")
data4 = Data(owner="Ondra", id="Password",size=42, type="txt")
assert hash(data) == hash(data2)
assert hash(data3) == hash(data4)

def test_data_hash_not_equal(self):
data = Data("Ondra", "Password")
data2 = Data("Ondra", "NewPassword")
# test equality with all fields used
data3 = Data(owner="Ondra", id="Password",size=42, type="txt")
data4 = Data(owner="Ondra", id="Password",size=41, type="rsa")
assert hash(data) != hash(data2)
assert hash(data3) != hash(data4)

class TestAction:
"""
Expand Down Expand Up @@ -359,7 +396,7 @@ def test_action_as_json(self):

# Exfiltrate Data
action = Action(action_type=ActionType.ExfiltrateData, params={"target_host":IP("172.16.1.3"),
"source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey")})
"source_host": IP("172.16.1.2"), "data":Data("User2", "PublicKey", size=42, type="pub")})
action_json = action.as_json()
try:
data = json.loads(action_json)
Expand All @@ -369,7 +406,7 @@ def test_action_as_json(self):
assert "ActionType.ExfiltrateData" in data["action_type"]
assert ("parameters", {"target_host": {"ip": "172.16.1.3"},
"source_host" : {"ip": "172.16.1.2"},
"data":{"owner":"User2", "id":"PublicKey"}}) in data.items()
"data":{"owner":"User2", "id":"PublicKey", "size":42 ,"type":"pub"}}) in data.items()

def test_action_scan_network_serialization(self):
action = Action(action_type=ActionType.ScanNetwork,
Expand Down Expand Up @@ -653,7 +690,7 @@ def test_game_state_as_json(self):
known_hosts={IP("192.168.1.2"), IP("192.168.1.3")}, controlled_hosts={IP("192.168.1.2")},
known_services={IP("192.168.1.3"):{Service("service1", "public", "1.01", True)}},
known_data={IP("192.168.1.3"):{Data("ChuckNorris", "data1"), Data("ChuckNorris", "data2")},
IP("192.168.1.2"):{Data("McGiver", "data2")}})
IP("192.168.1.2"):{Data("McGiver", "data2", 42, "txt")}})
game_json = game_state.as_json()
try:
data = json.loads(game_json)
Expand All @@ -664,8 +701,9 @@ def test_game_state_as_json(self):
assert {"ip": "192.168.1.3"} in data["known_hosts"]
assert {"ip": "192.168.1.2"} in data["controlled_hosts"]
assert ("192.168.1.3", [{"name": "service1", "type": "public", "version": "1.01", "is_local": True}]) in data["known_services"].items()
assert {"owner": "ChuckNorris", "id": "data1"} in data["known_data"]["192.168.1.3"]
assert {"owner": "ChuckNorris", "id": "data2"} in data["known_data"]["192.168.1.3"]
assert {"owner": "ChuckNorris", "id": "data1", "size":0, "type":""} in data["known_data"]["192.168.1.3"]
assert {"owner": "ChuckNorris", "id": "data2", "size":0, "type":""} in data["known_data"]["192.168.1.3"]
assert {"owner": "McGiver", "id": "data2", "size":42, "type":"txt"} in data["known_data"]["192.168.1.2"]

def test_game_state_json_deserialized(self):
game_state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)},
Expand All @@ -690,8 +728,8 @@ def test_game_state_as_dict(self):
assert {"ip": "192.168.1.3"} in game_dict["known_hosts"]
assert {"ip": "192.168.1.2"} in game_dict["controlled_hosts"]
assert ("192.168.1.3", [{"name": "service1", "type": "public", "version": "1.01", "is_local": True}]) in game_dict["known_services"].items()
assert {"owner": "ChuckNorris", "id": "data1"} in game_dict["known_data"]["192.168.1.3"]
assert {"owner": "ChuckNorris", "id": "data2"} in game_dict["known_data"]["192.168.1.3"]
assert {"owner": "ChuckNorris", "id": "data1", "size":0, "type":""} in game_dict["known_data"]["192.168.1.3"]
assert {"owner": "ChuckNorris", "id": "data2", "size":0, "type":""} in game_dict["known_data"]["192.168.1.3"]

def test_game_state_from_dict(self):
game_state = GameState(known_networks={Network("1.1.1.1", 24),Network("1.1.1.2", 24)},
Expand Down

0 comments on commit 1cae2a5

Please sign in to comment.