Skip to content

Commit

Permalink
Merge pull request #255 from stratosphereips/fix-block-action
Browse files Browse the repository at this point in the history
Fix block action
  • Loading branch information
ondrej-lukas authored Nov 15, 2024
2 parents fd9ead4 + 5179471 commit 00ce5be
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 27 deletions.
82 changes: 55 additions & 27 deletions env/worlds/network_security_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, task_config_file, world_name="NetSecEnv") -> None:
self._services = {} # Dict of all services in the environment. Keys: hostname (`str`), values: `set` of `Service` objetcs.
self._data = {} # Dict of all services in the environment. Keys: hostname (`str`), values `set` of `Service` objetcs.
self._firewall = {} # dict of all the allowed connections in the environment. Keys `IP` ,values: `set` of `IP` objects.
self._fw_blocks = {}
self._data_content = {} #content of each datapoint from self._data
# All exploits in the environment
self._exploits = {}
Expand Down Expand Up @@ -76,6 +77,7 @@ def __init__(self, task_config_file, world_name="NetSecEnv") -> None:
# Make a copy of data placements so it is possible to reset to it when episode ends
self._data_original = copy.deepcopy(self._data)
self._data_content_original = copy.deepcopy(self._data_content)
self._firewall_original = copy.deepcopy(self._firewall)

self._actions_played = []
self.logger.info("Environment initialization finished")
Expand Down Expand Up @@ -457,6 +459,16 @@ def _get_data_in_host(self, host_ip:str, controlled_hosts:set)->set:
else:
self.logger.debug("\t\t\tCan't get data in host. The host is not controlled.")
return data

def _get_known_blocks_in_host(self, host_ip:str, controlled_hosts:set)->set:
known_blocks = set()
if host_ip in controlled_hosts: #only return data if the agent controls the host
if host_ip in self._ip_to_hostname:
if host_ip in self._fw_blocks:
known_blocks = self._fw_blocks[host_ip]
else:
self.logger.debug("\t\t\tCan't get data in host. The host is not controlled.")
return known_blocks

def _get_data_content(self, host_ip:str, data_id:str)->str:
"""
Expand Down Expand Up @@ -579,6 +591,13 @@ def _execute_find_data_action(self, current:components.GameState, action:compone
next_data[action.parameters["target_host"]] = new_data
else:
next_data[action.parameters["target_host"]] = next_data[action.parameters["target_host"]].union(new_data)
# ADD KNOWN FW BLOCKS
new_blocks = self._get_known_blocks_in_host(action.parameters["target_host"], current.controlled_hosts)
if len(new_blocks) > 0:
if action.parameters["target_host"] not in next_blocked.keys():
next_blocked[action.parameters["target_host"]] = new_blocks
else:
next_blocked[action.parameters["target_host"]] = next_blocked[action.parameters["target_host"]].union(new_blocks)
else:
self.logger.debug(f"\t\t\tConnection {action.parameters['source_host']} -> {action.parameters['target_host']} blocked by FW. Skipping")
else:
Expand Down Expand Up @@ -680,41 +699,48 @@ def _execute_block_ip_action(self, current_state:components.GameState, action:co
- Add the rule to the FW list
- Update the state
"""
blocked_host = action.parameters['blocked_host']

next_nets, next_known_h, next_controlled_h, next_services, next_data, next_blocked = self._state_parts_deep_copy(current_state)
self.logger.info(f"\t\tBlockIP {action.parameters['target_host']}")
# Is the src in the controlled hosts?
if "source_host" in action.parameters.keys() and action.parameters["source_host"] in current_state.controlled_hosts:
# Is the target in the controlled hosts?
if "target_host" in action.parameters.keys() and action.parameters["target_host"] in current_state.controlled_hosts:
# For now there is only one FW in the main router, but this should change in the future.
# This means we ignore the 'target_host' that would be the router where this is applied.

# Stop the blocked host to connect _to_ any other IP
try:
self._firewall[blocked_host] = set()
self.logger.debug(f"Removing all allowed connections from {blocked_host}")
except KeyError:
# The blocked_host host was not in the list
pass
# Stop the other hosts to connect _to the blocked_host_
for host in self._firewall.keys():
try:
self._firewall[host].remove(blocked_host)
self.logger.debug(f"Removing {blocked_host} from allowed connections from {host}")
except KeyError:
# The blocked_host host was not in the list
pass
# Update the state of blocked ips. It is a dict with key target_host and a set with blocked hosts inside
new_blocked = set()
# Store the blocked host IP in the set of blocked hosts
new_blocked.add(action.parameters["blocked_host"])
if len(new_blocked) > 0:
if action.parameters["target_host"] not in next_blocked.keys():
next_blocked[action.parameters["target_host"]] = new_blocked
if self._firewall_check(action.parameters["source_host"], action.parameters["target_host"]):
if action.parameters["target_host"] != action.parameters['blocked_host']:
self.logger.info(f"\t\tBlockConnection {action.parameters['target_host']} <-> {action.parameters['blocked_host']}")
try:
#remove connection target_host -> blocked_host
self._firewall[action.parameters["target_host"]].discard(action.parameters["blocked_host"])
self.logger.debug(f"\t\t\t Removed rule:'{action.parameters['target_host']}' -> {action.parameters['blocked_host']}")
except KeyError:
pass
try:
#remove blocked_host -> target_host
self._firewall[action.parameters["blocked_host"]].discard(action.parameters["target_host"])
self.logger.debug(f"\t\t\t Removed rule:'{action.parameters['blocked_host']}' -> {action.parameters['target_host']}")
except KeyError:
pass

