Skip to content

Commit

Permalink
ThreadOperationalDataset: various bug fixes
Browse files Browse the repository at this point in the history
- Ensure TLVs read have the correct length
- Default construct as empty (mLength == 0)
- Change MakeRoom() size argument to size_t to avoid chance of truncation
- Check for null in SetNetworkName
- Use Encoding::BigEndian instead of hand-rolled math
- Use ReturnErrorOnFailure / VerifyOrReturnError consistently
- Make tests independent from each other (non-static dataset member)
- Add more tests
  • Loading branch information
ksperling-apple committed Jul 15, 2024
1 parent 05e4c10 commit c4847c3
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 240 deletions.
226 changes: 65 additions & 161 deletions src/lib/support/ThreadOperationalDataset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class ThreadTLV final
mLength = aLength;
}

const void * GetValue() const
const uint8_t * GetValue() const
{
assert(mLength != kLengthEscape);

Expand All @@ -79,75 +79,56 @@ class ThreadTLV final
return reinterpret_cast<const uint8_t *>(this) + sizeof(*this);
}

void * GetValue() { return const_cast<void *>(const_cast<const ThreadTLV *>(this)->GetValue()); }
uint8_t * GetValue() { return const_cast<uint8_t *>(const_cast<const ThreadTLV *>(this)->GetValue()); }

ByteSpan GetValueAsSpan() const { return ByteSpan(static_cast<const uint8_t *>(GetValue()), GetLength()); }

void Get64(uint64_t & aValue) const
{
assert(GetLength() >= sizeof(aValue));
aValue = Encoding::BigEndian::Get64(GetValue());
}

const uint8_t * p = reinterpret_cast<const uint8_t *>(GetValue());
aValue = //
(static_cast<uint64_t>(p[0]) << 56) | //
(static_cast<uint64_t>(p[1]) << 48) | //
(static_cast<uint64_t>(p[2]) << 40) | //
(static_cast<uint64_t>(p[3]) << 32) | //
(static_cast<uint64_t>(p[4]) << 24) | //
(static_cast<uint64_t>(p[5]) << 16) | //
(static_cast<uint64_t>(p[6]) << 8) | //
(static_cast<uint64_t>(p[7]));
void Get32(uint32_t & aValue) const
{
assert(GetLength() >= sizeof(aValue));
aValue = Encoding::BigEndian::Get32(GetValue());
}

void Get16(uint16_t & aValue) const
{
assert(GetLength() >= sizeof(aValue));

const uint8_t * p = static_cast<const uint8_t *>(GetValue());

aValue = static_cast<uint16_t>(p[0] << 8 | p[1]);
aValue = Encoding::BigEndian::Get16(GetValue());
}

void Get8(uint8_t & aValue) const
{
assert(GetLength() >= sizeof(aValue));
aValue = *static_cast<const uint8_t *>(GetValue());
aValue = *GetValue();
}

void Set64(uint64_t aValue)
{
uint8_t * value = static_cast<uint8_t *>(GetValue());

SetLength(sizeof(aValue));

value[0] = static_cast<uint8_t>((aValue >> 56) & 0xff);
value[1] = static_cast<uint8_t>((aValue >> 48) & 0xff);
value[2] = static_cast<uint8_t>((aValue >> 40) & 0xff);
value[3] = static_cast<uint8_t>((aValue >> 32) & 0xff);
value[4] = static_cast<uint8_t>((aValue >> 24) & 0xff);
value[5] = static_cast<uint8_t>((aValue >> 16) & 0xff);
value[6] = static_cast<uint8_t>((aValue >> 8) & 0xff);
value[7] = static_cast<uint8_t>(aValue & 0xff);
Encoding::BigEndian::Put64(GetValue(), aValue);
}

void Set16(uint16_t aValue)
void Set32(uint32_t aValue)
{
uint8_t * value = static_cast<uint8_t *>(GetValue());

SetLength(sizeof(aValue));

value[0] = static_cast<uint8_t>(aValue >> 8);
value[1] = static_cast<uint8_t>(aValue & 0xff);
Encoding::BigEndian::Put32(GetValue(), aValue);
}

void Set8(uint8_t aValue)
void Set16(uint16_t aValue)
{
SetLength(sizeof(aValue));
*static_cast<uint8_t *>(GetValue()) = aValue;
Encoding::BigEndian::Put16(GetValue(), aValue);
}

void Set8(int8_t aValue)
void Set8(uint8_t aValue)
{
SetLength(sizeof(aValue));
*static_cast<int8_t *>(GetValue()) = aValue;
*GetValue() = aValue;
}

