Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Async UDP Socket for QUIC #3

Draft
wants to merge 3 commits into
base: feature/quic
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 45 additions & 0 deletions quic/lib/async_socket.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
extern "C" {
#include <runtime/smalloc.h>
#include <runtime/udp.h>
}

#include <quic/lib/async_socket.h>
#include <quic/lib/async_socket_exception.h>

#include <cstddef>

namespace quic {

void AsyncUDPSocket::Init(sa_family_t family) {
NetworkSocket socket =
netops::socket(family, SOCK_DGRAM, family != AF_UNIX ? IPPROTO_UDP : 0);
if (socket == NetworkSocket()) {
throw AsyncSocketException(AsyncSocketException::NOT_OPEN,
"error creating async udp socket", errno);
}

fd_ = socket;
ownership_ = FDOwnership::OWNS;
}

void AsyncUDPSocket::Bind(const quic::SocketAddress& address) {
Init(address.GetFamily());

sockaddr_storage addr_storage;
address.GetAddress(&addr_storage);
auto& saddr = reinterpret_cast<sockaddr&>(addr_storage);
if (netops::bind(fd_, &saddr, address.GetActualSize()) != 0) {
throw AsyncSocketException(
AsyncSocketException::NOT_OPEN,
"failed to bind the async udp socket for: " + address.Describe(),
errno);
}

if (address.GetFamily() == AF_UNIX || address.GetPort() != 0) {
local_address_ = address;
} else {
local_address_.SetFromLocalAddress(fd_);
}
}

} // namespace quic
164 changes: 164 additions & 0 deletions quic/lib/async_socket.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#pragma once

extern "C" {
#include <base/kref.h>
#include <runtime/net/defs.h>
#include <runtime/net/waitq.h>
#include <runtime/udp.h>
}

#include <thread>

#include <quic/lib/netops.h>
#include <quic/lib/network_socket.h>
#include <quic/lib/socket_address.h>

namespace quic {

class AsyncUDPSocket {
public:
enum class FDOwnership { OWNS, SHARED };

AsyncUDPSocket(const AsyncUDPSocket&) = delete;
AsyncUDPSocket& operator=(const AsyncUDPSocket&) = delete;

class ReadCallback {
public:
struct OnDataAvailableParams {
int gro = -1;
};

/**
* Invoked when the socket becomes readable and we want buffer
* to write to.
*
* NOTE: From socket we will end up reading at most `len` bytes
* and if there were more bytes in datagram, we will end up
* dropping them.
*/
virtual void GetReadBuffer(void** buf, size_t* len) noexcept = 0;

/**
* Invoked when a new datagram is available on the socket. `len`
* is the number of bytes read and `truncated` is true if we had
* to drop few bytes because of running out of buffer space.
* OnDataAvailableParams::gro is the GRO segment size
*/
virtual void OnDataAvailable(const quic::SocketAddress& client, size_t len,
bool truncated,
OnDataAvailableParams params) noexcept = 0;

/**
* Notifies when data is available. This is only invoked when
* shouldNotifyOnly() returns true.
*/
virtual void OnNotifyDataAvailable(AsyncUDPSocket&) noexcept {}

/**
* Returns whether or not the read callback should only notify
* but not call getReadBuffer.
* If shouldNotifyOnly() returns true, AsyncUDPSocket will invoke
* onNotifyDataAvailable() instead of getReadBuffer().
* If shouldNotifyOnly() returns false, AsyncUDPSocket will invoke
* getReadBuffer() and onDataAvailable().
*/
virtual bool ShouldNotifyOnly() { return false; }

/**
* Invoked when there is an error reading from the socket.
*
* NOTE: Since UDP is connectionless, you can still read from the socket.
* But you have to re-register readCallback yourself after
* onReadError.
*/
virtual void OnReadError(const AsyncSocketException& ex) noexcept = 0;

/**
* Invoked when socket is closed and a read callback is registered.
*/
virtual void OnReadClosed() noexcept = 0;

virtual ~ReadCallback() = default;
};

class ErrMessageCallback {
public:
virtual ~ErrMessageCallback() = default;

/**
* errMessage() will be invoked when kernel puts a message to
* the error queue associated with the socket.
*
* @param cmsg Reference to cmsghdr structure describing
* a message read from error queue associated
* with the socket.
*/
virtual void ErrMessage(const cmsghdr& cmsg) noexcept = 0;

/**
* errMessageError() will be invoked if an error occurs reading a message
* from the socket error stream.
*
* @param ex An exception describing the error that occurred.
*/
virtual void ErrMessageError(const AsyncSocketException& ex) noexcept = 0;
};

struct WriteOptions {
WriteOptions() = default;
WriteOptions(int gsoVal) : gso(gsoVal) {}
int gso{0};
};

/**
* Returns the address server is listening on
*/
virtual const quic::SocketAddress& Address() const {
// CHECK_NE(NetworkSocket(), fd_) << "Server not yet bound to an address";
return local_address_;
}

/**
* Bind the socket to the following address. If port is not
* set in the `address` an ephemeral port is chosen and you can
* use `Address()` method above to get it after this method successfully
* returns.
*/
virtual void Bind(const quic::SocketAddress& address);

/**
* Connects the UDP socket to a remote destination address provided in
* address. This can speed up UDP writes on linux because it will cache flow
* state on connects.
* Using connect has many quirks, and you should be aware of them before using
* this API:
* 1. If this is called before bind, the socket will be automatically bound to
* the IP address of the current default network interface.
* 2. Normally UDP can use the 2 tuple (src ip, src port) to steer packets
* sent by the peer to the socket, however after connecting the socket, only
* packets destined to the destination address specified in connect() will be
* forwarded and others will be dropped. If the server can send a packet
* from a different destination port / IP then you probably do not want to use
* this API.
* 3. It can be called repeatedly on either the client or server however it's
* normally only useful on the client and not server.
*/
virtual void Connect(const quic::SocketAddress& address);

private:
void Init(sa_family_t family);

NetworkSocket fd_;
FDOwnership ownership_;

// Temp space to receive client address.
quic::SocketAddress client_address_;

quic::SocketAddress local_address_;

// If the socket is connected.
quic::SocketAddress connected_address_;
bool connected_{false};
};

} // namespace quic
33 changes: 33 additions & 0 deletions quic/lib/async_socket_exception.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#include <quic/lib/async_socket_exception.h>

