Skip to content

Commit

Permalink
Block extensions disallowed by policy (#3259)
Browse files Browse the repository at this point in the history
* Block disallowed extension processing
* Enable policy e2e tests
---------

Co-authored-by: Norberto Arrieta <[email protected]>
  • Loading branch information
mgunnala and narrieta authored Jan 22, 2025
1 parent 5a646ff commit 9e2ada3
Show file tree
Hide file tree
Showing 18 changed files with 1,546 additions and 45 deletions.
149 changes: 123 additions & 26 deletions azurelinuxagent/ga/exthandlers.py

Large diffs are not rendered by default.

12 changes: 4 additions & 8 deletions azurelinuxagent/ga/policy/policy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@
_MAX_SUPPORTED_POLICY_VERSION = "0.1.0"


class PolicyError(AgentError):
"""
Error raised during agent policy enforcement.
"""


class InvalidPolicyError(AgentError):
"""
Error raised if user-provided policy is invalid.
Expand All @@ -50,7 +44,6 @@ def __init__(self, msg, inner=None):
msg = "Customer-provided policy file ('{0}') is invalid, please correct the following error: {1}".format(conf.get_policy_file_path(), msg)
super(InvalidPolicyError, self).__init__(msg, inner)


class _PolicyEngine(object):
"""
Implements base policy engine API.
Expand All @@ -61,6 +54,7 @@ def __init__(self):
if not self.policy_enforcement_enabled:
return

_PolicyEngine._log_policy_event("Policy enforcement is enabled.")
self._policy = self._parse_policy(self.__read_policy())

@staticmethod
Expand Down Expand Up @@ -98,8 +92,10 @@ def __read_policy():
with open(conf.get_policy_file_path(), 'r') as f:
try:
contents = f.read()
# TODO: Consider copying the policy file contents to the history folder, and only log the policy locally
# in the case of policy-related failure.
_PolicyEngine._log_policy_event(
"Policy enforcement is enabled. Enforcing policy using policy file found at '{0}'. File contents:\n{1}"
"Enforcing policy using policy file found at '{0}'. File contents:\n{1}"
.format(conf.get_policy_file_path(), contents))
# json.loads will raise error if file contents are not a valid json (including empty file).
custom_policy = json.loads(contents)
Expand Down
2 changes: 1 addition & 1 deletion tests/data/test_waagent.conf
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,4 @@ OS.SshDir=/notareal/path
# - The default is false to protect the state of existing VMs
OS.EnableFirewall=n

Debug.EnableExtensionPolicy=y
Debug.EnableExtensionPolicy=n
219 changes: 214 additions & 5 deletions tests/ga/test_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def test_migration_ignores_tree_remove_errors(self, shutil_mock): # pylint: dis
class TestExtensionBase(AgentTestCase):
def _assert_handler_status(self, report_vm_status, expected_status,
expected_ext_count, version,
expected_handler_name="OSTCExtensions.ExampleHandlerLinux", expected_msg=None):
expected_handler_name="OSTCExtensions.ExampleHandlerLinux", expected_msg=None, expected_code=None):
self.assertTrue(report_vm_status.called)
args, kw = report_vm_status.call_args # pylint: disable=unused-variable
vm_status = args[0]
Expand All @@ -443,6 +443,9 @@ def _assert_handler_status(self, report_vm_status, expected_status,
if expected_msg is not None:
self.assertIn(expected_msg, handler_status.message)

if expected_code is not None:
self.assertEqual(expected_code, handler_status.code)


# Deprecated. New tests should be added to the TestExtension class
@patch('time.sleep', side_effect=lambda _: mock_sleep(0.001))
Expand Down Expand Up @@ -1649,13 +1652,13 @@ def test_extensions_disabled(self, _, *args):
vm_status = args[0]
self.assertEqual(1, len(vm_status.vmAgent.extensionHandlers))
exthandler = vm_status.vmAgent.extensionHandlers[0]
self.assertEqual(-1, exthandler.code)
self.assertEqual(ExtensionErrorCodes.PluginEnableProcessingFailed, exthandler.code)
self.assertEqual('NotReady', exthandler.status)
self.assertEqual("Extension will not be processed since extension processing is disabled. To enable extension processing, set Extensions.Enabled=y in '/etc/waagent.conf'", exthandler.message)
self.assertEqual("Extension 'OSTCExtensions.ExampleHandlerLinux' will not be processed since extension processing is disabled. To enable extension processing, set Extensions.Enabled=y in '/etc/waagent.conf'", exthandler.message)
ext_status = exthandler.extension_status
self.assertEqual(-1, ext_status.code)
self.assertEqual(ExtensionErrorCodes.PluginEnableProcessingFailed, ext_status.code)
self.assertEqual('error', ext_status.status)
self.assertEqual("Extension will not be processed since extension processing is disabled. To enable extension processing, set Extensions.Enabled=y in '/etc/waagent.conf'", ext_status.message)
self.assertEqual("Extension 'OSTCExtensions.ExampleHandlerLinux' will not be processed since extension processing is disabled. To enable extension processing, set Extensions.Enabled=y in '/etc/waagent.conf'", ext_status.message)

def test_extensions_deleted(self, *args):
# Ensure initial enable is successful
Expand Down Expand Up @@ -3507,5 +3510,211 @@ def test_report_msg_if_handler_manifest_contains_invalid_values(self):
self.assertIn("'supportsMultipleExtensions' has a non-boolean value", kw_messages[2]['message'])


class TestExtensionPolicy(TestExtensionBase):
def setUp(self):
AgentTestCase.setUp(self)
self.policy_path = os.path.join(self.tmp_dir, "waagent_policy.json")

# Patch attributes to enable policy feature
self.patch_policy_path = patch('azurelinuxagent.common.conf.get_policy_file_path',
return_value=str(self.policy_path))
self.patch_policy_path.start()
self.patch_conf_flag = patch('azurelinuxagent.ga.policy.policy_engine.conf.get_extension_policy_enabled',
return_value=True)
self.patch_conf_flag.start()
self.maxDiff = None # When long error messages don't match, display the entire diff.

def tearDown(self):
patch.stopall()
AgentTestCase.tearDown(self)

def _create_policy_file(self, policy):
with open(self.policy_path, mode='w') as policy_file:
if isinstance(policy, dict):
json.dump(policy, policy_file, indent=4)
else:
policy_file.write(policy)
policy_file.flush()

def _test_policy_case(self, policy, op, expected_status_code, expected_handler_status, expected_ext_count,
expected_status_msg=None):

# Set up a mock protocol instance.
with mock_wire_protocol(wire_protocol_data.DATA_FILE) as protocol:
if op == ExtensionRequestedState.Uninstall:
# Generate a new mock goal state to uninstall the extension - increment the incarnation
protocol.mock_wire_data.set_incarnation(2)
protocol.mock_wire_data.set_extensions_config_state(ExtensionRequestedState.Uninstall)
protocol.client.update_goal_state()
protocol.aggregate_status = None
protocol.report_vm_status = MagicMock()
exthandlers_handler = get_exthandlers_handler(protocol)

# Create policy file and process extensions.
self._create_policy_file(policy)
exthandlers_handler.run()
exthandlers_handler.report_ext_handlers_status()

# Assert that agent is reporting the expected handler status
report_vm_status = protocol.report_vm_status
self.assertTrue(report_vm_status.called)
self._assert_handler_status(report_vm_status, expected_handler_status, expected_ext_count=expected_ext_count,
version="1.0.0", expected_handler_name='OSTCExtensions.ExampleHandlerLinux',
expected_msg=expected_status_msg, expected_code=expected_status_code)

def test_should_fail_enable_if_extension_disallowed(self):
policy = \
{
"policyVersion": "0.1.0",
"extensionPolicies": {
"allowListedExtensionsOnly": True,
}
}
expected_msg = "failed to run extension 'OSTCExtensions.ExampleHandlerLinux' because it is not specified as an allowed extension."
self._test_policy_case(policy=policy, op=ExtensionRequestedState.Enabled, expected_status_code=ExtensionErrorCodes.PluginEnableProcessingFailed,
expected_handler_status='NotReady', expected_ext_count=1, expected_status_msg=expected_msg)

def test_should_fail_enable_for_invalid_policy(self):
policy = \
{
"policyVersion": "0.1.0",
"extensionPolicies": {
"allowListedExtensionsOnly": "False"
}
}
expected_msg = "attribute 'extensionPolicies.allowListedExtensionsOnly'; must be 'boolean'"
self._test_policy_case(policy=policy, op=ExtensionRequestedState.Enabled, expected_status_code=ExtensionErrorCodes.PluginEnableProcessingFailed,
expected_handler_status='NotReady', expected_ext_count=1, expected_status_msg=expected_msg)

def test_should_fail_extension_if_error_thrown_during_policy_engine_init(self):
policy = \
{
"policyVersion": "0.1.0"
}
with patch('azurelinuxagent.ga.policy.policy_engine.ExtensionPolicyEngine.__init__',
side_effect=Exception("mock exception")):
expected_msg = "Extension will not be processed: mock exception"
self._test_policy_case(policy=policy, op=ExtensionRequestedState.Enabled,
expected_status_code=ExtensionErrorCodes.PluginEnableProcessingFailed,
expected_handler_status='NotReady', expected_ext_count=1, expected_status_msg=expected_msg)

def test_should_fail_uninstall_if_extension_disallowed(self):
policy = \
{
"policyVersion": "0.1.0",
"extensionPolicies": {
"allowListedExtensionsOnly": True,
"signatureRequired": False,
"extensions": {}
},
}
expected_msg = "failed to uninstall extension 'OSTCExtensions.ExampleHandlerLinux' because it is not specified as an allowed extension."
self._test_policy_case(policy=policy, op=ExtensionRequestedState.Uninstall, expected_status_code=ExtensionErrorCodes.PluginDisableProcessingFailed,
expected_handler_status='NotReady', expected_ext_count=1, expected_status_msg=expected_msg)

def test_should_fail_enable_if_dependent_extension_disallowed(self):
self._create_policy_file({
"policyVersion": "0.1.0",
"extensionPolicies": {
"allowListedExtensionsOnly": True,
"extensions": {
"OSTCExtensions.ExampleHandlerLinux": {}
}
}
})
with mock_wire_protocol(wire_protocol_data.DATA_FILE_EXT_SEQUENCING) as protocol:
protocol.aggregate_status = None
protocol.report_vm_status = MagicMock()
exthandlers_handler = get_exthandlers_handler(protocol)
dep_ext_level_2 = extension_emulator(name="OSTCExtensions.ExampleHandlerLinux")
dep_ext_level_1 = extension_emulator(name="OSTCExtensions.OtherExampleHandlerLinux")

exthandlers_handler.run()
exthandlers_handler.report_ext_handlers_status()

# OtherExampleHandlerLinux should be disallowed by policy, ExampleHandlerLinux should be skipped because
# dependent extension failed
self._assert_handler_status(protocol.report_vm_status, expected_status="NotReady", expected_ext_count=1,
version="1.0.0", expected_handler_name="OSTCExtensions.OtherExampleHandlerLinux",
expected_msg=("failed to run extension 'OSTCExtensions.OtherExampleHandlerLinux' "
"because it is not specified as an allowed extension."))

self._assert_handler_status(protocol.report_vm_status, expected_status="NotReady", expected_ext_count=0,
version="1.0.0", expected_handler_name="OSTCExtensions.ExampleHandlerLinux",
expected_msg="Skipping processing of extensions since execution of dependent "
"extension OSTCExtensions.OtherExampleHandlerLinux failed")

# check handler list and dependency levels
self.assertTrue(exthandlers_handler.ext_handlers is not None)
self.assertTrue(exthandlers_handler.ext_handlers is not None)
self.assertEqual(len(exthandlers_handler.ext_handlers), 2)
self.assertEqual(1, next(handler for handler in exthandlers_handler.ext_handlers if
handler.name == dep_ext_level_1.name).settings[0].dependencyLevel)
self.assertEqual(2, next(handler for handler in exthandlers_handler.ext_handlers if
handler.name == dep_ext_level_2.name).settings[0].dependencyLevel)

def test_enable_should_succeed_if_extension_allowed(self):
policy_cases = [
{
"policyVersion": "0.1.0",
"extensionPolicies": {
"allowListedExtensionsOnly": False,
}
},
{
"policyVersion": "0.1.0",
"extensionPolicies": {
"allowListedExtensionsOnly": True,
"extensions": {
"OSTCExtensions.ExampleHandlerLinux": {}
}
}
}
]
for policy in policy_cases:
self._test_policy_case(policy=policy, op=ExtensionRequestedState.Enabled, expected_status_code=0,
expected_handler_status='Ready', expected_ext_count=1)

def test_uninstall_should_succeed_if_extension_allowed(self):
policy_cases = [
{
"policyVersion": "0.1.0",
"extensionPolicies": {
"allowListedExtensionsOnly": False,
}
},
{
"policyVersion": "0.1.0",
"extensionPolicies": {
"allowListedExtensionsOnly": True,
"extensions": {
"OSTCExtensions.ExampleHandlerLinux": {}
}
}
}
]
for policy in policy_cases:
with mock_wire_protocol(wire_protocol_data.DATA_FILE) as protocol:
# Generate a new mock goal state to uninstall the extension - increment the incarnation
protocol.mock_wire_data.set_incarnation(2)
protocol.mock_wire_data.set_extensions_config_state(ExtensionRequestedState.Uninstall)
protocol.client.update_goal_state()
protocol.aggregate_status = None
protocol.report_vm_status = MagicMock()
exthandlers_handler = get_exthandlers_handler(protocol)

# Create policy file and process extensions.
self._create_policy_file(policy)
exthandlers_handler.run()
exthandlers_handler.report_ext_handlers_status()

# Assert that no status is being reported for the extension, to confirm that uninstall was successful.
report_vm_status = protocol.report_vm_status
self.assertTrue(report_vm_status.called)
args, kw = report_vm_status.call_args # pylint: disable=unused-variable
vm_status = args[0]
self.assertEqual(0, len(vm_status.vmAgent.extensionHandlers))


if __name__ == '__main__':
unittest.main()
Loading

0 comments on commit 9e2ada3

Please sign in to comment.