Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ThreadOperationalDataset: various bug fixes #34331

Merged
merged 2 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 65 additions & 171 deletions src/lib/support/ThreadOperationalDataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
* limitations under the License.
*/

#include <assert.h>
#include <string.h>

#include <lib/support/ThreadOperationalDataset.h>

#include <lib/core/CHIPEncoding.h>

#include <cassert>
#include <cstring>

namespace chip {
namespace Thread {

Expand Down Expand Up @@ -70,7 +72,7 @@ class ThreadTLV final
mLength = aLength;
}

const void * GetValue() const
const uint8_t * GetValue() const
{
assert(mLength != kLengthEscape);

Expand All @@ -79,75 +81,44 @@ class ThreadTLV final
return reinterpret_cast<const uint8_t *>(this) + sizeof(*this);
}

void * GetValue() { return const_cast<void *>(const_cast<const ThreadTLV *>(this)->GetValue()); }
uint8_t * GetValue() { return const_cast<uint8_t *>(const_cast<const ThreadTLV *>(this)->GetValue()); }

ByteSpan GetValueAsSpan() const { return ByteSpan(static_cast<const uint8_t *>(GetValue()), GetLength()); }

void Get64(uint64_t & aValue) const
{
assert(GetLength() >= sizeof(aValue));

const uint8_t * p = reinterpret_cast<const uint8_t *>(GetValue());
aValue = //
(static_cast<uint64_t>(p[0]) << 56) | //
(static_cast<uint64_t>(p[1]) << 48) | //
(static_cast<uint64_t>(p[2]) << 40) | //
(static_cast<uint64_t>(p[3]) << 32) | //
(static_cast<uint64_t>(p[4]) << 24) | //
(static_cast<uint64_t>(p[5]) << 16) | //
(static_cast<uint64_t>(p[6]) << 8) | //
(static_cast<uint64_t>(p[7]));
aValue = Encoding::BigEndian::Get64(GetValue());
}

void Get16(uint16_t & aValue) const
void Get32(uint32_t & aValue) const
{
assert(GetLength() >= sizeof(aValue));

const uint8_t * p = static_cast<const uint8_t *>(GetValue());

aValue = static_cast<uint16_t>(p[0] << 8 | p[1]);
aValue = Encoding::BigEndian::Get32(GetValue());
}

void Get8(uint8_t & aValue) const
void Get16(uint16_t & aValue) const
{
assert(GetLength() >= sizeof(aValue));
aValue = *static_cast<const uint8_t *>(GetValue());
aValue = Encoding::BigEndian::Get16(GetValue());
}

void Set64(uint64_t aValue)
{
uint8_t * value = static_cast<uint8_t *>(GetValue());

SetLength(sizeof(aValue));

value[0] = static_cast<uint8_t>((aValue >> 56) & 0xff);
value[1] = static_cast<uint8_t>((aValue >> 48) & 0xff);
value[2] = static_cast<uint8_t>((aValue >> 40) & 0xff);
value[3] = static_cast<uint8_t>((aValue >> 32) & 0xff);
value[4] = static_cast<uint8_t>((aValue >> 24) & 0xff);
value[5] = static_cast<uint8_t>((aValue >> 16) & 0xff);
value[6] = static_cast<uint8_t>((aValue >> 8) & 0xff);
value[7] = static_cast<uint8_t>(aValue & 0xff);
Encoding::BigEndian::Put64(GetValue(), aValue);
}

void Set16(uint16_t aValue)
void Set32(uint32_t aValue)
{
uint8_t * value = static_cast<uint8_t *>(GetValue());

SetLength(sizeof(aValue));

value[0] = static_cast<uint8_t>(aValue >> 8);
value[1] = static_cast<uint8_t>(aValue & 0xff);
}

void Set8(uint8_t aValue)
{
SetLength(sizeof(aValue));
*static_cast<uint8_t *>(GetValue()) = aValue;
Encoding::BigEndian::Put32(GetValue(), aValue);
}

void Set8(int8_t aValue)
void Set16(uint16_t aValue)
{
SetLength(sizeof(aValue));
*static_cast<int8_t *>(GetValue()) = aValue;
Encoding::BigEndian::Put16(GetValue(), aValue);
}

void SetValue(const void * aValue, uint8_t aLength)
Expand Down Expand Up @@ -218,24 +189,16 @@ CHIP_ERROR OperationalDataset::Init(ByteSpan aData)
CHIP_ERROR OperationalDataset::GetActiveTimestamp(uint64_t & aActiveTimestamp) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kActiveTimestamp);

