Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add feature for SymbolicSocket through sys_accept #1618

Merged
merged 5 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 135 additions & 29 deletions manticore/platforms/linux.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import struct
import time
import resource
from typing import Union, List, TypeVar, cast
from typing import Union, List, TypeVar, cast, Deque

import io
import os
Expand All @@ -22,12 +22,13 @@
from . import linux_syscalls
from .linux_syscall_stubs import SyscallStubs
from ..core.state import TerminateState
from ..core.smtlib import ConstraintSet, Operators, Expression, issymbolic
from ..core.smtlib import ConstraintSet, Operators, Expression, issymbolic, ArrayProxy
from ..core.smtlib.solver import Z3Solver
from ..exceptions import SolverError
from ..native.cpu.abstractcpu import Syscall, ConcretizeArgument, Interruption
from ..native.cpu.cpufactory import CpuFactory
from ..native.memory import SMemory32, SMemory64, Memory32, Memory64, LazySMemory32, LazySMemory64
from ..native.state import State
from ..platforms.platform import Platform, SyscallNotImplemented, unimplemented

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -368,14 +369,30 @@ def pair():
a.connect(b)
return a, b

def __init__(self):
def __init__(self, net: bool = False):
"""
Builds a normal socket that does not introduce symbolic bytes.

:param net: Whether this is a network socket
"""
from collections import deque

self.buffer = deque() # queue os bytes
self.buffer: Deque[
Union[bytes, Expression]
] = deque() # current bytes received but not read
self.peer = None
self.net = net

def __getstate__(self):
state = {"buffer": self.buffer, "net": self.net}
return state

def __setstate__(self, state):
self.buffer = state["buffer"]
self.net = state["net"]

def __repr__(self):
return f"SOCKET({hash(self):x}, {self.buffer!r}, {hash(self.peer):x})"
return f"SOCKET({hash(self):x}, buffer={self.buffer!r}, net={self.net}, peer={hash(self.peer):x})"

def is_connected(self):
return self.peer is not None
Expand Down Expand Up @@ -404,7 +421,14 @@ def receive(self, size):
return ret

def write(self, buf):
assert self.is_connected()
if self.net:
# Just return like we were able to send all data
# TODO: We should put this data somewhere to let the user know what is being sent.
# I don't think stdout is correct. It would be nice to know what is being sent
# over each network connection
return len(buf)
# If not a network Socket, it should be connected
assert self.is_connected(), f"Non-network socket is not connected: {self.__repr__()}"
return self.peer._transmit(buf)

def _transmit(self, buf):
Expand All @@ -425,6 +449,84 @@ def close(self):
pass


class SymbolicSocket(Socket):
"""
Symbolic sockets are generally used for network communications that contain user-controlled input.
"""

def __init__(
self,
constraints: ConstraintSet,
name: str,
max_recv_symbolic: int = 80,
net: bool = True,
wildcard: str = "+",
):
"""
Builds a symbolic socket.

:param constraints: the SMT constraints
:param name: The name of the SymbolicSocket, which is propagated to the symbolic variables introduced
:param max_recv_symbolic: Maximum number of bytes allowed to be read from this socket. 0 for unlimited
:param net: Whether this is a network connection socket
:param wildcard: Wildcard to be used for symbolic bytes in socket. Not supported, yet
"""
super().__init__(net=net)
self._constraints = constraints
self.symb_name = name
self.max_recv_symbolic = max_recv_symbolic # 0 for unlimited. Unlimited is not tested
# Keep track of the symbolic inputs we create
self.inputs_recvd: List[ArrayProxy] = []
self.recv_pos = 0

def __getstate__(self):
state = super().__getstate__()
state["inputs_recvd"] = self.inputs_recvd
state["symb_name"] = self.symb_name
state["recv_pos"] = self.recv_pos
state["max_recv_symbolic"] = self.max_recv_symbolic
state["constraints"] = self._constraints
return state

def __setstate__(self, state):
super().__setstate__(state)
self.inputs_recvd = state["inputs_recvd"]
self.symb_name = state["symb_name"]
self.recv_pos = state["recv_pos"]
self.max_recv_symbolic = state["max_recv_symbolic"]
self._constraints = state["constraints"]

def __repr__(self):
return f"SymbolicSocket({hash(self):x}, inputs_recvd={self.inputs_recvd}, buffer={self.buffer}, net={self.net}"

def _next_symb_name(self) -> str:
"""
Return the next name for a symbolic array, based on previous number of other receives
"""
return f"{self.symb_name}-{len(self.inputs_recvd)}"

def receive(self, size: int) -> Union[ArrayProxy, List[bytes]]:
"""
Return a symbolic array of either `size` or rest of remaining symbolic bytes
:param size: Size of receive
:return: Symbolic array or list of concrete bytes
"""
# NOTE: self.buffer isn't used at all for SymbolicSocket. Not sure if there is a better
# way to use it for on-demand generation of symbolic data or not.
rx_bytes = (
size
if self.max_recv_symbolic == 0
else min(size, self.max_recv_symbolic - self.recv_pos)
)
if rx_bytes == 0:
# If no symbolic bytes left, return empty list
return []
ret = self._constraints.new_array(name=self._next_symb_name(), index_max=rx_bytes)
self.recv_pos += rx_bytes
self.inputs_recvd.append(ret)
return ret


