Skip to content

Commit

Permalink
Fix wrong PictureID rolling over in VP9 and VP8 (#984)
Browse files Browse the repository at this point in the history
  • Loading branch information
jcague authored Jan 19, 2023
1 parent daca24c commit f60845b
Show file tree
Hide file tree
Showing 12 changed files with 354 additions and 49 deletions.
2 changes: 1 addition & 1 deletion worker/include/RTC/Codecs/H264_SVC.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ namespace RTC
}

public:
RTC::SeqManager<uint16_t> pictureIdManager;
RTC::SeqManager<uint16_t, 15> pictureIdManager;
bool syncRequired{ false };
};

Expand Down
2 changes: 1 addition & 1 deletion worker/include/RTC/Codecs/VP8.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ namespace RTC
}

public:
RTC::SeqManager<uint16_t> pictureIdManager;
RTC::SeqManager<uint16_t, 15> pictureIdManager;
RTC::SeqManager<uint8_t> tl0PictureIndexManager;
bool syncRequired{ false };
};
Expand Down
2 changes: 1 addition & 1 deletion worker/include/RTC/Codecs/VP9.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ namespace RTC
}

public:
RTC::SeqManager<uint16_t> pictureIdManager;
RTC::SeqManager<uint16_t, 15> pictureIdManager;
bool syncRequired{ false };
};

Expand Down
7 changes: 5 additions & 2 deletions worker/include/RTC/SeqManager.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@

