Skip to content

Commit

Permalink
api/admin: firewall-related methods
Browse files Browse the repository at this point in the history
In the end firewall is implemented as .Get and .Set rules, with policy
statically set to 'drop'. This way allow atomic firewall updates.

Since we already have appropriate firewall format handling in
qubes.firewall module - reuse it from there, but adjust the code to be
prepared for potentially malicious input. And also mark such variables
with untrusted_ prefix.

There is also third method: .Reload - which cause firewall reload
without making any change.

QubesOS/qubes-issues#2622
Fixes QubesOS/qubes-issues#2869
  • Loading branch information
marmarek committed Jun 26, 2017
1 parent 842efb5 commit 0200fda
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 46 deletions.
35 changes: 35 additions & 0 deletions qubes/api/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

import qubes.api
import qubes.devices
import qubes.firewall
import qubes.storage
import qubes.utils
import qubes.vm
Expand Down Expand Up @@ -991,3 +992,37 @@ def vm_device_detach(self, endpoint):
dev.backend_domain, dev.ident)
self.dest.devices[devclass].detach(assignment)
self.app.save()

@qubes.api.method('admin.vm.firewall.Get', no_payload=True)
@asyncio.coroutine
def vm_firewall_get(self):
assert not self.arg

self.fire_event_for_permission()

return ''.join('{}\n'.format(rule.api_rule)
for rule in self.dest.firewall.rules)

@qubes.api.method('admin.vm.firewall.Set')
@asyncio.coroutine
def vm_firewall_set(self, untrusted_payload):
assert not self.arg
rules = []
for untrusted_line in untrusted_payload.decode('ascii',
errors='strict').splitlines():
rule = qubes.firewall.Rule.from_api_string(untrusted_line)
rules.append(rule)

self.fire_event_for_permission(rules=rules)

self.dest.firewall.rules = rules
self.dest.firewall.save()

@qubes.api.method('admin.vm.firewall.Reload', no_payload=True)
@asyncio.coroutine
def vm_firewall_reload(self):
assert not self.arg

self.fire_event_for_permission()

self.dest.fire_event('firewall-changed')
153 changes: 116 additions & 37 deletions qubes/firewall.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#

import datetime
import string
import subprocess

import itertools
Expand All @@ -34,13 +35,22 @@


class RuleOption(object):
def __init__(self, value):
self._value = str(value)
def __init__(self, untrusted_value):
# subset of string.punctuation
safe_set = string.ascii_letters + string.digits + \
':;,./-_[]'
assert all(x in safe_set for x in str(untrusted_value))
value = str(untrusted_value)
self._value = value

@property
def rule(self):
raise NotImplementedError

@property
def api_rule(self):
return self.rule

def __str__(self):
return self._value

Expand All @@ -50,14 +60,15 @@ def __eq__(self, other):
# noinspection PyAbstractClass
class RuleChoice(RuleOption):
# pylint: disable=abstract-method
def __init__(self, value):
super(RuleChoice, self).__init__(value)
def __init__(self, untrusted_value):
# preliminary validation
super(RuleChoice, self).__init__(untrusted_value)
self.allowed_values = \
[v for k, v in self.__class__.__dict__.items()
if not k.startswith('__') and isinstance(v, str) and
not v.startswith('__')]
if value not in self.allowed_values:
raise ValueError(value)
if untrusted_value not in self.allowed_values:
raise ValueError(untrusted_value)


class Action(RuleChoice):
Expand All @@ -81,14 +92,14 @@ def rule(self):