class Linux(Platform):
"""
A simple Linux Operating System Platform.
Expand Down Expand Up @@ -610,7 +712,7 @@ def __getstate__(self):
state_files = []
for fd in self.files:
if isinstance(fd, Socket):
state_files.append(("Socket", fd.buffer))
state_files.append(("Socket", fd))
else:
state_files.append(("File", fd))
state["files"] = state_files
Expand Down Expand Up @@ -657,13 +759,8 @@ def __setstate__(self, state):

# fetch each file descriptor (Socket or File())
self.files = []
for ty, file_or_buffer in state["files"]:
if ty == "Socket":
f = Socket()
f.buffer = file_or_buffer
self.files.append(f)
else:
self.files.append(file_or_buffer)
for ty, file_or_socket in state["files"]:
self.files.append(file_or_socket)

# If file descriptors for stdin/stdout/stderr aren't closed, propagate them
if self.files[0]:
Expand Down Expand Up @@ -2129,12 +2226,11 @@ def sys_accept4(self, sockfd, addr, addrlen, flags):
if ret != 0:
return ret

sock = Socket()
sock = Socket(net=True)
fd = self._open(sock)
return fd

def sys_recv(self, sockfd, buf, count, flags, trace_str="_recv"):
data: bytes = bytes()
if not self.current.memory.access_ok(slice(buf, buf + count), "w"):
logger.info("RECV: buf within invalid memory. Returning EFAULT")
return -errno.EFAULT
Expand All @@ -2148,6 +2244,8 @@ def sys_recv(self, sockfd, buf, count, flags, trace_str="_recv"):
return -errno.ENOTSOCK

data = sock.read(count)
if len(data) == 0:
return 0
self.syscall_trace.append((trace_str, sockfd, data))
self.current.write_bytes(buf, data)

Expand Down Expand Up @@ -3021,20 +3119,28 @@ def sys_recvfrom(self, sockfd, buf, count, flags, src_addr, addrlen):
return super().sys_recvfrom(sockfd, buf, count, flags, src_addr, addrlen)

def sys_accept(self, sockfd, addr, addrlen):
# TODO(yan): Transmit some symbolic bytes as soon as we start.
# Remove this hack once no longer needed.

fd = super().sys_accept(sockfd, addr, addrlen)
if fd < 0:
return fd
sock = self._get_fd(fd)
nbytes = 32
symb = self.constraints.new_array(
name=f"socket{fd}", index_max=nbytes, avoid_collisions=True
)
for i in range(nbytes):
sock.buffer.append(symb[i])
if issymbolic(sockfd):
logger.debug("Symbolic sockfd")
raise ConcretizeArgument(self, 0)

if issymbolic(addr):
logger.debug("Symbolic address")
raise ConcretizeArgument(self, 1)

if issymbolic(addrlen):
logger.debug("Symbolic address length")
raise ConcretizeArgument(self, 2)

ret = self._is_sockfd(sockfd)
if ret != 0:
return ret

# TODO: maybe combine name with addr?
sock = SymbolicSocket(self.constraints, "SymbSocket", net=True)
fd = self._open(sock)
return fd
# TODO: Make a concrete connection actually an option
# return super().sys_accept(sockfd, addr, addrlen)

def sys_open(self, buf, flags, mode):
"""
Expand Down
23 changes: 18 additions & 5 deletions tests/native/test_syscalls.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from manticore.core.smtlib import *
from manticore.platforms import linux, linux_syscall_stubs
from manticore.platforms.linux import SymbolicSocket
from manticore.platforms.platform import SyscallNotImplemented


Expand Down Expand Up @@ -156,23 +157,35 @@ def test_recvfrom(self):
self.assertEqual(conn_fd, 4)

sock_obj = self.linux.files[conn_fd]
# Any socket that comes from an accept, should probably be symbolic for now
assert isinstance(sock_obj, SymbolicSocket)

# Start with 0 symbolic bytes
init_len = len(sock_obj.buffer)
self.assertEqual(init_len, 0)

# Try to receive 5 symbolic bytes
BYTES = 5
wrote = self.linux.sys_recvfrom(conn_fd, 0x1100, BYTES, 0x0, 0x0, 0x0)
self.assertEqual(wrote, BYTES)

# Try to receive into address 0x0
wrote = self.linux.sys_recvfrom(conn_fd, 0x0, 100, 0x0, 0x0, 0x0)
self.assertEqual(wrote, -errno.EFAULT)

remain_len = init_len - BYTES
self.assertEqual(remain_len, len(sock_obj.buffer))

wrote = self.linux.sys_recvfrom(conn_fd, 0x1100, remain_len + 10, 0x0, 0x0, 0x0)
self.assertEqual(wrote, remain_len)
# Try to receive all remaining symbolic bytes plus some more
recvd_bytes = sock_obj.recv_pos
remaining_bytes = sock_obj.max_recv_symbolic - sock_obj.recv_pos
BYTES = remaining_bytes + 10
wrote = self.linux.sys_recvfrom(conn_fd, 0x1100, BYTES, 0x0, 0x0, 0x0)
self.assertNotEqual(wrote, BYTES)
self.assertEqual(wrote, remaining_bytes)

# Try to receive 10 more bytes when already at max
wrote = self.linux.sys_recvfrom(conn_fd, 0x1100, 10, 0x0, 0x0, 0x0)
self.assertEqual(wrote, 0)

# Close and make sure we can't write more stuff
self.linux.sys_close(conn_fd)
wrote = self.linux.sys_recvfrom(conn_fd, 0x1100, 10, 0x0, 0x0, 0x0)
self.assertEqual(wrote, -errno.EBADF)
Expand Down