namespace RTC
{
template<typename T>
// T is the base type (uint16_t, uint32_t, ...).
// N is the max number of bits used in T.
template<typename T, uint8_t N = 0>
class SeqManager
{
public:
static constexpr T MaxValue = std::numeric_limits<T>::max();
static constexpr T MaxValue = (N == 0) ? std::numeric_limits<T>::max() : ((1 << N) - 1);

public:
struct SeqLowerThan
Expand All @@ -27,6 +29,7 @@ namespace RTC
private:
static const SeqLowerThan isSeqLowerThan;
static const SeqHigherThan isSeqHigherThan;
static T Delta(const T lhs, const T rhs);

public:
static bool IsSeqLowerThan(const T lhs, const T rhs);
Expand Down
1 change: 1 addition & 0 deletions worker/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ mediasoup_worker_test = executable(
'test/src/RTC/TestTrendCalculator.cpp',
'test/src/RTC/TestRtpEncodingParameters.cpp',
'test/src/RTC/Codecs/TestVP8.cpp',
'test/src/RTC/Codecs/TestVP9.cpp',
'test/src/RTC/Codecs/TestH264.cpp',
'test/src/RTC/Codecs/TestH264_SVC.cpp',
'test/src/RTC/RTCP/TestFeedbackPsAfb.cpp',
Expand Down
2 changes: 1 addition & 1 deletion worker/src/RTC/Codecs/VP8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ namespace RTC
this->payloadDescriptor->hasPictureId &&
this->payloadDescriptor->hasTlIndex &&
this->payloadDescriptor->hasTl0PictureIndex &&
!RTC::SeqManager<uint16_t>::IsSeqLowerThan(
!RTC::SeqManager<uint16_t, 15>::IsSeqLowerThan(
this->payloadDescriptor->pictureId,
context->pictureIdManager.GetMaxInput())
)
Expand Down
2 changes: 1 addition & 1 deletion worker/src/RTC/Codecs/VP9.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ namespace RTC
// clang-format off
bool isOldPacket = (
this->payloadDescriptor->hasPictureId &&
RTC::SeqManager<uint16_t>::IsSeqLowerThan(
RTC::SeqManager<uint16_t, 15>::IsSeqLowerThan(
this->payloadDescriptor->pictureId,
context->pictureIdManager.GetMaxInput())
);
Expand Down
77 changes: 43 additions & 34 deletions worker/src/RTC/SeqManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,43 +7,51 @@

namespace RTC
{
template<typename T>
bool SeqManager<T>::SeqLowerThan::operator()(const T lhs, const T rhs) const
template<typename T, uint8_t N>
bool SeqManager<T, N>::SeqLowerThan::operator()(const T lhs, const T rhs) const
{
return ((rhs > lhs) && (rhs - lhs <= MaxValue / 2)) ||
((lhs > rhs) && (lhs - rhs > MaxValue / 2));
}

template<typename T>
bool SeqManager<T>::SeqHigherThan::operator()(const T lhs, const T rhs) const
template<typename T, uint8_t N>
bool SeqManager<T, N>::SeqHigherThan::operator()(const T lhs, const T rhs) const
{
return ((lhs > rhs) && (lhs - rhs <= MaxValue / 2)) ||
((rhs > lhs) && (rhs - lhs > MaxValue / 2));
}

template<typename T>
const typename SeqManager<T>::SeqLowerThan SeqManager<T>::isSeqLowerThan{};
template<typename T, uint8_t N>
const typename SeqManager<T, N>::SeqLowerThan SeqManager<T, N>::isSeqLowerThan{};

template<typename T>
const typename SeqManager<T>::SeqHigherThan SeqManager<T>::isSeqHigherThan{};
template<typename T, uint8_t N>
const typename SeqManager<T, N>::SeqHigherThan SeqManager<T, N>::isSeqHigherThan{};

template<typename T>
bool SeqManager<T>::IsSeqLowerThan(const T lhs, const T rhs)
template<typename T, uint8_t N>
bool SeqManager<T, N>::IsSeqLowerThan(const T lhs, const T rhs)
{
return isSeqLowerThan(lhs, rhs);
}

template<typename T>
bool SeqManager<T>::IsSeqHigherThan(const T lhs, const T rhs)
template<typename T, uint8_t N>
bool SeqManager<T, N>::IsSeqHigherThan(const T lhs, const T rhs)
{
return isSeqHigherThan(lhs, rhs);
}

template<typename T>
void SeqManager<T>::Sync(T input)
template<typename T, uint8_t N>
T SeqManager<T, N>::Delta(const T lhs, const T rhs)
{
T value = (lhs > rhs) ? (lhs - rhs) : (MaxValue - rhs + lhs);

return value & MaxValue;
}

template<typename T, uint8_t N>
void SeqManager<T, N>::Sync(T input)
{
// Update base.
this->base = this->maxOutput - input;
this->base = (this->maxOutput - input) & MaxValue;

// Update maxInput.
this->maxInput = input;
Expand All @@ -52,24 +60,24 @@ namespace RTC
this->dropped.clear();
}

template<typename T>
void SeqManager<T>::Drop(T input)
template<typename T, uint8_t N>
void SeqManager<T, N>::Drop(T input)
{
// Mark as dropped if 'input' is higher than anyone already processed.
if (SeqManager<T>::IsSeqHigherThan(input, this->maxInput))
if (SeqManager<T, N>::IsSeqHigherThan(input, this->maxInput))
{
this->dropped.insert(input);
}
}

template<typename T>
void SeqManager<T>::Offset(T offset)
template<typename T, uint8_t N>
void SeqManager<T, N>::Offset(T offset)
{
this->base += offset;
this->base = (this->base + offset) & MaxValue;
}

template<typename T>
bool SeqManager<T>::Input(const T input, T& output)
template<typename T, uint8_t N>
bool SeqManager<T, N>::Input(const T input, T& output)
{
auto base = this->base;

Expand All @@ -78,10 +86,10 @@ namespace RTC
{
// Delete dropped inputs older than input - MaxValue/2.
size_t droppedCount = this->dropped.size();
auto it = this->dropped.lower_bound(input - MaxValue / 2);

size_t threshold = (input - MaxValue / 2) & MaxValue;
auto it = this->dropped.lower_bound(threshold);
this->dropped.erase(this->dropped.begin(), it);
this->base -= (droppedCount - this->dropped.size());
this->base = (this->base - (droppedCount - this->dropped.size())) & MaxValue;

// Count dropped entries before 'input' in order to adapt the base.
droppedCount = this->dropped.size();
Expand All @@ -100,13 +108,13 @@ namespace RTC
droppedCount -= std::distance(it, this->dropped.end());
}

base = this->base - droppedCount;
base = (this->base - droppedCount) & MaxValue;
}

output = input + base;
output = (input + base) & MaxValue;

T idelta = input - this->maxInput;
T odelta = output - this->maxOutput;
T idelta = SeqManager<T, N>::Delta(input, this->maxInput);
T odelta = SeqManager<T, N>::Delta(output, this->maxOutput);

// New input is higher than the maximum seen. But less than acceptable units higher.
// Keep it as the maximum seen. See Drop().
Expand All @@ -121,21 +129,22 @@ namespace RTC
return true;
}

template<typename T>
T SeqManager<T>::GetMaxInput() const
template<typename T, uint8_t N>
T SeqManager<T, N>::GetMaxInput() const
{
return this->maxInput;
}

template<typename T>
T SeqManager<T>::GetMaxOutput() const
template<typename T, uint8_t N>
T SeqManager<T, N>::GetMaxOutput() const
{
return this->maxOutput;
}

// Explicit instantiation to have all SeqManager definitions in this file.
template class SeqManager<uint8_t>;
template class SeqManager<uint16_t>;
template class SeqManager<uint16_t, 15>; // For PictureID (15 bits).
template class SeqManager<uint32_t>;

} // namespace RTC
1 change: 1 addition & 0 deletions worker/src/RTC/SvcConsumer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ namespace RTC
MS_DEBUG_TAG(rtp, "sync key frame received");

this->rtpSeqManager.Sync(packet->GetSequenceNumber() - 1);
this->encodingContext->SyncRequired();

this->syncRequired = false;
}
Expand Down
38 changes: 34 additions & 4 deletions worker/test/src/RTC/Codecs/TestVP8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

using namespace RTC;

constexpr uint16_t MaxPictureId = (1 << 15) - 1;

SCENARIO("parse VP8 payload descriptor", "[codecs][vp8]")
{
SECTION("parse payload descriptor")
Expand Down Expand Up @@ -274,22 +276,50 @@ SCENARIO("process VP8 payload descriptor", "[codecs][vp8]")
context.SetCurrentTemporalLayer(0);
context.SetTargetTemporalLayer(0);

// Frame 1
// Frame 1.
auto forwarded = ProcessPacket(context, 0, 0, 0);
REQUIRE(forwarded);
REQUIRE(forwarded->pictureId == 0);
REQUIRE(forwarded->tl0PictureIndex == 0);

// Frame 2 gets lost
// Frame 2 gets lost.

// Frame 3
// Frame 3.
forwarded = ProcessPacket(context, 2, 1, 1);
REQUIRE_FALSE(forwarded);

// Frame 2 retransmitted
// Frame 2 retransmitted.
forwarded = ProcessPacket(context, 1, 1, 0);
REQUIRE(forwarded);
REQUIRE(forwarded->pictureId == 1);
REQUIRE(forwarded->tl0PictureIndex == 1);
}

SECTION("drop packets that belong to other temporal layers after rolling over pictureID")
{
RTC::Codecs::EncodingContext::Params params;
params.spatialLayers = 0;
params.temporalLayers = 2;
Codecs::VP8::EncodingContext context(params);
context.SyncRequired();

context.SetCurrentTemporalLayer(0);
context.SetTargetTemporalLayer(0);

// Frame 1.
auto forwarded = ProcessPacket(context, MaxPictureId, 0, 0);
REQUIRE(forwarded);
REQUIRE(forwarded->pictureId == 1);
REQUIRE(forwarded->tl0PictureIndex == 1);

// Frame 2.
forwarded = ProcessPacket(context, 0, 0, 0);
REQUIRE(forwarded);
REQUIRE(forwarded->pictureId == 2);
REQUIRE(forwarded->tl0PictureIndex == 1);

// Frame 3.
forwarded = ProcessPacket(context, 1, 0, 1);
REQUIRE_FALSE(forwarded);
}
}
78 changes: 78 additions & 0 deletions worker/test/src/RTC/Codecs/TestVP9.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
#include "common.hpp"
#include "RTC/Codecs/VP9.hpp"
#include <catch2/catch.hpp>
#include <cstring> // std::memcmp()

using namespace RTC;

constexpr uint16_t MaxPictureId = (1 << 15) - 1;

Codecs::VP9::PayloadDescriptor* CreateVP9Packet(
uint8_t* buffer, size_t bufferLen, uint16_t pictureId, uint8_t tlIndex)
{
buffer[0] = 0xAD; // I and L bits
uint16_t netPictureId = htons(pictureId);
std::memcpy(buffer + 1, &netPictureId, 2);
buffer[1] |= 0x80;
buffer[3] = tlIndex << 6;

auto* payloadDescriptor = Codecs::VP9::Parse(buffer, bufferLen);

REQUIRE(payloadDescriptor);

return payloadDescriptor;
}

std::unique_ptr<Codecs::VP9::PayloadDescriptor> ProcessVP9Packet(
Codecs::VP9::EncodingContext& context, uint16_t pictureId, uint8_t tlIndex)
{
// clang-format off
uint8_t buffer[] =
{
0xAD, 0x80, 0x00, 0x00, 0x00, 0x00
};
// clang-format on
bool marker;
auto* payloadDescriptor = CreateVP9Packet(buffer, sizeof(buffer), pictureId, tlIndex);
std::unique_ptr<Codecs::VP9::PayloadDescriptorHandler> payloadDescriptorHandler(
new Codecs::VP9::PayloadDescriptorHandler(payloadDescriptor));

if (payloadDescriptorHandler->Process(&context, buffer, marker))
{
return std::unique_ptr<Codecs::VP9::PayloadDescriptor>(Codecs::VP9::Parse(buffer, sizeof(buffer)));
}

return nullptr;
}

SCENARIO("process VP9 payload descriptor", "[codecs][vp9]")
{
SECTION("drop packets that belong to other temporal layers after rolling over pictureID")
{
RTC::Codecs::EncodingContext::Params params;
params.spatialLayers = 1;
params.temporalLayers = 3;

Codecs::VP9::EncodingContext context(params);
context.SyncRequired();
context.SetCurrentTemporalLayer(0);
context.SetTargetTemporalLayer(0);

context.SetCurrentSpatialLayer(0);
context.SetTargetSpatialLayer(0);

// Frame 1.
auto forwarded = ProcessVP9Packet(context, MaxPictureId, 0);
REQUIRE(forwarded);
REQUIRE(forwarded->pictureId == MaxPictureId);

// Frame 2.
forwarded = ProcessVP9Packet(context, 0, 0);
REQUIRE(forwarded);
REQUIRE(forwarded->pictureId == 0);

// Frame 3.
forwarded = ProcessVP9Packet(context, 1, 1);
REQUIRE_FALSE(forwarded);
}
}
Loading

0 comments on commit f60845b

Please sign in to comment.