Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rollback InvokeRequestMessage when AddResponseData fails #33849

Merged
36 changes: 35 additions & 1 deletion src/app/CommandSender.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,18 @@ CHIP_ERROR CommandSender::FinishCommand(FinishCommandParameters & aFinishCommand
CHIP_ERROR CommandSender::AddRequestData(const CommandPathParams & aCommandPath, const DataModel::EncodableToTLV & aEncodable,
AddRequestDataParameters & aAddRequestDataParams)
{
ReturnErrorOnFailure(AllocateBuffer());

RollbackInvokeRequest rollback(*this);
PrepareCommandParameters prepareCommandParams(aAddRequestDataParams);
ReturnErrorOnFailure(PrepareCommand(aCommandPath, prepareCommandParams));
TLV::TLVWriter * writer = GetCommandDataIBTLVWriter();
VerifyOrReturnError(writer != nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorOnFailure(aEncodable.EncodeTo(*writer, TLV::ContextTag(CommandDataIB::Tag::kFields)));
FinishCommandParameters finishCommandParams(aAddRequestDataParams);
return FinishCommand(finishCommandParams);
ReturnErrorOnFailure(FinishCommand(finishCommandParams));
rollback.DisableAutomaticRollback();
return CHIP_NO_ERROR;
}

CHIP_ERROR CommandSender::FinishCommandInternal(FinishCommandParameters & aFinishCommandParams)
Expand Down Expand Up @@ -657,5 +662,34 @@ void CommandSender::MoveToState(const State aTargetState)
ChipLogDetail(DataManagement, "ICR moving to [%10.10s]", GetStateStr());
}

CommandSender::RollbackInvokeRequest::RollbackInvokeRequest(CommandSender & aCommandSender) : mCommandSender(aCommandSender)
{
VerifyOrReturn(mCommandSender.mBufferAllocated);
VerifyOrReturn(mCommandSender.mState == State::Idle || mCommandSender.mState == State::AddedCommand);
VerifyOrReturn(mCommandSender.mInvokeRequestBuilder.GetInvokeRequests().GetError() == CHIP_NO_ERROR);
VerifyOrReturn(mCommandSender.mInvokeRequestBuilder.GetError() == CHIP_NO_ERROR);
mCommandSender.mInvokeRequestBuilder.Checkpoint(mBackupWriter);
mBackupState = mCommandSender.mState;
mRollbackInDestructor = true;
}

CommandSender::RollbackInvokeRequest::~RollbackInvokeRequest()
{
VerifyOrReturn(mRollbackInDestructor);
VerifyOrReturn(mCommandSender.mState == State::AddingCommand);
ChipLogDetail(DataManagement, "Rolling back response");
// TODO(#30453): Rollback of mInvokeRequestBuilder should handle resetting
tehampson marked this conversation as resolved.
Show resolved Hide resolved
// InvokeRequest.
tehampson marked this conversation as resolved.
Show resolved Hide resolved
mCommandSender.mInvokeRequestBuilder.GetInvokeRequests().ResetError();
mCommandSender.mInvokeRequestBuilder.Rollback(mBackupWriter);
mCommandSender.MoveToState(mBackupState);
mRollbackInDestructor = false;
}

void CommandSender::RollbackInvokeRequest::DisableAutomaticRollback()
{
mRollbackInDestructor = false;
}

} // namespace app
} // namespace chip
34 changes: 34 additions & 0 deletions src/app/CommandSender.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ class CommandSender final : public Messaging::ExchangeDelegate

AddRequestDataParameters(const Optional<uint16_t> & aTimedInvokeTimeoutMs) : timedInvokeTimeoutMs(aTimedInvokeTimeoutMs) {}

AddRequestDataParameters & SetCommandRef(uint16_t aCommandRef)
{
commandRef.SetValue(aCommandRef);
return *this;
}

// When a value is provided for timedInvokeTimeoutMs, this invoke becomes a timed
// invoke. CommandSender will use the minimum of all provided timeouts for execution.
const Optional<uint16_t> timedInvokeTimeoutMs;
Expand Down Expand Up @@ -511,6 +517,34 @@ class CommandSender final : public Messaging::ExchangeDelegate
AwaitingDestruction, ///< The object has completed its work and is awaiting destruction by the application.
};