class DstHost(RuleOption):
'''Represent host/network address: either IPv4, IPv6, or DNS name'''
def __init__(self, value, prefixlen=None):
# TODO: in python >= 3.3 ipaddress module could be used
if value.count('/') > 1:
raise ValueError('Too many /: ' + value)
elif not value.count('/'):
def __init__(self, untrusted_value, prefixlen=None):
if untrusted_value.count('/') > 1:
raise ValueError('Too many /: ' + untrusted_value)
elif not untrusted_value.count('/'):
# add prefix length to bare IP addresses
try:
socket.inet_pton(socket.AF_INET6, value)
socket.inet_pton(socket.AF_INET6, untrusted_value)
value = untrusted_value
self.prefixlen = prefixlen or 128
if self.prefixlen < 0 or self.prefixlen > 128:
raise ValueError(
Expand All @@ -97,10 +108,11 @@ def __init__(self, value, prefixlen=None):
self.type = 'dst6'
except socket.error:
try:
socket.inet_pton(socket.AF_INET, value)
if value.count('.') != 3:
socket.inet_pton(socket.AF_INET, untrusted_value)
if untrusted_value.count('.') != 3:
raise ValueError(
'Invalid number of dots in IPv4 address')
value = untrusted_value
self.prefixlen = prefixlen or 32
if self.prefixlen < 0 or self.prefixlen > 32:
raise ValueError(
Expand All @@ -110,28 +122,33 @@ def __init__(self, value, prefixlen=None):
except socket.error:
self.type = 'dsthost'
self.prefixlen = 0
safe_set = string.ascii_lowercase + string.digits + '-._'
assert all(c in safe_set for c in untrusted_value)
value = untrusted_value
else:
host, prefixlen = value.split('/', 1)
prefixlen = int(prefixlen)
untrusted_host, untrusted_prefixlen = untrusted_value.split('/', 1)
prefixlen = int(untrusted_prefixlen)
if prefixlen < 0:
raise ValueError('netmask must be non-negative')
self.prefixlen = prefixlen
try:
socket.inet_pton(socket.AF_INET6, host)
socket.inet_pton(socket.AF_INET6, untrusted_host)
value = untrusted_value
if prefixlen > 128:
raise ValueError('netmask for IPv6 must be <= 128')
self.type = 'dst6'
except socket.error:
try:
socket.inet_pton(socket.AF_INET, host)
socket.inet_pton(socket.AF_INET, untrusted_host)
if prefixlen > 32:
raise ValueError('netmask for IPv4 must be <= 32')
self.type = 'dst4'
if host.count('.') != 3:
if untrusted_host.count('.') != 3:
raise ValueError(
'Invalid number of dots in IPv4 address')
value = untrusted_value
except socket.error:
raise ValueError('Invalid IP address: ' + host)
raise ValueError('Invalid IP address: ' + untrusted_host)

super(DstHost, self).__init__(value)

Expand All @@ -141,15 +158,15 @@ def rule(self):


class DstPorts(RuleOption):
def __init__(self, value):
if isinstance(value, int):
value = str(value)
if value.count('-') == 1:
self.range = [int(x) for x in value.split('-', 1)]
elif not value.count('-'):
self.range = [int(value), int(value)]
def __init__(self, untrusted_value):
if isinstance(untrusted_value, int):
untrusted_value = str(untrusted_value)
if untrusted_value.count('-') == 1:
self.range = [int(x) for x in untrusted_value.split('-', 1)]
elif not untrusted_value.count('-'):
self.range = [int(untrusted_value), int(untrusted_value)]
else:
raise ValueError(value)
raise ValueError(untrusted_value)
if any(port < 0 or port > 65536 for port in self.range):
raise ValueError('Ports out of range')
if self.range[0] > self.range[1]:
Expand All @@ -164,11 +181,11 @@ def rule(self):


class IcmpType(RuleOption):
def __init__(self, value):
super(IcmpType, self).__init__(value)
value = int(value)
if value < 0 or value > 255:
def __init__(self, untrusted_value):
untrusted_value = int(untrusted_value)
if untrusted_value < 0 or untrusted_value > 255:
raise ValueError('ICMP type out of range')
super(IcmpType, self).__init__(untrusted_value)

@property
def rule(self):
Expand All @@ -184,24 +201,42 @@ def rule(self):


class Expire(RuleOption):
def __init__(self, value):
super(Expire, self).__init__(value)
self.datetime = datetime.datetime.utcfromtimestamp(int(value))
def __init__(self, untrusted_value):
super(Expire, self).__init__(untrusted_value)
self.datetime = datetime.datetime.utcfromtimestamp(int(untrusted_value))

@property
def rule(self):
return None

@property
def api_rule(self):
return 'expire=' + str(self)

@property
def expired(self):
return self.datetime < datetime.datetime.utcnow()


class Comment(RuleOption):
# noinspection PyMissingConstructor
def __init__(self, untrusted_value):
# pylint: disable=super-init-not-called
# subset of string.punctuation
safe_set = string.ascii_letters + string.digits + \
':;,./-_[] '
assert all(x in safe_set for x in str(untrusted_value))
value = str(untrusted_value)
self._value = value

@property
def rule(self):
return None

@property
def api_rule(self):
return 'comment=' + str(self)


class Rule(qubes.PropertyHolder):
def __init__(self, xml=None, **kwargs):
Expand Down Expand Up @@ -311,6 +346,20 @@ def rule(self):
values.append(value.rule)
return ' '.join(values)

@property
def api_rule(self):
values = []
# put comment at the end
for prop in sorted(self.property_list(),
key=(lambda p: p.__name__ == 'comment')):
value = getattr(self, prop.__name__)
if value is None:
continue
if value.api_rule is None:
continue
values.append(value.api_rule)
return ' '.join(values)

@classmethod
def from_xml_v1(cls, node, action):
netmask = node.get('netmask')
Expand Down Expand Up @@ -358,8 +407,39 @@ def from_xml_v1(cls, node, action):

return cls(**kwargs)

@classmethod
def from_api_string(cls, untrusted_rule):
'''Parse a single line of firewall rule'''
# comment is allowed to have spaces
untrusted_options, _, untrusted_comment = untrusted_rule.partition(
'comment=')
# appropriate handlers in __init__ of individual options will perform
# option-specific validation
kwargs = {}
if untrusted_comment:
kwargs['comment'] = untrusted_comment

for untrusted_option in untrusted_options.strip().split(' '):
untrusted_key, untrusted_value = untrusted_option.split('=', 1)
if untrusted_key in kwargs:
raise ValueError('Option \'{}\' already set'.format(
untrusted_key))
if untrusted_key in [str(prop) for prop in cls.property_list()]:
kwargs[untrusted_key] = untrusted_value
elif untrusted_key in ('dst4', 'dst6', 'dstname'):
kwargs['dsthost'] = untrusted_value
else:
raise ValueError('Unknown firewall option')

return cls(**kwargs)

def __eq__(self, other):
return self.rule == other.rule
if isinstance(other, Rule):
return self.api_rule == other.api_rule
return self.api_rule == str(other)

def __hash__(self):
return hash(self.api_rule)


class Firewall(object):
Expand Down Expand Up @@ -496,7 +576,6 @@ def save(self):
subprocess.call(["sudo", "systemctl", "start",
"qubes-reload-firewall@%s.timer" % self.vm.name])


def qdb_entries(self, addr_family=None):
'''Return firewall settings serialized for QubesDB entries
Expand Down
8 changes: 6 additions & 2 deletions qubes/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def fire_event(self, event, **kwargs):
effects = super(TestEmitter, self).fire_event(event, **kwargs)
ev_kwargs = frozenset(
(key,
frozenset(value.items()) if isinstance(value, dict) else value)
frozenset(value.items()) if isinstance(value, dict)
else tuple(value) if isinstance(value, list)
else value)
for key, value in kwargs.items()
)
self.fired_events[(event, ev_kwargs)] += 1
Expand All @@ -161,7 +163,9 @@ def fire_event_pre(self, event, **kwargs):
effects = super(TestEmitter, self).fire_event_pre(event, **kwargs)
ev_kwargs = frozenset(
(key,
frozenset(value.items()) if isinstance(value, dict) else value)
frozenset(value.items()) if isinstance(value, dict)
else tuple(value) if isinstance(value, list)
else value)
for key, value in kwargs.items()
)
self.fired_events[(event, ev_kwargs)] += 1
Expand Down
Loading

0 comments on commit 0200fda

Please sign in to comment.