#Update the FW_Rules visible to agents
if action.parameters["target_host"] not in self._fw_blocks.keys():
self._fw_blocks[action.parameters["target_host"]] = set()
self._fw_blocks[action.parameters["target_host"]].add(action.parameters["blocked_host"])
if action.parameters["blocked_host"] not in self._fw_blocks.keys():
self._fw_blocks[action.parameters["blocked_host"]] = set()
self._fw_blocks[action.parameters["blocked_host"]].add(action.parameters["target_host"])

# update the state
if action.parameters["target_host"] not in next_blocked.keys():
next_blocked[action.parameters["target_host"]] = set()
if action.parameters["blocked_host"] not in next_blocked.keys():
next_blocked[action.parameters["blocked_host"]] = set()
next_blocked[action.parameters["target_host"]].add(action.parameters["blocked_host"])
next_blocked[action.parameters["blocked_host"]].add(action.parameters["target_host"])
else:
next_blocked[action.parameters["target_host"]] = next_blocked[action.parameters["target_host"]].union(new_blocked)
self.logger.info(f"\t\t\t Cant block connection form :'{action.parameters['target_host']}' to '{action.parameters['blocked_host']}'")
else:
self.logger.debug(f"\t\t\t Connection from '{action.parameters['source_host']}->'{action.parameters['target_host']} is blocked blocked by FW")
else:
self.logger.info(f"\t\t\t Invalid target_host:'{action.parameters['target_host']}'")
else:
Expand Down Expand Up @@ -870,6 +896,8 @@ def reset(self)->None:
self._data = copy.deepcopy(self._data_original)
# reset self._data_content to orignal state
self._data_content_original = copy.deepcopy(self._data_content_original)
self._firewall = copy.deepcopy(self._firewall_original)
self._fw_blocks = {}


self._actions_played = []
Expand Down
59 changes: 59 additions & 0 deletions tests/test_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,21 @@ def env_obs_found_data2(env_obs_exploited_service2):
new_state = env.step(state=state, action=action, agent_id=None)
return (env, new_state)

@pytest.fixture
def env_obs_blocked_connection(env_obs_exploited_service):
"After blocking"
env, state = env_obs_exploited_service
source_host = components.IP('192.168.2.2')
target_host = components.IP('192.168.1.3')
blocked_host = components.IP('192.168.2.2')
parameters = {
"target_host":target_host,
"source_host":source_host,
"blocked_host": blocked_host}
action = components.Action(components.ActionType.BlockIP, parameters)
new_state = env.step(state=state, action=action, agent_id=None)
return (env, new_state)

class TestActionsNoDefender:
def test_scan_network_not_exist(self, env_obs):
"""
Expand Down Expand Up @@ -303,3 +318,47 @@ def test_exploit_service_witout_find_service_in_host(self, env_obs_scan):
new_state = env.step(state=state, action=action, agent_id=None)
assert state == new_state
assert components.IP('192.168.1.3') not in new_state.known_services

def test_block_ip_same_host(self, env_obs_exploited_service):
env, state = env_obs_exploited_service
target_host = components.IP('192.168.2.2')
blocked_host = components.IP("1.1.1.1")
parameters = {
"target_host":target_host,
"source_host":target_host,
"blocked_host": blocked_host}
action = components.Action(components.ActionType.BlockIP, parameters)
new_state = env.step(state=state, action=action, agent_id=None)
assert target_host in new_state.known_blocks.keys()
assert blocked_host in new_state.known_blocks[target_host]
assert target_host in env._fw_blocks.keys()
assert blocked_host in env._fw_blocks[target_host]

def test_block_ip_same_different_source(self, env_obs_exploited_service):
env, state = env_obs_exploited_service
source_host = components.IP('192.168.2.2')
target_host = components.IP("192.168.1.3")
blocked_host = components.IP("1.1.1.1")
parameters = {
"target_host":target_host,
"source_host":source_host,
"blocked_host": blocked_host}
action = components.Action(components.ActionType.BlockIP, parameters)
new_state = env.step(state=state, action=action, agent_id=None)
assert target_host in new_state.known_blocks.keys()
assert blocked_host in new_state.known_blocks[target_host]
assert target_host in env._fw_blocks.keys()
assert blocked_host in env._fw_blocks[target_host]


def test_block_ip_self_block(self, env_obs_exploited_service):
env, state = env_obs_exploited_service
target_host = components.IP('192.168.2.2')
parameters = {
"target_host":target_host,
"source_host":components.IP('192.168.2.2'),
"blocked_host": target_host}
action = components.Action(components.ActionType.BlockIP, parameters)
new_state = env.step(state=state, action=action, agent_id=None)
assert target_host not in new_state.known_blocks.keys()
assert target_host not in env._fw_blocks.keys()

0 comments on commit 00ce5be

Please sign in to comment.