Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rollback InvokeRequestMessage when AddResponseData fails #33849

Merged
51 changes: 45 additions & 6 deletions src/app/CommandSender.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,13 +540,27 @@ CHIP_ERROR CommandSender::FinishCommand(FinishCommandParameters & aFinishCommand
CHIP_ERROR CommandSender::AddRequestData(const CommandPathParams & aCommandPath, const DataModel::EncodableToTLV & aEncodable,
AddRequestDataParameters & aAddRequestDataParams)
{
ReturnErrorOnFailure(AllocateBuffer());

RollbackData rollbackData;
rollbackData.Checkpoint(*this);
PrepareCommandParameters prepareCommandParams(aAddRequestDataParams);
ReturnErrorOnFailure(PrepareCommand(aCommandPath, prepareCommandParams));
TLV::TLVWriter * writer = GetCommandDataIBTLVWriter();
VerifyOrReturnError(writer != nullptr, CHIP_ERROR_INCORRECT_STATE);
ReturnErrorOnFailure(aEncodable.EncodeTo(*writer, TLV::ContextTag(CommandDataIB::Tag::kFields)));
FinishCommandParameters finishCommandParams(aAddRequestDataParams);
return FinishCommand(finishCommandParams);
TLV::TLVWriter * writer = nullptr;
CHIP_ERROR err = CHIP_NO_ERROR;
SuccessOrExit(err = PrepareCommand(aCommandPath, prepareCommandParams));
writer = GetCommandDataIBTLVWriter();
VerifyOrExit(writer != nullptr, err = CHIP_ERROR_INCORRECT_STATE);
SuccessOrExit(err = aEncodable.EncodeTo(*writer, TLV::ContextTag(CommandDataIB::Tag::kFields)));
{
FinishCommandParameters finishCommandParams(aAddRequestDataParams);
SuccessOrExit(err = FinishCommand(finishCommandParams));
}
exit:
if (err != CHIP_NO_ERROR)
{
LogErrorOnFailure(rollbackData.Rollback(*this));
bzbarsky-apple marked this conversation as resolved.
Show resolved Hide resolved
}
return err;
}

CHIP_ERROR CommandSender::FinishCommandInternal(FinishCommandParameters & aFinishCommandParams)
Expand Down Expand Up @@ -657,5 +671,30 @@ void CommandSender::MoveToState(const State aTargetState)
ChipLogDetail(DataManagement, "ICR moving to [%10.10s]", GetStateStr());
}

void CommandSender::RollbackData::Checkpoint(CommandSender & aCommandSender)
{
VerifyOrReturn(aCommandSender.mBufferAllocated);
VerifyOrReturn(aCommandSender.mState == State::Idle || aCommandSender.mState == State::AddedCommand);
VerifyOrReturn(aCommandSender.mInvokeRequestBuilder.GetInvokeRequests().GetError() == CHIP_NO_ERROR);
VerifyOrReturn(aCommandSender.mInvokeRequestBuilder.GetError() == CHIP_NO_ERROR);
aCommandSender.mInvokeRequestBuilder.Checkpoint(mBackupWriter);
mBackupState = aCommandSender.mState;
mRollbackIsValid = true;
}

CHIP_ERROR CommandSender::RollbackData::Rollback(CommandSender & aCommandSender)
{
VerifyOrReturnError(mRollbackIsValid, CHIP_ERROR_INCORRECT_STATE);
VerifyOrReturnError(aCommandSender.mState == State::AddingCommand, CHIP_ERROR_INCORRECT_STATE);
ChipLogDetail(DataManagement, "Rolling back response");
// TODO(#30453): Rollback of mInvokeRequestBuilder should handle resetting
tehampson marked this conversation as resolved.
Show resolved Hide resolved
// InvokeResponses.
tehampson marked this conversation as resolved.
Show resolved Hide resolved
aCommandSender.mInvokeRequestBuilder.GetInvokeRequests().ResetError();
aCommandSender.mInvokeRequestBuilder.Rollback(mBackupWriter);
aCommandSender.MoveToState(mBackupState);
mRollbackIsValid = false;
return CHIP_NO_ERROR;
}

} // namespace app
} // namespace chip
34 changes: 34 additions & 0 deletions src/app/CommandSender.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ class CommandSender final : public Messaging::ExchangeDelegate

AddRequestDataParameters(const Optional<uint16_t> & aTimedInvokeTimeoutMs) : timedInvokeTimeoutMs(aTimedInvokeTimeoutMs) {}

AddRequestDataParameters & SetCommandRef(uint16_t aCommandRef)
{
commandRef.SetValue(aCommandRef);
return *this;
}