if (tlv != nullptr)
{
tlv->Get64(aActiveTimestamp);
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == sizeof(aActiveTimestamp), CHIP_ERROR_INVALID_TLV_ELEMENT);
tlv->Get64(aActiveTimestamp);
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetActiveTimestamp(uint64_t aActiveTimestamp)
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kActiveTimestamp, sizeof(*tlv) + sizeof(aActiveTimestamp));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->Set64(aActiveTimestamp);

Expand All @@ -247,26 +210,19 @@ CHIP_ERROR OperationalDataset::SetActiveTimestamp(uint64_t aActiveTimestamp)
CHIP_ERROR OperationalDataset::GetChannel(uint16_t & aChannel) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kChannel);

if (tlv != nullptr)
{
const uint8_t * value = reinterpret_cast<const uint8_t *>(tlv->GetValue());
aChannel = static_cast<uint16_t>((value[1] << 8) | value[2]);
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == 3, CHIP_ERROR_INVALID_TLV_ELEMENT);
// Note: The channel page (byte 0) is not returned
const uint8_t * value = tlv->GetValue();
aChannel = static_cast<uint16_t>((value[1] << 8) | value[2]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could just BigEndian::Get16 starting at the right byte?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, I'll include that change in my other PR.

return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetChannel(uint16_t aChannel)
{
uint8_t value[] = { 0, static_cast<uint8_t>(aChannel >> 8), static_cast<uint8_t>(aChannel & 0xff) };
ThreadTLV * tlv = MakeRoom(ThreadTLV::kChannel, sizeof(*tlv) + sizeof(value));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(value, sizeof(value));

Expand All @@ -278,43 +234,24 @@ CHIP_ERROR OperationalDataset::SetChannel(uint16_t aChannel)
CHIP_ERROR OperationalDataset::GetExtendedPanId(uint8_t (&aExtendedPanId)[kSizeExtendedPanId]) const
{
ByteSpan extPanIdSpan;
CHIP_ERROR error = GetExtendedPanIdAsByteSpan(extPanIdSpan);

if (error != CHIP_NO_ERROR)
{
return error;
}

ReturnErrorOnFailure(GetExtendedPanIdAsByteSpan(extPanIdSpan));
memcpy(aExtendedPanId, extPanIdSpan.data(), extPanIdSpan.size());
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::GetExtendedPanIdAsByteSpan(ByteSpan & span) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kExtendedPanId);

if (tlv == nullptr)
{
return CHIP_ERROR_TLV_TAG_NOT_FOUND;
}

if (tlv->GetLength() != kSizeExtendedPanId)
{
return CHIP_ERROR_INVALID_TLV_ELEMENT;
}

VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == kSizeExtendedPanId, CHIP_ERROR_INVALID_TLV_ELEMENT);
span = ByteSpan(static_cast<const uint8_t *>(tlv->GetValue()), tlv->GetLength());
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetExtendedPanId(const uint8_t (&aExtendedPanId)[kSizeExtendedPanId])
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kExtendedPanId, sizeof(*tlv) + sizeof(aExtendedPanId));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(aExtendedPanId, sizeof(aExtendedPanId));

Expand All @@ -328,24 +265,16 @@ CHIP_ERROR OperationalDataset::SetExtendedPanId(const uint8_t (&aExtendedPanId)[
CHIP_ERROR OperationalDataset::GetMasterKey(uint8_t (&aMasterKey)[kSizeMasterKey]) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kMasterKey);

if (tlv != nullptr)
{
memcpy(aMasterKey, tlv->GetValue(), sizeof(aMasterKey));
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == sizeof(aMasterKey), CHIP_ERROR_INVALID_TLV_ELEMENT);
memcpy(aMasterKey, tlv->GetValue(), sizeof(aMasterKey));
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetMasterKey(const uint8_t (&aMasterKey)[kSizeMasterKey])
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kMasterKey, sizeof(*tlv) + sizeof(aMasterKey));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(aMasterKey, sizeof(aMasterKey));