void SetValue(const void * aValue, uint8_t aLength)
Expand Down Expand Up @@ -218,24 +199,16 @@ CHIP_ERROR OperationalDataset::Init(ByteSpan aData)
CHIP_ERROR OperationalDataset::GetActiveTimestamp(uint64_t & aActiveTimestamp) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kActiveTimestamp);

if (tlv != nullptr)
{
tlv->Get64(aActiveTimestamp);
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == sizeof(aActiveTimestamp), CHIP_ERROR_INVALID_TLV_ELEMENT);
tlv->Get64(aActiveTimestamp);
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetActiveTimestamp(uint64_t aActiveTimestamp)
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kActiveTimestamp, sizeof(*tlv) + sizeof(aActiveTimestamp));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->Set64(aActiveTimestamp);

Expand All @@ -247,26 +220,19 @@ CHIP_ERROR OperationalDataset::SetActiveTimestamp(uint64_t aActiveTimestamp)
CHIP_ERROR OperationalDataset::GetChannel(uint16_t & aChannel) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kChannel);

if (tlv != nullptr)
{
const uint8_t * value = reinterpret_cast<const uint8_t *>(tlv->GetValue());
aChannel = static_cast<uint16_t>((value[1] << 8) | value[2]);
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == 3, CHIP_ERROR_INVALID_TLV_ELEMENT);
// Note: The channel page (byte 0) is not returned
const uint8_t * value = tlv->GetValue();
aChannel = static_cast<uint16_t>((value[1] << 8) | value[2]);
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetChannel(uint16_t aChannel)
{
uint8_t value[] = { 0, static_cast<uint8_t>(aChannel >> 8), static_cast<uint8_t>(aChannel & 0xff) };
ThreadTLV * tlv = MakeRoom(ThreadTLV::kChannel, sizeof(*tlv) + sizeof(value));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(value, sizeof(value));

Expand All @@ -278,43 +244,24 @@ CHIP_ERROR OperationalDataset::SetChannel(uint16_t aChannel)
CHIP_ERROR OperationalDataset::GetExtendedPanId(uint8_t (&aExtendedPanId)[kSizeExtendedPanId]) const
{
ByteSpan extPanIdSpan;
CHIP_ERROR error = GetExtendedPanIdAsByteSpan(extPanIdSpan);

if (error != CHIP_NO_ERROR)
{
return error;
}

ReturnErrorOnFailure(GetExtendedPanIdAsByteSpan(extPanIdSpan));
memcpy(aExtendedPanId, extPanIdSpan.data(), extPanIdSpan.size());
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::GetExtendedPanIdAsByteSpan(ByteSpan & span) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kExtendedPanId);

if (tlv == nullptr)
{
return CHIP_ERROR_TLV_TAG_NOT_FOUND;
}

if (tlv->GetLength() != kSizeExtendedPanId)
{
return CHIP_ERROR_INVALID_TLV_ELEMENT;
}

VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == kSizeExtendedPanId, CHIP_ERROR_INVALID_TLV_ELEMENT);
span = ByteSpan(static_cast<const uint8_t *>(tlv->GetValue()), tlv->GetLength());
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetExtendedPanId(const uint8_t (&aExtendedPanId)[kSizeExtendedPanId])
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kExtendedPanId, sizeof(*tlv) + sizeof(aExtendedPanId));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(aExtendedPanId, sizeof(aExtendedPanId));

Expand All @@ -328,24 +275,16 @@ CHIP_ERROR OperationalDataset::SetExtendedPanId(const uint8_t (&aExtendedPanId)[
CHIP_ERROR OperationalDataset::GetMasterKey(uint8_t (&aMasterKey)[kSizeMasterKey]) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kMasterKey);

if (tlv != nullptr)
{
memcpy(aMasterKey, tlv->GetValue(), sizeof(aMasterKey));
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == sizeof(aMasterKey), CHIP_ERROR_INVALID_TLV_ELEMENT);
memcpy(aMasterKey, tlv->GetValue(), sizeof(aMasterKey));
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetMasterKey(const uint8_t (&aMasterKey)[kSizeMasterKey])
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kMasterKey, sizeof(*tlv) + sizeof(aMasterKey));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(aMasterKey, sizeof(aMasterKey));