// When a value is provided for timedInvokeTimeoutMs, this invoke becomes a timed
// invoke. CommandSender will use the minimum of all provided timeouts for execution.
const Optional<uint16_t> timedInvokeTimeoutMs;
Expand Down Expand Up @@ -511,6 +517,34 @@ class CommandSender final : public Messaging::ExchangeDelegate
AwaitingDestruction, ///< The object has completed its work and is awaiting destruction by the application.
};

class RollbackData
{
public:
/**
* Creates a backup to enable rolling back CommandSender's buffer containing
* InvokeRequestMessage in case subsequent calls to add request fail.
*
* A successful backup will only be created if the InvokeRequestMessage is
* in a known good state.
*
* @param [in] aCommandSender reference to CommandSender.
*/
void Checkpoint(CommandSender & aCommandSender);
/**
* Rolls back CommandSender's buffer containing InvokeRequestMessage to a previously
* saved state. Must have previously called Checkpoint in a known good state.
*
* @param [in] aCommandSender reference to CommandSender.
*/
CHIP_ERROR Rollback(CommandSender & aCommandSender);
bool RollbackIsValid() { return mRollbackIsValid; }
tehampson marked this conversation as resolved.
Show resolved Hide resolved

private:
TLV::TLVWriter mBackupWriter;
State mBackupState;
bool mRollbackIsValid = false;
};

union CallbackHandle
{
CallbackHandle(Callback * apCallback) : legacyCallback(apCallback) {}
Expand Down
64 changes: 56 additions & 8 deletions src/app/tests/TestCommandInteraction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,9 @@ enum class ForcedSizeBufferLengthHint
kSizeGreaterThan255,
};

