Skip to content

Commit

Permalink
Added support for sending streams directly from TcpClient. (#2341)
Browse files Browse the repository at this point in the history
  • Loading branch information
slaff authored Jun 29, 2021
1 parent fbf8caa commit 5da5592
Show file tree
Hide file tree
Showing 10 changed files with 180 additions and 55 deletions.
16 changes: 4 additions & 12 deletions Sming/Components/Network/src/Network/MqttClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

#include "MqttClient.h"

#include "Data/Stream/MemoryDataStream.h"
#include "Data/Stream/StreamChain.h"
#include "Data/Stream/DataSourceStream.h"

const mqtt_parser_callbacks_t MqttClient::callbacks PROGMEM = {
.on_message_begin = staticOnMessageBegin,
Expand Down Expand Up @@ -387,7 +386,7 @@ void MqttClient::onReadyToSendData(TcpConnectionEvent sourceEvent)
}

if(outgoingMessage->common.type == MQTT_TYPE_PUBLISH && payloadStream != nullptr) {
// The packetLength should be big enought for the header ONLY.
// The packetLength should be big enough for the header ONLY.
// Payload will be attached as a second stream
packetLength -= outgoingMessage->publish.content.length;
outgoingMessage->publish.content.data = nullptr;
Expand All @@ -396,16 +395,9 @@ void MqttClient::onReadyToSendData(TcpConnectionEvent sourceEvent)
uint8_t packet[packetLength];
mqtt_serialiser_write(&serialiser, outgoingMessage, packet, packetLength);

delete stream;
auto headerStream = new MemoryDataStream();
headerStream->write(packet, packetLength);
send(reinterpret_cast<const char*>(packet), packetLength);
if(payloadStream != nullptr) {
auto streamChain = new StreamChain();
streamChain->attachStream(headerStream);
streamChain->attachStream(payloadStream);
stream = streamChain;
} else {
stream = headerStream;
send(payloadStream);
}

state = eMCS_SendingData;
Expand Down
19 changes: 10 additions & 9 deletions Sming/Components/Network/src/Network/MqttClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ class MqttClient : protected TcpClient

#ifndef MQTT_NO_COMPAT
/**
* @todo deprecate: Use setWill(const String& topic, const String& message,uint8_t flags) instead
* @deprecated: Use setWill(const String& topic, const String& message,uint8_t flags) instead
*/
bool setWill(const String& topic, const String& message, int QoS, bool retained = false)
bool setWill(const String& topic, const String& message, int QoS, bool retained = false) SMING_DEPRECATED
{
uint8_t flags = (uint8_t)(retained + (QoS << 1));
return setWill(topic, message, flags);
Expand All @@ -197,12 +197,12 @@ class MqttClient : protected TcpClient
*/

/**
* @todo deprecate: Use publish(const String& topic, const String& message, uint8_t flags = 0) instead.
* @deprecated: Use publish(const String& topic, const String& message, uint8_t flags = 0) instead.
* If you want to have a callback that should be triggered on successful delivery of messages
* then use setEventHandler(MQTT_TYPE_PUBACK, youCallback) instead.
*/
bool publishWithQoS(const String& topic, const String& message, int QoS, bool retained = false,
MqttMessageDeliveredCallback onDelivery = nullptr)
MqttMessageDeliveredCallback onDelivery = nullptr) SMING_DEPRECATED
{
if(onDelivery) {
if(QoS == 1) {
Expand All @@ -220,10 +220,11 @@ class MqttClient : protected TcpClient
return publish(topic, message, flags);
}

/** @brief Provide a function to be called when a message is received from the broker
* @todo deprecate: Use setEventHandler(MQTT_TYPE_PUBLISH, MqttDelegate handler) instead.
/**
* @brief Provide a function to be called when a message is received from the broker
* @deprecated: Use setEventHandler(MQTT_TYPE_PUBLISH, MqttDelegate handler) instead.
*/
void setCallback(MqttStringSubscriptionCallback subscriptionCallback = nullptr)
void setCallback(MqttStringSubscriptionCallback subscriptionCallback = nullptr) SMING_DEPRECATED
{
this->subscriptionCallback = subscriptionCallback;
setEventHandler(MQTT_TYPE_PUBLISH, onPublish);
Expand Down Expand Up @@ -331,8 +332,8 @@ class MqttClient : protected TcpClient
*/

#ifndef MQTT_NO_COMPAT
MqttMessageDeliveredCallback onDelivery = nullptr; ///< @deprecated
MqttStringSubscriptionCallback subscriptionCallback = nullptr; ///< @deprecated
SMING_DEPRECATED MqttMessageDeliveredCallback onDelivery = nullptr; ///< @deprecated
SMING_DEPRECATED MqttStringSubscriptionCallback subscriptionCallback = nullptr; ///< @deprecated
#endif
};

Expand Down
65 changes: 43 additions & 22 deletions Sming/Components/Network/src/Network/TcpClient.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,14 @@

#include "TcpClient.h"
#include "Data/Stream/MemoryDataStream.h"
#include "Data/Stream/StreamChain.h"

void TcpClient::freeStreams()
{
if(buffer != nullptr) {
if(buffer != stream) {
debug_e("TcpClient: buffer doesn't match stream");
delete buffer;
}
buffer = nullptr;
}

delete stream;
stream = nullptr;
}

void TcpClient::setBuffer(ReadWriteStream* stream)
{
freeStreams();
buffer = stream;
this->stream = buffer;
}

bool TcpClient::connect(const String& server, int port, bool useSsl)
{
if(isProcessing()) {
Expand All @@ -58,21 +44,56 @@ bool TcpClient::send(const char* data, uint16_t len, bool forceCloseAfterSent)
return false;
}

if(buffer == nullptr) {
setBuffer(new MemoryDataStream());
if(buffer == nullptr) {
return false;
auto memoryStream = static_cast<MemoryDataStream*>(stream);
if(memoryStream == nullptr || memoryStream->getStreamType() != eSST_MemoryWritable) {
memoryStream = new MemoryDataStream();
if(stream == nullptr) {
stream = memoryStream;
}
}

if(buffer->write((const uint8_t*)data, len) != len) {
if(memoryStream->write(data, len) != len) {
debug_e("TcpClient::send ERROR: Unable to store %d bytes in buffer", len);
return false;
}

debug_d("Storing %d bytes in stream", len);
return send(memoryStream, forceCloseAfterSent);
}

bool TcpClient::send(IDataSourceStream* source, bool forceCloseAfterSent)
{
if(state != eTCS_Connecting && state != eTCS_Connected) {
return false;
}

if(source == nullptr) {
return false;
}

if(stream == nullptr) {
stream = source;
}
else if(stream != source){
auto chainStream = static_cast<StreamChain*>(stream);
if(chainStream != nullptr && chainStream->getStreamType() == eSST_Chain) {
chainStream->attachStream(source);
}
else {
debug_d("Creating stream chain ...");
chainStream = new StreamChain();
chainStream->attachStream(stream);
chainStream->attachStream(source);
stream = chainStream;
}
}

int length = source->available();
if(length > 0) {
totalSentBytes += length;
}

debug_d("Sending stream. Bytes to send: %d", length);

totalSentBytes += len;
closeAfterSent = forceCloseAfterSent ? eTCCASS_AfterSent : eTCCASS_None;

return true;
Expand Down
5 changes: 2 additions & 3 deletions Sming/Components/Network/src/Network/TcpClient.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ class TcpClient : public TcpConnection
return send(data.c_str(), data.length(), forceCloseAfterSent);
}

bool send(IDataSourceStream* source, bool forceCloseAfterSent = false);

bool isProcessing()
{
return state == eTCS_Connected || state == eTCS_Connecting;
Expand Down Expand Up @@ -151,9 +153,6 @@ class TcpClient : public TcpConnection
void freeStreams();

protected:
void setBuffer(ReadWriteStream* stream);

ReadWriteStream* buffer = nullptr; ///< Used internally to buffer arbitrary data via send() methods
IDataSourceStream* stream = nullptr; ///< The currently active stream being sent

private:
Expand Down
1 change: 0 additions & 1 deletion Sming/Components/Network/src/Network/TcpServer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,6 @@ void TcpServer::onClientComplete(TcpClient& client, bool successful)
bool TcpServer::onClientReceive(TcpClient& client, char* data, int size)
{
debug_d("TcpSever onReceive: %s, %d bytes\r\n", client.getRemoteIp().toString().c_str(), size);
debug_d("Data: %s", data);
if(clientReceiveDelegate) {
return clientReceiveDelegate(client, data, size);
}
Expand Down
17 changes: 10 additions & 7 deletions Sming/Core/Data/Stream/DataSourceStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,16 @@
* @ingroup constants
*/
enum StreamType {
eSST_Invalid, ///< Stream content not valid
eSST_Memory, ///< Memory data stream
eSST_File, ///< File data stream
eSST_Template, ///< Template data stream
eSST_JsonObject, ///< JSON object data stream
eSST_User, ///< User defined data stream
eSST_Unknown ///< Unknown data stream type
eSST_Invalid, ///< Stream content not valid
eSST_Memory, ///< Memory stream
eSST_MemoryWritable, /// < Memory stream where data can be safely written to.
// Expands on demand and does not transform the data.
eSST_File, ///< File data stream
eSST_Template, ///< Template data stream
eSST_JsonObject, ///< JSON object data stream
eSST_User, ///< User defined data stream
eSST_Chain, ///< A stream (chain) containing multiple streams
eSST_Unknown ///< Unknown data stream type
};

/**
Expand Down
2 changes: 1 addition & 1 deletion Sming/Core/Data/Stream/MemoryDataStream.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class MemoryDataStream : public ReadWriteStream

StreamType getStreamType() const override
{
return eSST_Memory;
return eSST_MemoryWritable;
}

/** @brief Get a pointer to the current position
Expand Down
5 changes: 5 additions & 0 deletions Sming/Core/Data/Stream/StreamChain.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class StreamChain : public MultiStream
return queue.enqueue(stream);
}

StreamType getStreamType() const override
{
return eSST_Chain;
}

protected:
IDataSourceStream* getNextStream() override
{
Expand Down
1 change: 1 addition & 0 deletions tests/HostTests/include/modules.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@
XX(Clocks) \
XX(Timers) \
XX(HttpRequest) \
XX(TcpClient) \
XX(Hosted)
104 changes: 104 additions & 0 deletions tests/HostTests/modules/TcpClient.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
#include <HostTests.h>

#include <Network/TcpClient.h>
#include <Network/TcpServer.h>
#include <Data/Stream/MemoryDataStream.h>
#include <Platform/Station.h>

class TcpClientTest : public TestGroup
{
public:
TcpClientTest() : TestGroup(_F("TcpClient"))
{
}

void execute() override
{
if(!WifiStation.isConnected()) {
Serial.println("No network, skipping tests");
return;
}

constexpr int port = 9876;
String inputData = "This is very long and complex text that will be sent using multiple complicated streams.";

// Tcp Server
server = new TcpServer(
[this](TcpClient& client, char* data, int size) -> bool {
// on data
return receivedData.concat(data, size);
},
[this, inputData](TcpClient& client, bool successful) {
// on client close
if(finished) {
return;
}
REQUIRE(successful == true);
REQUIRE(receivedData == inputData);
finished = true;
shutdown();
});
server->listen(port);
server->setTimeOut(USHRT_MAX); // disable connection timeout
server->setKeepAlive(USHRT_MAX); // disable connection timeout

// Tcp Client
bool connected = client.connect(WifiStation.getIP(), port);
debug_d("Connected: %d", connected);

TEST_CASE("TcpClient::send stream")
{
size_t offset = 0;

// Send text using bytes
client.send(inputData.c_str(), 5);
offset += 5;

// send data using more bytes
client.send(inputData.c_str() + offset, 7);
offset += 7;

// send data as stream
auto stream1 = new MemoryDataStream();
stream1->write(inputData.c_str() + offset, 3);
client.send(stream1);
offset += 3;
client.commit();

// more stream
auto stream2 = new LimitedMemoryStream(4);
stream2->write(reinterpret_cast<const uint8_t*>(inputData.c_str()) + offset, 4);
client.send(stream2);
offset += 4;

// and finally the rest of the bytes
String rest = inputData.substring(offset);
client.send(rest.c_str(), rest.length());
client.setTimeOut(1);

pending();
}
}

void shutdown()
{
server->shutdown();
server = nullptr;
timer.initializeMs<1000>([this]() { complete(); });
timer.startOnce();
}

private:
String receivedData;
TcpServer* server{nullptr};
TcpClient client{false};
Timer timer;
volatile bool finished = false;
};

void REGISTER_TEST(TcpClient)
{
#ifdef ARCH_HOST
registerGroup<TcpClientTest>();
#endif
}

0 comments on commit 5da5592

Please sign in to comment.