Expand All @@ -359,24 +298,16 @@ CHIP_ERROR OperationalDataset::SetMasterKey(const uint8_t (&aMasterKey)[kSizeMas
CHIP_ERROR OperationalDataset::GetMeshLocalPrefix(uint8_t (&aMeshLocalPrefix)[kSizeMeshLocalPrefix]) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kMeshLocalPrefix);

if (tlv != nullptr)
{
memcpy(aMeshLocalPrefix, tlv->GetValue(), sizeof(aMeshLocalPrefix));
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == sizeof(aMeshLocalPrefix), CHIP_ERROR_INVALID_TLV_ELEMENT);
memcpy(aMeshLocalPrefix, tlv->GetValue(), sizeof(aMeshLocalPrefix));
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetMeshLocalPrefix(const uint8_t (&aMeshLocalPrefix)[kSizeMeshLocalPrefix])
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kMeshLocalPrefix, sizeof(*tlv) + sizeof(aMeshLocalPrefix));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(aMeshLocalPrefix, sizeof(aMeshLocalPrefix));

Expand All @@ -388,32 +319,21 @@ CHIP_ERROR OperationalDataset::SetMeshLocalPrefix(const uint8_t (&aMeshLocalPref
CHIP_ERROR OperationalDataset::GetNetworkName(char (&aNetworkName)[kSizeNetworkName + 1]) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kNetworkName);

if (tlv != nullptr)
{
memcpy(aNetworkName, tlv->GetValue(), tlv->GetLength());
aNetworkName[tlv->GetLength()] = '\0';
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() <= kSizeNetworkName, CHIP_ERROR_INVALID_TLV_ELEMENT);
memcpy(aNetworkName, tlv->GetValue(), tlv->GetLength());
aNetworkName[tlv->GetLength()] = '\0';
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetNetworkName(const char * aNetworkName)
{
VerifyOrReturnError(aNetworkName != nullptr, CHIP_ERROR_INVALID_ARGUMENT);
size_t len = strlen(aNetworkName);
VerifyOrReturnError(0 < len && len <= kSizeNetworkName, CHIP_ERROR_INVALID_STRING_LENGTH);

if (len > kSizeNetworkName || len == 0)
{
return CHIP_ERROR_INVALID_STRING_LENGTH;
}

ThreadTLV * tlv = MakeRoom(ThreadTLV::kNetworkName, static_cast<uint8_t>(sizeof(*tlv) + static_cast<uint8_t>(len)));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
ThreadTLV * tlv = MakeRoom(ThreadTLV::kNetworkName, sizeof(*tlv) + len);
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(aNetworkName, static_cast<uint8_t>(len));

Expand All @@ -425,24 +345,16 @@ CHIP_ERROR OperationalDataset::SetNetworkName(const char * aNetworkName)
CHIP_ERROR OperationalDataset::GetPanId(uint16_t & aPanId) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kPanId);

if (tlv != nullptr)
{
tlv->Get16(aPanId);
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == sizeof(aPanId), CHIP_ERROR_INVALID_TLV_ELEMENT);
tlv->Get16(aPanId);
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetPanId(uint16_t aPanId)
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kPanId, sizeof(*tlv) + sizeof(aPanId));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->Set16(aPanId);

Expand All @@ -454,24 +366,16 @@ CHIP_ERROR OperationalDataset::SetPanId(uint16_t aPanId)
CHIP_ERROR OperationalDataset::GetPSKc(uint8_t (&aPSKc)[kSizePSKc]) const
{
const ThreadTLV * tlv = Locate(ThreadTLV::kPSKc);

if (tlv != nullptr)
{
memcpy(aPSKc, tlv->GetValue(), sizeof(aPSKc));
return CHIP_NO_ERROR;
}

return CHIP_ERROR_TLV_TAG_NOT_FOUND;
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_TLV_TAG_NOT_FOUND);
VerifyOrReturnError(tlv->GetLength() == sizeof(aPSKc), CHIP_ERROR_INVALID_TLV_ELEMENT);
memcpy(aPSKc, tlv->GetValue(), sizeof(aPSKc));
return CHIP_NO_ERROR;
}

CHIP_ERROR OperationalDataset::SetPSKc(const uint8_t (&aPSKc)[kSizePSKc])
{
ThreadTLV * tlv = MakeRoom(ThreadTLV::kPSKc, sizeof(*tlv) + sizeof(aPSKc));

if (tlv == nullptr)
{
return CHIP_ERROR_NO_MEMORY;
}
VerifyOrReturnError(tlv != nullptr, CHIP_ERROR_NO_MEMORY);

tlv->SetValue(aPSKc, sizeof(aPSKc));

Expand Down Expand Up @@ -533,7 +437,7 @@ void OperationalDataset::Remove(uint8_t aType)
}
}

ThreadTLV * OperationalDataset::MakeRoom(uint8_t aType, uint8_t aSize)
ThreadTLV * OperationalDataset::MakeRoom(uint8_t aType, size_t aSize)
{
ThreadTLV * tlv = Locate(aType);

Expand Down
Loading

0 comments on commit c4847c3

Please sign in to comment.