/**
* Class to help backup CommandSender's buffer containing InvokeRequestMessage when adding InvokeRequest
* in case there is a failure to add InvokeRequest. Intended usage is as follows:
* - Allocate RollbackInvokeRequest on the stack.
* - Attempt adding InvokeRequest into InvokeRequestMessage buffer.
* - If modification is added successfully, call DisableAutomaticRollback() to prevent destructor from
* rolling back InvokeReqestMessage.
* - If there is an issue adding InvokeRequest, destructor will take care of rolling back
* InvokeRequestMessage to previously saved state.
*/
class RollbackInvokeRequest
{
public:
explicit RollbackInvokeRequest(CommandSender & aCommandSender);
~RollbackInvokeRequest();

/**
* Disables rolling back to previously saved state for InvokeRequestMessage.
*/
void DisableAutomaticRollback();

private:
CommandSender & mCommandSender;
TLV::TLVWriter mBackupWriter;
State mBackupState;
bool mRollbackInDestructor = false;
};

union CallbackHandle
{
CallbackHandle(Callback * apCallback) : legacyCallback(apCallback) {}
Expand Down
64 changes: 56 additions & 8 deletions src/app/tests/TestCommandInteraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ enum class ForcedSizeBufferLengthHint
kSizeGreaterThan255,
};

