Skip to content

Commit

Permalink
Use correct async read handling
Browse files Browse the repository at this point in the history
  • Loading branch information
pcgod committed Sep 4, 2010
1 parent 22f1eb1 commit cce131f
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 72 deletions.
138 changes: 87 additions & 51 deletions client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,6 @@
#include "settings.h"
#include "user.h"

using MumbleClient::mumble_message::MessageHeader;
using MumbleClient::mumble_message::Message;

///////////////////////////////////////////////////////////////////////////////

#define MUMBLE_VERSION(x, y, z) ((x << 16) | (y << 8) | (z & 0xFF))
Expand All @@ -35,6 +32,36 @@ template <class T> T ConstructProtobufObject(void* buffer, int32_t length, bool

namespace MumbleClient {

class MessageHeader {
public:
int16_t type() const { return (d_[0] << 8) + d_[1]; }
int32_t length() const { return (d_[2] << 24) + (d_[3] << 16) + (d_[4] << 8) + d_[5]; }
void type(int16_t t_) { d_[0] = t_ >> 8; d_[1] = t_ & 0xFF; }
void length(int32_t l_) {
d_[2] = static_cast<unsigned char>(l_ >> 24);
d_[3] = static_cast<unsigned char>(l_ >> 16);
d_[4] = static_cast<unsigned char>(l_ >> 8);
d_[5] = static_cast<unsigned char>(l_ & 0xFF);
}

const unsigned char* data() const { return d_; }

friend std::istream& operator>>(std::istream& is, MessageHeader& header) {
return is.read(reinterpret_cast<char*>(header.d_), 6);
}

private:
unsigned char d_[6];
};

class Message {
public:
MessageHeader header_;
std::string msg_;

Message(const MessageHeader& header, const std::string& msg) : header_(header), msg_(msg) {};
};

///////////////////////////////////////////////////////////////////////////////
// MumbleClient, private:

Expand All @@ -60,46 +87,46 @@ void MumbleClient::DoPing(const boost::system::error_code& error) {
}

void MumbleClient::ParseMessage(const MessageHeader& msg_header, void* buffer) {
switch (msg_header.type) {
switch (msg_header.type()) {
case PbMessageType::Version: {
MumbleProto::Version v = ConstructProtobufObject<MumbleProto::Version>(buffer, msg_header.length, true);
MumbleProto::Version v = ConstructProtobufObject<MumbleProto::Version>(buffer, msg_header.length(), true);
// NOT_IMPLEMENTED
break;
}
case PbMessageType::Ping: {
MumbleProto::Ping p = ConstructProtobufObject<MumbleProto::Ping>(buffer, msg_header.length, false);
MumbleProto::Ping p = ConstructProtobufObject<MumbleProto::Ping>(buffer, msg_header.length(), false);
// NOT_IMPLEMENTED
break;
}
case PbMessageType::ChannelRemove: {
MumbleProto::ChannelRemove cr = ConstructProtobufObject<MumbleProto::ChannelRemove>(buffer, msg_header.length, true);
MumbleProto::ChannelRemove cr = ConstructProtobufObject<MumbleProto::ChannelRemove>(buffer, msg_header.length(), true);
HandleChannelRemove(cr);
break;
}
case PbMessageType::ChannelState: {
MumbleProto::ChannelState cs = ConstructProtobufObject<MumbleProto::ChannelState>(buffer, msg_header.length, true);
MumbleProto::ChannelState cs = ConstructProtobufObject<MumbleProto::ChannelState>(buffer, msg_header.length(), true);
HandleChannelState(cs);
break;
}
case PbMessageType::UserRemove: {
MumbleProto::UserRemove ur =ConstructProtobufObject<MumbleProto::UserRemove>(buffer, msg_header.length, true);
MumbleProto::UserRemove ur =ConstructProtobufObject<MumbleProto::UserRemove>(buffer, msg_header.length(), true);
HandleUserRemove(ur);
break;
}
case PbMessageType::UserState: {
MumbleProto::UserState us = ConstructProtobufObject<MumbleProto::UserState>(buffer, msg_header.length, true);
MumbleProto::UserState us = ConstructProtobufObject<MumbleProto::UserState>(buffer, msg_header.length(), true);
HandleUserState(us);
break;
}
case PbMessageType::TextMessage: {
MumbleProto::TextMessage tm = ConstructProtobufObject<MumbleProto::TextMessage>(buffer, msg_header.length, true);
MumbleProto::TextMessage tm = ConstructProtobufObject<MumbleProto::TextMessage>(buffer, msg_header.length(), true);

if (text_message_callback_)
text_message_callback_(tm.message());
break;
}
case PbMessageType::CryptSetup: {
MumbleProto::CryptSetup cs = ConstructProtobufObject<MumbleProto::CryptSetup>(buffer, msg_header.length, true);
MumbleProto::CryptSetup cs = ConstructProtobufObject<MumbleProto::CryptSetup>(buffer, msg_header.length(), true);
if (cs.has_key() && cs.has_client_nonce() && cs.has_server_nonce()) {
cs_->setKey(reinterpret_cast<const unsigned char *>(cs.key().data()), reinterpret_cast<const unsigned char *>(cs.client_nonce().data()), reinterpret_cast<const unsigned char *>(cs.server_nonce().data()));
} else if (cs.has_server_nonce()) {
Expand All @@ -113,12 +140,12 @@ void MumbleClient::ParseMessage(const MessageHeader& msg_header, void* buffer) {
break;
}
case PbMessageType::CodecVersion: {
MumbleProto::CodecVersion cv = ConstructProtobufObject<MumbleProto::CodecVersion>(buffer, msg_header.length, true);
MumbleProto::CodecVersion cv = ConstructProtobufObject<MumbleProto::CodecVersion>(buffer, msg_header.length(), true);
// NOT_IMPLEMENTED
break;
}
case PbMessageType::ServerSync: {
MumbleProto::ServerSync ss = ConstructProtobufObject<MumbleProto::ServerSync>(buffer, msg_header.length, true);
MumbleProto::ServerSync ss = ConstructProtobufObject<MumbleProto::ServerSync>(buffer, msg_header.length(), true);
state_ = kStateAuthenticated;
session_ = ss.session();

Expand All @@ -131,11 +158,11 @@ void MumbleClient::ParseMessage(const MessageHeader& msg_header, void* buffer) {
}
case PbMessageType::UDPTunnel: {
if (raw_udp_tunnel_callback_)
raw_udp_tunnel_callback_(msg_header.length, buffer);
raw_udp_tunnel_callback_(msg_header.length(), buffer);
break;
}
default:
DLOG(WARNING) << ">> IN: Unhandled message - Type: " << msg_header.type << " Length: " << msg_header.length;
DLOG(WARNING) << ">> IN: Unhandled message - Type: " << msg_header.type() << " Length: " << msg_header.length();
}
}

Expand Down Expand Up @@ -282,11 +309,26 @@ void MumbleClient::SendFirstQueued() {
boost::shared_ptr<Message>& msg = send_queue_.front();

std::vector<boost::asio::const_buffer> bufs;
bufs.push_back(boost::asio::buffer(reinterpret_cast<char *>(&msg->header_), sizeof(msg->header_)));
bufs.push_back(boost::asio::buffer(msg->header_.data(), 6));
bufs.push_back(boost::asio::buffer(msg->msg_, msg->msg_.size()));

async_write(*tcp_socket_, bufs, boost::bind(&MumbleClient::ProcessTCPSendQueue, this, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred));
DLOG(INFO) << "<< ASYNC Type: " << ntohs(msg->header_.type) << " Length: 6+" << msg->msg_.size();
DLOG(INFO) << "<< ASYNC Type: " << msg->header_.type() << " Length: 6+" << msg->msg_.size();
}

void MumbleClient::HandleMessageContent(std::istream& is, const MessageHeader& msg_header) {
if (static_cast<int32_t>(recv_buffer_.size()) < msg_header.length()) {
// The message is incomplete, read the rest
if (tcp_socket_)
async_read(*tcp_socket_, recv_buffer_, boost::asio::transfer_at_least(msg_header.length() - recv_buffer_.size()), boost::bind(&MumbleClient::ReadHandlerContinue, this, msg_header, boost::asio::placeholders::error));
return;
}

// Receive message body
char* buffer = new char[msg_header.length()];
is.read(buffer, msg_header.length());
ParseMessage(msg_header, buffer);
delete[] buffer;
}

void MumbleClient::ReadHandler(const boost::system::error_code& error) {
Expand All @@ -295,41 +337,35 @@ void MumbleClient::ReadHandler(const boost::system::error_code& error) {
return;
}

// Receive message header
MessageHeader msg_header;
boost::system::error_code err;
read(*tcp_socket_, boost::asio::buffer(reinterpret_cast<char *>(&msg_header), 6), boost::asio::transfer_all(), err);
// FIXME(pcgod): This is not the correct solution... We should use async_read with a buffer
if (err) {
#if defined(_WIN32)
if (err.value() == WSAEWOULDBLOCK) {
#else
if (err.value() == EWOULDBLOCK) {
#endif
if (tcp_socket_)
tcp_socket_->async_read_some(boost::asio::null_buffers(), boost::bind(&MumbleClient::ReadHandler, this, boost::asio::placeholders::error));
} else {
LOG(ERROR) << "read error: " << err.message() << " " << err.value();
}
return;
}
std::istream is(&recv_buffer_);
do {
// Receive message header
MessageHeader msg_header;
is >> msg_header;

msg_header.type = ntohs(msg_header.type);
msg_header.length = ntohl(msg_header.length);
if (msg_header.length() >= 0x7FFFF)
return;

if (msg_header.length >= 0x7FFFF)
return;
HandleMessageContent(is, msg_header);
} while (recv_buffer_.size() >= 6);

// Receive message body
char* buffer = static_cast<char *>(malloc(msg_header.length));
read(*tcp_socket_, boost::asio::buffer(buffer, msg_header.length));
// Requeue read
if (tcp_socket_)
async_read(*tcp_socket_, recv_buffer_, boost::asio::transfer_at_least(6), boost::bind(&MumbleClient::ReadHandler, this, boost::asio::placeholders::error));
}

ParseMessage(msg_header, buffer);
free(buffer);
void MumbleClient::ReadHandlerContinue(const MessageHeader msg_header, const boost::system::error_code& error) {
if (error) {
LOG(ERROR) << "read error: " << error.message();
return;
}

std::istream is(&recv_buffer_);
HandleMessageContent(is, msg_header);

// Requeue read
if (tcp_socket_)
tcp_socket_->async_read_some(boost::asio::null_buffers(), boost::bind(&MumbleClient::ReadHandler, this, boost::asio::placeholders::error));
async_read(*tcp_socket_, recv_buffer_, boost::asio::transfer_at_least(6), boost::bind(&MumbleClient::ReadHandler, this, boost::asio::placeholders::error));
}

///////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -413,7 +449,7 @@ void MumbleClient::Connect(const Settings& s) {
a.add_celt_versions(0x8000000b);
SendMessage(PbMessageType::Authenticate, a, true);

tcp_socket_->async_read_some(boost::asio::null_buffers(), boost::bind(&MumbleClient::ReadHandler, this, boost::asio::placeholders::error));
boost::asio::async_read(*tcp_socket_, recv_buffer_, boost::asio::transfer_at_least(6), boost::bind(&MumbleClient::ReadHandler, this, boost::asio::placeholders::error));
}

void MumbleClient::Disconnect() {
Expand Down Expand Up @@ -445,8 +481,8 @@ void MumbleClient::SendMessage(PbMessageType::MessageType type, const ::google::
bool write_in_progress = !send_queue_.empty();
int32_t length = new_msg.ByteSize();
MessageHeader msg_header;
msg_header.type = htons(static_cast<int16_t>(type));
msg_header.length = htonl(length);
msg_header.type(static_cast<int16_t>(type));
msg_header.length(length);

std::string pb_message = new_msg.SerializeAsString();
boost::shared_ptr<Message> m = boost::make_shared<Message>(msg_header, pb_message);
Expand All @@ -460,8 +496,8 @@ void MumbleClient::SendMessage(PbMessageType::MessageType type, const ::google::
void MumbleClient::SendRawUdpTunnel(const char* buffer, int32_t len) {
bool write_in_progress = !send_queue_.empty();
MessageHeader msg_header;
msg_header.type = htons(static_cast<int16_t>(PbMessageType::UDPTunnel));
msg_header.length = htonl(len);
msg_header.type(PbMessageType::UDPTunnel);
msg_header.length(len);

std::string data(buffer, len);
boost::shared_ptr<Message> m = boost::make_shared<Message>(msg_header, data);
Expand Down
28 changes: 7 additions & 21 deletions client.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,11 @@ using boost::asio::ssl::stream;

class Channel;
class CryptState;
class Message;
class MessageHeader;
class Settings;
class User;

namespace mumble_message {

#pragma pack(push)
#pragma pack(1)
struct MessageHeader {
int16_t type;
int32_t length;
};
#pragma pack(pop)

struct Message {
MessageHeader header_;
std::string msg_;

Message(const MessageHeader& header, const std::string& msg) : header_(header), msg_(msg) {};
};

} // namespace mumble_message

typedef std::list< boost::shared_ptr<User> >::iterator user_list_iterator;
typedef std::list< boost::shared_ptr<Channel> >::iterator channel_list_iterator;

Expand Down Expand Up @@ -91,10 +74,12 @@ class DLL_PUBLIC MumbleClient {
DLL_LOCAL MumbleClient(boost::asio::io_service* io_service);

DLL_LOCAL void DoPing(const boost::system::error_code& error);
DLL_LOCAL void ParseMessage(const mumble_message::MessageHeader& msg_header, void* buffer);
DLL_LOCAL void ParseMessage(const MessageHeader& msg_header, void* buffer);
DLL_LOCAL void ProcessTCPSendQueue(const boost::system::error_code& error, const size_t bytes_transferred);
DLL_LOCAL void SendFirstQueued();
DLL_LOCAL void HandleMessageContent(std::istream& is, const MessageHeader& msg_header);
DLL_LOCAL void ReadHandler(const boost::system::error_code& error);
DLL_LOCAL void ReadHandlerContinue(const MessageHeader msg_header, const boost::system::error_code& error);
DLL_LOCAL void HandleUserRemove(const MumbleProto::UserRemove& ur);
DLL_LOCAL void HandleUserState(const MumbleProto::UserState& us);
DLL_LOCAL void HandleChannelState(const MumbleProto::ChannelState& cs);
Expand All @@ -109,8 +94,9 @@ class DLL_PUBLIC MumbleClient {
tcp::socket* tcp_socket_;
#endif
udp::socket* udp_socket_;
boost::asio::streambuf recv_buffer_;
CryptState* cs_;
std::deque< boost::shared_ptr<mumble_message::Message> > send_queue_;
std::deque< boost::shared_ptr<Message> > send_queue_;
State state_;
boost::asio::deadline_timer* ping_timer_;
int32_t session_;
Expand Down

0 comments on commit cce131f

Please sign in to comment.