Expand All @@ -359,24 +288,16 @@ CHIP_ERROR OperationalDataset::SetMasterKey(const uint8_t (&aMasterKey)[kSizeMas
CHIP_ERROR OperationalDataset::GetMeshLocalPrefix(uint8_t (&aMeshLocalPrefix)[kSizeMeshLocalPrefix]) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kMeshLocalPrefix);

if (tlv != nullptr)
{
memcpy(aMeshLocalPrefix, tlv->GetValue(), sizeof(aMeshLocalPrefix));
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == sizeof(aMeshLocalPrefix), CHIP_ERROR_INVALID_TLV_ELEMENT);
memcpy(aMeshLocalPrefix, tlv->GetValue(), sizeof(aMeshLocalPrefix));
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetMeshLocalPrefix(const uint8_t (&aMeshLocalPrefix)[kSizeMeshLocalPrefix])
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kMeshLocalPrefix, sizeof(*tlv) + sizeof(aMeshLocalPrefix));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(aMeshLocalPrefix, sizeof(aMeshLocalPrefix));

Expand All @@ -388,32 +309,21 @@ CHIP_ERROR OperationalDataset::SetMeshLocalPrefix(const uint8_t (&aMeshLocalPref
CHIP_ERROR OperationalDataset::GetNetworkName(char (&aNetworkName)[kSizeNetworkName + 1]) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kNetworkName);

if (tlv != nullptr)
{
memcpy(aNetworkName, tlv->GetValue(), tlv->GetLength());
aNetworkName[tlv->GetLength()] = '\0';
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() <= kSizeNetworkName, CHIP_ERROR_INVALID_TLV_ELEMENT);
memcpy(aNetworkName, tlv->GetValue(), tlv->GetLength());
aNetworkName[tlv->GetLength()] = '\0';
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetNetworkName(const char * aNetworkName)
{
VerifyOrReturnError(aNetworkName != nullptr, CHIP_ERROR_INVALID_ARGUMENT);
size_t len = strlen(aNetworkName);
VerifyOrReturnError(0 < len && len <= kSizeNetworkName, CHIP_ERROR_INVALID_STRING_LENGTH);

if (len > kSizeNetworkName || len == 0)
{
return CHIP_ERROR_INVALID_STRING_LENGTH;
}

ThreadTLV * tlv = MakeRoom(ThreadTLV::kNetworkName, static_cast<uint8_t>(sizeof(*tlv) + static_cast<uint8_t>(len)));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
ThreadTLV * tlv = MakeRoom(ThreadTLV::kNetworkName, sizeof(*tlv) + len);
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(aNetworkName, static_cast<uint8_t>(len));

Expand All @@ -425,24 +335,16 @@ CHIP_ERROR OperationalDataset::SetNetworkName(const char * aNetworkName)
CHIP_ERROR OperationalDataset::GetPanId(uint16_t & aPanId) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kPanId);

if (tlv != nullptr)
{
tlv->Get16(aPanId);
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == sizeof(aPanId), CHIP_ERROR_INVALID_TLV_ELEMENT);
tlv->Get16(aPanId);
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetPanId(uint16_t aPanId)
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kPanId, sizeof(*tlv) + sizeof(aPanId));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->Set16(aPanId);

Expand All @@ -454,24 +356,16 @@ CHIP_ERROR OperationalDataset::SetPanId(uint16_t aPanId)
CHIP_ERROR OperationalDataset::GetPSKc(uint8_t (&aPSKc)[kSizePSKc]) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kPSKc);

if (tlv != nullptr)
{
memcpy(aPSKc, tlv->GetValue(), sizeof(aPSKc));
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == sizeof(aPSKc), CHIP_ERROR_INVALID_TLV_ELEMENT);
memcpy(aPSKc, tlv->GetValue(), sizeof(aPSKc));
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetPSKc(const uint8_t (&aPSKc)[kSizePSKc])
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kPSKc, sizeof(*tlv) + sizeof(aPSKc));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(aPSKc, sizeof(aPSKc));

Expand Down Expand Up @@ -533,7 +427,7 @@ void OperationalDataset::Remove(uint8_t aType)
}
}

ThreadTLV * OperationalDataset::MakeRoom(uint8_t aType, uint8_t aSize)
ThreadTLV * OperationalDataset::MakeRoom(uint8_t aType, size_t aSize)
{
ThreadTLV * tlv = Locate(aType);

Expand Down
Loading
Loading