struct ForcedSizeBuffer
class ForcedSizeBuffer : public app::DataModel::EncodableToTLV
tehampson marked this conversation as resolved.
Show resolved Hide resolved
{
chip::Platform::ScopedMemoryBufferWithSize<uint8_t> mBuffer;

public:
ForcedSizeBuffer(uint32_t size)
{
if (mBuffer.Alloc(size))
Expand All @@ -124,7 +123,7 @@ struct ForcedSizeBuffer

// No significance with using 0x12 as the CommandId, just using a value.
static constexpr chip::CommandId GetCommandId() { return 0x12; }
CHIP_ERROR Encode(TLV::TLVWriter & aWriter, TLV::Tag aTag) const
CHIP_ERROR EncodeTo(TLV::TLVWriter & aWriter, TLV::Tag aTag) const override
{
VerifyOrReturnError(mBuffer, CHIP_ERROR_NO_MEMORY);

Expand All @@ -133,6 +132,9 @@ struct ForcedSizeBuffer
ReturnErrorOnFailure(app::DataModel::Encode(aWriter, TLV::ContextTag(1), ByteSpan(mBuffer.Get(), mBuffer.AllocatedSize())));
return aWriter.EndContainer(outerContainerType);
}

private:
chip::Platform::ScopedMemoryBufferWithSize<uint8_t> mBuffer;
};

struct Fields
Expand Down Expand Up @@ -385,6 +387,7 @@ class TestCommandInteraction : public ::testing::Test
void TestCommandSender_WithProcessReceivedMsg();
void TestCommandSender_ExtendableApiWithProcessReceivedMsg();
void TestCommandSender_ExtendableApiWithProcessReceivedMsgContainingInvalidCommandRef();
void TestCommandSender_ValidateSecondLargeAddRequestDataRollbacked();
void TestCommandHandler_WithoutResponderCallingAddStatus();
void TestCommandHandler_WithoutResponderCallingAddResponse();
void TestCommandHandler_WithoutResponderCallingDirectPrepareFinishCommandApis();
Expand Down Expand Up @@ -630,7 +633,8 @@ uint32_t TestCommandInteraction::GetAddResponseDataOverheadSizeForPath(const Con
// When ForcedSizeBuffer exceeds 255, an extra byte is needed for length, affecting the overhead size required by
// AddResponseData. In order to have this accounted for in overhead calculation we set the length to be 256.
uint32_t sizeOfForcedSizeBuffer = aBufferSizeHint == ForcedSizeBufferLengthHint::kSizeGreaterThan255 ? 256 : 0;
EXPECT_EQ(commandHandler.AddResponseData(aRequestCommandPath, ForcedSizeBuffer(sizeOfForcedSizeBuffer)), CHIP_NO_ERROR);
ForcedSizeBuffer responseData(sizeOfForcedSizeBuffer);
EXPECT_EQ(commandHandler.AddResponseData(aRequestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
uint32_t remainingSizeAfter = commandHandler.mInvokeResponseBuilder.GetWriter()->GetRemainingFreeLength();
uint32_t delta = remainingSizeBefore - remainingSizeAfter - sizeOfForcedSizeBuffer;

Expand All @@ -655,7 +659,8 @@ void TestCommandInteraction::FillCurrentInvokeResponseBuffer(CommandHandlerImpl
// Validating assumption. If this fails, it means overheadSizeNeededForAddingResponse is likely too large.
EXPECT_GE(sizeToFill, 256u);

EXPECT_EQ(apCommandHandler->AddResponseData(aRequestCommandPath, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
ForcedSizeBuffer responseData(sizeToFill);
EXPECT_EQ(apCommandHandler->AddResponseData(aRequestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
}

void TestCommandInteraction::ValidateCommandHandlerEncodeInvokeResponseMessage(bool aNeedStatusCode)
Expand Down Expand Up @@ -1085,6 +1090,47 @@ TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandSender_ExtendableApiWithP
EXPECT_EQ(mockCommandSenderExtendedDelegate.onErrorCalledTimes, 0);
}

TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandSender_ValidateSecondLargeAddRequestDataRollbacked)
{
mockCommandSenderExtendedDelegate.ResetCounter();
PendingResponseTrackerImpl pendingResponseTracker;
app::CommandSender commandSender(kCommandSenderTestOnlyMarker, &mockCommandSenderExtendedDelegate,
&mpTestContext->GetExchangeManager(), &pendingResponseTracker);

app::CommandSender::AddRequestDataParameters addRequestDataParams;

CommandSender::ConfigParameters config;
config.SetRemoteMaxPathsPerInvoke(2);
EXPECT_EQ(commandSender.SetCommandSenderConfig(config), CHIP_NO_ERROR);

// The specific values chosen here are arbitrary.
uint16_t firstCommandRef = 1;
uint16_t secondCommandRef = 2;
auto commandPathParams = MakeTestCommandPath();
SimpleTLVPayload simplePayloadWriter;
addRequestDataParams.SetCommandRef(firstCommandRef);

EXPECT_EQ(commandSender.AddRequestData(commandPathParams, simplePayloadWriter, addRequestDataParams), CHIP_NO_ERROR);

uint32_t remainingSize = commandSender.mInvokeRequestBuilder.GetWriter()->GetRemainingFreeLength();
// Because request is made of both request data and request path (commandPathParams), using
// `remainingSize` is large enough fail.
ForcedSizeBuffer requestData(remainingSize);

addRequestDataParams.SetCommandRef(secondCommandRef);
EXPECT_EQ(commandSender.AddRequestData(commandPathParams, requestData, addRequestDataParams), CHIP_ERROR_NO_MEMORY);

// Confirm that we can still send out a request with the first command.
EXPECT_EQ(commandSender.SendCommandRequest(mpTestContext->GetSessionBobToAlice()), CHIP_NO_ERROR);
EXPECT_EQ(commandSender.GetInvokeResponseMessageCount(), 0u);

mpTestContext->DrainAndServiceIO();

EXPECT_EQ(mockCommandSenderExtendedDelegate.onResponseCalledTimes, 1);
EXPECT_EQ(mockCommandSenderExtendedDelegate.onFinalCalledTimes, 1);
EXPECT_EQ(mockCommandSenderExtendedDelegate.onErrorCalledTimes, 0);
}

TEST_F(TestCommandInteraction, TestCommandHandlerEncodeSimpleCommandData)
{
// Send response which has simple command data and command path
Expand Down Expand Up @@ -1186,7 +1232,8 @@ TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandHandler_WithoutResponderC
CommandHandlerImpl commandHandler(&mockCommandHandlerDelegate);

uint32_t sizeToFill = 50; // This is an arbitrary number, we need to select a non-zero value.
EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
ForcedSizeBuffer responseData(sizeToFill);
EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);

// Since calling AddResponseData is supposed to be a no-operation when there is no responder, it is
// hard to validate. Best way is to check that we are still in an Idle state afterwards
Expand Down Expand Up @@ -1811,7 +1858,8 @@ TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandHandler_FillUpInvokeRespo
EXPECT_EQ(remainingSize, sizeToLeave);

uint32_t sizeToFill = 50;
EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath2, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
ForcedSizeBuffer responseData(sizeToFill);
EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath2, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);

remainingSize = commandHandler.mInvokeResponseBuilder.GetWriter()->GetRemainingFreeLength();
EXPECT_GT(remainingSize, sizeToLeave);
Expand Down
Loading