#include <cerrno>
#include <cstring>

namespace quic {

std::string AsyncSocketException::GetExceptionTypeString(
AsyncSocketExceptionType type) {
switch (type) {
case UNKNOWN:
return "Unknown async socked exception";
case NOT_OPEN:
return "Socket not open";
default:
return "Invalid exception type";
}
}

std::string AsyncSocketException::GetMessage(AsyncSocketExceptionType type,
const std::string& message,
int errno_copy) {
if (errno != 0) {
return "AsyncSocketException: " + message +
", type = " + GetExceptionTypeString(type) +
", errno = " + std::strerror(errno_copy);
} else {
return "AsyncSocketException: " + message +
", type = " + GetExceptionTypeString(type);
}
}

} // namespace quic
39 changes: 39 additions & 0 deletions quic/lib/async_socket_exception.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#pragma once

#include <stdexcept>
#include <string>

namespace quic {

class AsyncSocketException : public std::runtime_error {
public:
enum AsyncSocketExceptionType {
UNKNOWN = 0,
NOT_OPEN = 1,
};

AsyncSocketException(AsyncSocketExceptionType type,
const std::string& message, int errno_copy = 0)
: std::runtime_error(GetMessage(type, message, errno_copy)),
type_(type),
errno_(errno_copy) {}

AsyncSocketExceptionType GetType() const noexcept { return type_; }

int getErrno() const noexcept { return errno_; }

protected:
static std::string GetExceptionTypeString(AsyncSocketExceptionType type);

// Return a message based on the input.
static std::string GetMessage(AsyncSocketExceptionType type,
const std::string& message, int errno_copy);

// Error code.
AsyncSocketExceptionType type_;

// A copy of the errno.
int errno_;
};

} // namespace quic
39 changes: 39 additions & 0 deletions quic/lib/netops.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#include <quic/lib/netops.h>

#include <runtime/smalloc.h>

#define UDP_IN_DEFAULT_CAP 512
#define UDP_OUT_DEFAULT_CAP 2048

namespace quic {

NetworkSocket socket(int af, int type, int protocol) {
assert(type == SOCK_DGRAM);

quic::NativeSocket sock;
sock = static_cast<quic::NativeSocket>(smalloc(sizeof sock));
if (!sock) return NetworkSocket();

sock->shutdown = false;

// initialize ingress fields
spin_lock_init(&sock->inq_lock);
sock->inq_cap = UDP_IN_DEFAULT_CAP;
sock->inq_len = 0;
sock->inq_err = 0;
waitq_init(&sock->inq_wq);
mbufq_init(&sock->inq);

// initialize egress fields
spin_lock_init(&sock->outq_lock);
sock->outq_free = false;
sock->outq_cap = UDP_OUT_DEFAULT_CAP;
sock->outq_len = 0;
waitq_init(&sock->outq_wq);

kref_init(&sock->ref);

return NetworkSocket(sock);
}

} // namespace quic
35 changes: 35 additions & 0 deletions quic/lib/netops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once

#include <sys/socket.h>

#include <quic/lib/network_socket.h>

namespace quic {

namespace netops {

NetworkSocket accept(NetworkSocket s, sockaddr* addr, socklen_t* addrlen);
int bind(NetworkSocket s, const sockaddr* name, socklen_t namelen);
int close(NetworkSocket s);
int connect(NetworkSocket s, const sockaddr* name, socklen_t namelen);
int getpeername(NetworkSocket s, sockaddr* name, socklen_t* namelen);
int getsockname(NetworkSocket s, sockaddr* name, socklen_t* namelen);
int listen(NetworkSocket s, int backlog);
ssize_t recv(NetworkSocket s, void* buf, size_t len, int flag);
ssize_t recvfrom(NetworkSocket s, void* buf, size_t len, int flags,
sockaddr* from, socklen_t* fromlen);
ssize_t recvmsg(NetworkSocket s, msghdr* message, int flags);
int recvmmsg(NetworkSocket s, mmsghdr* msgvec, unsigned int vlen,
unsigned int flags, timespec* timeout);
ssize_t send(NetworkSocket s, const void* buf, size_t len, int flags);
ssize_t sendto(NetworkSocket s, const void* buf, size_t len, int flags,
const sockaddr* to, socklen_t tolen);
ssize_t sendmsg(NetworkSocket socket, const msghdr* message, int flags);
int sendmmsg(NetworkSocket socket, mmsghdr* msgvec, unsigned int vlen,
int flags);
int shutdown(NetworkSocket s, int how);
NetworkSocket socket(int af, int type, int protocol);

} // namespace netops

} // namespace quic
Loading