struct ForcedSizeBuffer
class ForcedSizeBuffer : public app::DataModel::EncodableToTLV
tehampson marked this conversation as resolved.
Show resolved Hide resolved
{
chip::Platform::ScopedMemoryBufferWithSize<uint8_t> mBuffer;

public:
ForcedSizeBuffer(uint32_t size)
{
if (mBuffer.Alloc(size))
Expand All @@ -124,7 +123,7 @@ struct ForcedSizeBuffer

// No significance with using 0x12 as the CommandId, just using a value.
static constexpr chip::CommandId GetCommandId() { return 0x12; }
CHIP_ERROR Encode(TLV::TLVWriter & aWriter, TLV::Tag aTag) const
CHIP_ERROR EncodeTo(TLV::TLVWriter & aWriter, TLV::Tag aTag) const override
{
VerifyOrReturnError(mBuffer, CHIP_ERROR_NO_MEMORY);

Expand All @@ -133,6 +132,9 @@ struct ForcedSizeBuffer
ReturnErrorOnFailure(app::DataModel::Encode(aWriter, TLV::ContextTag(1), ByteSpan(mBuffer.Get(), mBuffer.AllocatedSize())));
return aWriter.EndContainer(outerContainerType);
}

private:
chip::Platform::ScopedMemoryBufferWithSize<uint8_t> mBuffer;
};

struct Fields
Expand Down Expand Up @@ -385,6 +387,7 @@ class TestCommandInteraction : public ::testing::Test
void TestCommandSender_WithProcessReceivedMsg();
void TestCommandSender_ExtendableApiWithProcessReceivedMsg();
void TestCommandSender_ExtendableApiWithProcessReceivedMsgContainingInvalidCommandRef();
void TestCommandSender_ValidateSecondLargeAddRequestDataRollbacked();
void TestCommandHandler_WithoutResponderCallingAddStatus();
void TestCommandHandler_WithoutResponderCallingAddResponse();
void TestCommandHandler_WithoutResponderCallingDirectPrepareFinishCommandApis();
Expand Down Expand Up @@ -630,7 +633,8 @@ uint32_t TestCommandInteraction::GetAddResponseDataOverheadSizeForPath(const Con
// When ForcedSizeBuffer exceeds 255, an extra byte is needed for length, affecting the overhead size required by
// AddResponseData. In order to have this accounted for in overhead calculation we set the length to be 256.
uint32_t sizeOfForcedSizeBuffer = aBufferSizeHint == ForcedSizeBufferLengthHint::kSizeGreaterThan255 ? 256 : 0;
EXPECT_EQ(commandHandler.AddResponseData(aRequestCommandPath, ForcedSizeBuffer(sizeOfForcedSizeBuffer)), CHIP_NO_ERROR);
ForcedSizeBuffer responseData(sizeOfForcedSizeBuffer);
EXPECT_EQ(commandHandler.AddResponseData(aRequestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
uint32_t remainingSizeAfter = commandHandler.mInvokeResponseBuilder.GetWriter()->GetRemainingFreeLength();
uint32_t delta = remainingSizeBefore - remainingSizeAfter - sizeOfForcedSizeBuffer;

Expand All @@ -655,7 +659,8 @@ void TestCommandInteraction::FillCurrentInvokeResponseBuffer(CommandHandlerImpl
// Validating assumption. If this fails, it means overheadSizeNeededForAddingResponse is likely too large.
EXPECT_GE(sizeToFill, 256u);

EXPECT_EQ(apCommandHandler->AddResponseData(aRequestCommandPath, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
ForcedSizeBuffer responseData(sizeToFill);
EXPECT_EQ(apCommandHandler->AddResponseData(aRequestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);
}

void TestCommandInteraction::ValidateCommandHandlerEncodeInvokeResponseMessage(bool aNeedStatusCode)
Expand Down Expand Up @@ -1085,6 +1090,47 @@ TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandSender_ExtendableApiWithP
EXPECT_EQ(mockCommandSenderExtendedDelegate.onErrorCalledTimes, 0);
}

TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandSender_ValidateSecondLargeAddRequestDataRollbacked)
{
mockCommandSenderExtendedDelegate.ResetCounter();
PendingResponseTrackerImpl pendingResponseTracker;
app::CommandSender commandSender(kCommandSenderTestOnlyMarker, &mockCommandSenderExtendedDelegate,
&mpTestContext->GetExchangeManager(), &pendingResponseTracker);

app::CommandSender::AddRequestDataParameters addRequestDataParams;

CommandSender::ConfigParameters config;
config.SetRemoteMaxPathsPerInvoke(2);
EXPECT_EQ(commandSender.SetCommandSenderConfig(config), CHIP_NO_ERROR);

// The specific values chosen here are arbitrary.
uint16_t firstCommandRef = 1;
uint16_t secondCommandRef = 2;
auto commandPathParams = MakeTestCommandPath();
SimpleTLVPayload simplePayloadWriter;
addRequestDataParams.SetCommandRef(firstCommandRef);

EXPECT_EQ(commandSender.AddRequestData(commandPathParams, simplePayloadWriter, addRequestDataParams), CHIP_NO_ERROR);

uint32_t remainingSize = commandSender.mInvokeRequestBuilder.GetWriter()->GetRemainingFreeLength();
// Because request is made of both request data and request path (commandPathParams), using
// `remainingSize` is large enough fail.
ForcedSizeBuffer requestData(remainingSize);

addRequestDataParams.SetCommandRef(secondCommandRef);
EXPECT_EQ(commandSender.AddRequestData(commandPathParams, requestData, addRequestDataParams), CHIP_ERROR_NO_MEMORY);

// Confirm that we can still send out a request with the first command.
EXPECT_EQ(commandSender.SendCommandRequest(mpTestContext->GetSessionBobToAlice()), CHIP_NO_ERROR);
EXPECT_EQ(commandSender.GetInvokeResponseMessageCount(), 0u);

mpTestContext->DrainAndServiceIO();

EXPECT_EQ(mockCommandSenderExtendedDelegate.onResponseCalledTimes, 1);
EXPECT_EQ(mockCommandSenderExtendedDelegate.onFinalCalledTimes, 1);
EXPECT_EQ(mockCommandSenderExtendedDelegate.onErrorCalledTimes, 0);
}

TEST_F(TestCommandInteraction, TestCommandHandlerEncodeSimpleCommandData)
{
// Send response which has simple command data and command path
Expand Down Expand Up @@ -1186,7 +1232,8 @@ TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandHandler_WithoutResponderC
CommandHandlerImpl commandHandler(&mockCommandHandlerDelegate);

uint32_t sizeToFill = 50; // This is an arbitrary number, we need to select a non-zero value.
EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
ForcedSizeBuffer responseData(sizeToFill);
EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);

// Since calling AddResponseData is supposed to be a no-operation when there is no responder, it is
// hard to validate. Best way is to check that we are still in an Idle state afterwards
Expand Down Expand Up @@ -1811,7 +1858,8 @@ TEST_F_FROM_FIXTURE(TestCommandInteraction, TestCommandHandler_FillUpInvokeRespo
EXPECT_EQ(remainingSize, sizeToLeave);

uint32_t sizeToFill = 50;
EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath2, ForcedSizeBuffer(sizeToFill)), CHIP_NO_ERROR);
ForcedSizeBuffer responseData(sizeToFill);
EXPECT_EQ(commandHandler.AddResponseData(requestCommandPath2, responseData.GetCommandId(), responseData), CHIP_NO_ERROR);

remainingSize = commandHandler.mInvokeResponseBuilder.GetWriter()->GetRemainingFreeLength();
EXPECT_GT(remainingSize, sizeToLeave);
Expand Down
Loading