diff --git a/node/src/DataConsumer.ts b/node/src/DataConsumer.ts index 92dee2141d..87d856e823 100644 --- a/node/src/DataConsumer.ts +++ b/node/src/DataConsumer.ts @@ -8,6 +8,7 @@ import { Event, Notification } from './fbs/notification'; import * as FbsTransport from './fbs/transport'; import * as FbsRequest from './fbs/request'; import * as FbsDataConsumer from './fbs/data-consumer'; +import * as utils from './utils'; export type DataConsumerOptions = { @@ -45,6 +46,13 @@ export type DataConsumerOptions = */ paused?: boolean; + /** + * Subchannels this data consumer initially subscribes to. + * Only used in case this data consumer receives messages from a local data + * producer that specifies subchannel(s) when calling send(). + */ + subchannels?: number[]; + /** * Custom application data. */ @@ -93,6 +101,7 @@ type DataConsumerDump = DataConsumerData & id: string; paused: boolean; dataProducerPaused: boolean; + subchannels: number[]; }; type DataConsumerInternal = TransportInternal & @@ -132,6 +141,9 @@ export class DataConsumer // Associated DataProducer paused flag. #dataProducerPaused = false; + // Subchannels subscribed to. + #subchannels: number[]; + // Custom app data. #appData: DataConsumerAppData; @@ -148,6 +160,7 @@ export class DataConsumer channel, paused, dataProducerPaused, + subchannels, appData }: { @@ -156,6 +169,7 @@ export class DataConsumer channel: Channel; paused: boolean; dataProducerPaused: boolean; + subchannels: number[]; appData?: DataConsumerAppData; } ) @@ -169,6 +183,7 @@ export class DataConsumer this.#channel = channel; this.#paused = paused; this.#dataProducerPaused = dataProducerPaused; + this.#subchannels = subchannels; this.#appData = appData || {} as DataConsumerAppData; this.handleWorkerNotifications(); @@ -246,6 +261,14 @@ export class DataConsumer return this.#dataProducerPaused; } + /** + * Get current subchannels this data consumer is subscribed to. + */ + get subchannels(): number[] + { + return Array.from(this.#subchannels); + } + /** * App custom data. */ @@ -541,6 +564,34 @@ export class DataConsumer return data.bufferedAmount(); } + /** + * Set subchannels. + */ + async setSubchannels(subchannels: number[]): Promise + { + logger.debug('setSubchannels()'); + + /* Build Request. */ + const requestOffset = new FbsDataConsumer.SetSubchannelsRequestT( + subchannels + ).pack(this.#channel.bufferBuilder); + + const response = await this.#channel.request( + FbsRequest.Method.DATACONSUMER_SET_SUBCHANNELS, + FbsRequest.Body.FBS_DataConsumer_SetSubchannelsRequest, + requestOffset, + this.#internal.dataConsumerId + ); + + /* Decode Response. */ + const data = new FbsDataConsumer.SetSubchannelsResponse(); + + response.body(data); + + // Update subchannels. + this.#subchannels = utils.parseVector(data, 'subchannels'); + } + private handleWorkerNotifications(): void { this.#channel.on(this.#internal.dataConsumerId, (event: Event, data?: Notification) => @@ -675,14 +726,14 @@ export function parseDataConsumerDumpResponse( label : data.label()!, protocol : data.protocol()!, paused : data.paused(), - dataProducerPaused : data.dataProducerPaused() - + dataProducerPaused : data.dataProducerPaused(), + subchannels : utils.parseVector(data, 'subchannels') }; } function parseDataConsumerStats( binary: FbsDataConsumer.GetStatsResponse -):DataConsumerStat +): DataConsumerStat { return { type : 'data-consumer', diff --git a/node/src/DataProducer.ts b/node/src/DataProducer.ts index 929856ad1d..467d9ae130 100644 --- a/node/src/DataProducer.ts +++ b/node/src/DataProducer.ts @@ -387,7 +387,12 @@ export class DataProducer /** * Send data (just valid for DataProducers created on a DirectTransport). */ - send(message: string | Buffer, ppid?: number): void + send( + message: string | Buffer, + ppid?: number, + subchannels?: number[], + requiredSubchannel?: number + ): void { if (typeof message !== 'string' && !Buffer.isBuffer(message)) { @@ -431,6 +436,10 @@ export class DataProducer let dataOffset = 0; + const subchannelsOffset = FbsDataProducer.SendNotification.createSubchannelsVector( + builder, subchannels ?? [] + ); + if (typeof message === 'string') { const messageOffset = builder.createString(message); @@ -450,7 +459,9 @@ export class DataProducer typeof message === 'string' ? FbsDataProducer.Data.String : FbsDataProducer.Data.Binary, - dataOffset + dataOffset, + subchannelsOffset, + requiredSubchannel ?? null ); this.#channel.notify( diff --git a/node/src/RtpParameters.ts b/node/src/RtpParameters.ts index f7b7b94554..af9a976da6 100644 --- a/node/src/RtpParameters.ts +++ b/node/src/RtpParameters.ts @@ -2,8 +2,8 @@ import * as flatbuffers from 'flatbuffers'; import { Boolean as FbsBoolean, Double as FbsDouble, - Integer as FbsInteger, - IntegerArray as FbsIntegerArray, + Integer32 as FbsInteger32, + Integer32Array as FbsInteger32Array, String as FbsString, Parameter as FbsParameter, RtcpFeedback as FbsRtcpFeedback, @@ -564,16 +564,15 @@ export function serializeParameters( builder, keyOffset, FbsValue.Boolean, value === true ? 1:0 ); } - else if (typeof value === 'number') { // Integer. if (value % 1 === 0) { - const valueOffset = FbsInteger.createInteger(builder, value); + const valueOffset = FbsInteger32.createInteger32(builder, value); parameterOffset = FbsParameter.createParameter( - builder, keyOffset, FbsValue.Integer, valueOffset + builder, keyOffset, FbsValue.Integer32, valueOffset ); } // Float. @@ -586,7 +585,6 @@ export function serializeParameters( ); } } - else if (typeof value === 'string') { const valueOffset = FbsString.createString(builder, builder.createString(value)); @@ -595,16 +593,14 @@ export function serializeParameters( builder, keyOffset, FbsValue.String, valueOffset ); } - else if (Array.isArray(value)) { - const valueOffset = FbsIntegerArray.createValueVector(builder, value); + const valueOffset = FbsInteger32Array.createValueVector(builder, value); parameterOffset = FbsParameter.createParameter( - builder, keyOffset, FbsValue.IntegerArray, valueOffset + builder, keyOffset, FbsValue.Integer32Array, valueOffset ); } - else { throw new Error(`invalid parameter type [key:'${key}', value:${value}]`); @@ -645,9 +641,9 @@ export function parseParameters(data: any): any break; } - case FbsValue.Integer: + case FbsValue.Integer32: { - const value = new FbsInteger(); + const value = new FbsInteger32(); fbsParameter.value(value); @@ -678,9 +674,9 @@ export function parseParameters(data: any): any break; } - case FbsValue.IntegerArray: + case FbsValue.Integer32Array: { - const value = new FbsIntegerArray(); + const value = new FbsInteger32Array(); fbsParameter.value(value); diff --git a/node/src/Transport.ts b/node/src/Transport.ts index 0d4b237dd6..50e53f1594 100644 --- a/node/src/Transport.ts +++ b/node/src/Transport.ts @@ -1080,6 +1080,7 @@ export class Transport maxPacketLifeTime, maxRetransmits, paused = false, + subchannels, appData }: DataConsumerOptions ): Promise> @@ -1163,7 +1164,8 @@ export class Transport sctpStreamParameters, label, protocol, - paused + paused, + subchannels }); const response = await this.channel.request( @@ -1197,6 +1199,7 @@ export class Transport }, channel : this.channel, paused : dump.paused, + subchannels : dump.subchannels, dataProducerPaused : dump.dataProducerPaused, appData }); @@ -1680,9 +1683,10 @@ function createConsumeDataRequest({ sctpStreamParameters, label, protocol, - paused + paused, + subchannels = [] } : { - builder : flatbuffers.Builder; + builder: flatbuffers.Builder; dataConsumerId: string; dataProducerId: string; type: DataConsumerType; @@ -1690,6 +1694,7 @@ function createConsumeDataRequest({ label: string; protocol: string; paused: boolean; + subchannels?: number[]; }): number { const dataConsumerIdOffset = builder.createString(dataConsumerId); @@ -1707,6 +1712,10 @@ function createConsumeDataRequest({ ); } + const subchannelsOffset = FbsTransport.ConsumeDataRequest.createSubchannelsVector( + builder, subchannels + ); + FbsTransport.ConsumeDataRequest.startConsumeDataRequest(builder); FbsTransport.ConsumeDataRequest.addDataConsumerId(builder, dataConsumerIdOffset); FbsTransport.ConsumeDataRequest.addDataProducerId(builder, dataProducerIdOffset); @@ -1722,6 +1731,7 @@ function createConsumeDataRequest({ FbsTransport.ConsumeDataRequest.addLabel(builder, labelOffset); FbsTransport.ConsumeDataRequest.addProtocol(builder, protocolOffset); FbsTransport.ConsumeDataRequest.addPaused(builder, paused); + FbsTransport.ConsumeDataRequest.addSubchannels(builder, subchannelsOffset); return FbsTransport.ConsumeDataRequest.endConsumeDataRequest(builder); } diff --git a/node/src/tests/test-DataConsumer.ts b/node/src/tests/test-DataConsumer.ts index 45973919bd..7f13b1aab8 100644 --- a/node/src/tests/test-DataConsumer.ts +++ b/node/src/tests/test-DataConsumer.ts @@ -53,6 +53,9 @@ test('transport.consumeData() succeeds', async () => { dataProducerId : dataProducer.id, maxPacketLifeTime : 4000, + // Valid values are 0...65535 so others and duplicated ones will be + // discarded. + subchannels : [ 0, 1, 1, 1, 2, 65535, 65536, 65537, 100 ], appData : { baz: 'LOL' } }); @@ -70,6 +73,7 @@ test('transport.consumeData() succeeds', async () => expect(dataConsumer1.label).toBe('foo'); expect(dataConsumer1.protocol).toBe('bar'); expect(dataConsumer1.paused).toBe(false); + expect(dataConsumer1.subchannels.sort((a, b) => a - b)).toEqual([ 0, 1, 2, 100, 65535 ]); expect(dataConsumer1.appData).toEqual({ baz: 'LOL' }); const dump = await router.dump(); @@ -128,6 +132,13 @@ test('dataConsumer.getStats() succeeds', async () => ]); }, 2000); +test('dataConsumer.setSubchannels() succeeds', async () => +{ + await dataConsumer1.setSubchannels([ 999, 999, 998, 65536 ]); + + expect(dataConsumer1.subchannels.sort((a, b) => a - b)).toEqual([ 0, 998, 999 ]); +}, 2000); + test('transport.consumeData() on a DirectTransport succeeds', async () => { const onObserverNewDataConsumer = jest.fn(); diff --git a/node/src/tests/test-DirectTransport.ts b/node/src/tests/test-DirectTransport.ts index f495b4004f..580185d96d 100644 --- a/node/src/tests/test-DirectTransport.ts +++ b/node/src/tests/test-DirectTransport.ts @@ -246,6 +246,153 @@ test('dataProducer.send() succeeds', async () => ]); }, 5000); +test('dataProducer.send() with subchannels succeeds', async () => +{ + const transport2 = await router.createDirectTransport(); + const dataProducer = await transport2.produceData(); + const dataConsumer1 = await transport2.consumeData( + { + dataProducerId : dataProducer.id, + subchannels : [ 1, 11, 666 ] + }); + const dataConsumer2 = await transport2.consumeData( + { + dataProducerId : dataProducer.id, + subchannels : [ 2, 22, 666 ] + }); + const expectedReceivedNumMessages1 = 7; + const expectedReceivedNumMessages2 = 5; + const receivedMessages1: string[] = []; + const receivedMessages2: string[] = []; + + // eslint-disable-next-line no-async-promise-executor + await new Promise(async (resolve) => + { + // Must be received by dataConsumer1 and dataConsumer2. + dataProducer.send( + 'both', + /* ppid */ undefined, + /* subchannels */ undefined, + /* requiredSubchannel */ undefined + ); + + // Must be received by dataConsumer1 and dataConsumer2. + dataProducer.send( + 'both', + /* ppid */ undefined, + /* subchannels */ [ 1, 2 ], + /* requiredSubchannel */ undefined + ); + + // Must be received by dataConsumer1 and dataConsumer2. + dataProducer.send( + 'both', + /* ppid */ undefined, + /* subchannels */ [ 11, 22, 33 ], + /* requiredSubchannel */ 666 + ); + + // Must not be received by neither dataConsumer1 nor dataConsumer2. + dataProducer.send( + 'none', + /* ppid */ undefined, + /* subchannels */ [ 3 ], + /* requiredSubchannel */ 666 + ); + + // Must not be received by neither dataConsumer1 nor dataConsumer2. + dataProducer.send( + 'none', + /* ppid */ undefined, + /* subchannels */ [ 666 ], + /* requiredSubchannel */ 3 + ); + + // Must be received by dataConsumer1. + dataProducer.send( + 'dc1', + /* ppid */ undefined, + /* subchannels */ [ 1 ], + /* requiredSubchannel */ undefined + ); + + // Must be received by dataConsumer1. + dataProducer.send( + 'dc1', + /* ppid */ undefined, + /* subchannels */ [ 11 ], + /* requiredSubchannel */ 1 + ); + + // Must be received by dataConsumer1. + dataProducer.send( + 'dc1', + /* ppid */ undefined, + /* subchannels */ [ 666 ], + /* requiredSubchannel */ 11 + ); + + // Must be received by dataConsumer2. + dataProducer.send( + 'dc2', + /* ppid */ undefined, + /* subchannels */ [ 666 ], + /* requiredSubchannel */ 2 + ); + + // Make dataConsumer2 also subscribe to subchannel 1. + // NOTE: No need to await for this call. + void dataConsumer2.setSubchannels([ ...dataConsumer2.subchannels, 1 ]); + + // Must be received by dataConsumer1 and dataConsumer2. + dataProducer.send( + 'both', + /* ppid */ undefined, + /* subchannels */ [ 1 ], + /* requiredSubchannel */ 666 + ); + + dataConsumer1.on('message', (message) => + { + receivedMessages1.push(message.toString('utf8')); + + if ( + receivedMessages1.length === expectedReceivedNumMessages1 && + receivedMessages2.length === expectedReceivedNumMessages2 + ) + { + resolve(); + } + }); + + dataConsumer2.on('message', (message) => + { + receivedMessages2.push(message.toString('utf8')); + + if ( + receivedMessages1.length === expectedReceivedNumMessages1 && + receivedMessages2.length === expectedReceivedNumMessages2 + ) + { + resolve(); + } + }); + }); + + expect(receivedMessages1.length).toBe(expectedReceivedNumMessages1); + expect(receivedMessages2.length).toBe(expectedReceivedNumMessages2); + + for (const message of receivedMessages1) + { + expect([ 'both', 'dc1' ].includes(message)).toBe(true); + } + + for (const message of receivedMessages2) + { + expect([ 'both', 'dc2' ].includes(message)).toBe(true); + } +}, 5000); + test('DirectTransport methods reject if closed', async () => { const onObserverClose = jest.fn(); diff --git a/node/src/utils.ts b/node/src/utils.ts index 77da462bc3..b67e382acd 100644 --- a/node/src/utils.ts +++ b/node/src/utils.ts @@ -3,16 +3,27 @@ import { ProducerType } from './Producer'; import { Type as FbsRtpParametersType } from './fbs/rtp-parameters'; /** - * Clones the given object/array. + * Clones the given value. */ -export function clone(data: any): any +export function clone(value: T): T { - if (typeof data !== 'object') + if (value === undefined) { - return {}; + return undefined as unknown as T; + } + else if (Number.isNaN(value)) + { + return NaN as unknown as T; + } + else if (typeof structuredClone === 'function') + { + // Available in Node >= 18. + return structuredClone(value); + } + else + { + return JSON.parse(JSON.stringify(value)); } - - return JSON.parse(JSON.stringify(data)); } /** diff --git a/worker/fbs/dataConsumer.fbs b/worker/fbs/dataConsumer.fbs index 54aa7a0904..89e29c6ad2 100644 --- a/worker/fbs/dataConsumer.fbs +++ b/worker/fbs/dataConsumer.fbs @@ -20,6 +20,7 @@ table DumpResponse { protocol:string (required); paused:bool; data_producer_paused:bool; + subchannels:[uint16]; } table GetStatsResponse { @@ -49,6 +50,14 @@ table SendRequest { data:Data (required); } +table SetSubchannelsRequest { + subchannels:[uint16]; +} + +table SetSubchannelsResponse { + subchannels:[uint16]; +} + // Notifications from Worker. table BufferedAmountLowNotification { diff --git a/worker/fbs/dataProducer.fbs b/worker/fbs/dataProducer.fbs index 969939cac9..5815733f28 100644 --- a/worker/fbs/dataProducer.fbs +++ b/worker/fbs/dataProducer.fbs @@ -36,4 +36,6 @@ union Data { table SendNotification { ppid:uint8; data:Data (required); + subchannels:[uint16]; + required_subchannel:uint16 = null; } diff --git a/worker/fbs/request.fbs b/worker/fbs/request.fbs index 7948723a00..ac1ad5fd99 100644 --- a/worker/fbs/request.fbs +++ b/worker/fbs/request.fbs @@ -71,6 +71,7 @@ enum Method: uint8 { DATACONSUMER_GET_BUFFERED_AMOUNT, DATACONSUMER_SET_BUFFERED_AMOUNT_LOW_THRESHOLD, DATACONSUMER_SEND, + DATACONSUMER_SET_SUBCHANNELS, RTPOBSERVER_PAUSE, RTPOBSERVER_RESUME, RTPOBSERVER_ADD_PRODUCER, @@ -112,6 +113,7 @@ union Body { FBS.Consumer.EnableTraceEventRequest, FBS.DataConsumer.SetBufferedAmountLowThresholdRequest, FBS.DataConsumer.SendRequest, + FBS.DataConsumer.SetSubchannelsRequest, FBS.RtpObserver.AddProducerRequest, FBS.RtpObserver.RemoveProducerRequest, } diff --git a/worker/fbs/response.fbs b/worker/fbs/response.fbs index bd94fbe2db..41e8cbe422 100644 --- a/worker/fbs/response.fbs +++ b/worker/fbs/response.fbs @@ -39,6 +39,7 @@ union Body { FBS_DataConsumer_GetBufferedAmountResponse: FBS.DataConsumer.GetBufferedAmountResponse, FBS_DataConsumer_DumpResponse: FBS.DataConsumer.DumpResponse, FBS_DataConsumer_GetStatsResponse: FBS.DataConsumer.GetStatsResponse, + FBS_DataConsumer_SetSubchannelsResponse: FBS.DataConsumer.SetSubchannelsResponse, } table Response { diff --git a/worker/fbs/rtpParameters.fbs b/worker/fbs/rtpParameters.fbs index 35a5ff84e0..e5e34fca6b 100644 --- a/worker/fbs/rtpParameters.fbs +++ b/worker/fbs/rtpParameters.fbs @@ -9,11 +9,11 @@ table Boolean { value:uint8; } -table Integer { +table Integer32 { value:int32; } -table IntegerArray { +table Integer32Array { value:[int32]; } @@ -27,10 +27,10 @@ table String { union Value { Boolean, - Integer, + Integer32, Double, String, - IntegerArray, + Integer32Array, } table Parameter { diff --git a/worker/fbs/transport.fbs b/worker/fbs/transport.fbs index 6e5914f366..5b7af3ef61 100644 --- a/worker/fbs/transport.fbs +++ b/worker/fbs/transport.fbs @@ -72,6 +72,7 @@ table ConsumeDataRequest { label:string; protocol:string; paused:bool = false; + subchannels:[uint16]; } table Tuple { diff --git a/worker/include/RTC/DataConsumer.hpp b/worker/include/RTC/DataConsumer.hpp index 17e0e2c3e5..c7737dbd12 100644 --- a/worker/include/RTC/DataConsumer.hpp +++ b/worker/include/RTC/DataConsumer.hpp @@ -6,6 +6,7 @@ #include "Channel/ChannelSocket.hpp" #include "RTC/SctpDictionaries.hpp" #include "RTC/Shared.hpp" +#include #include namespace RTC @@ -28,9 +29,9 @@ namespace RTC public: virtual void OnDataConsumerSendMessage( RTC::DataConsumer* dataConsumer, - uint32_t ppid, const uint8_t* msg, size_t len, + uint32_t ppid, onQueuedCallback* cb) = 0; virtual void OnDataConsumerDataProducerClosed(RTC::DataConsumer* dataConsumer) = 0; }; @@ -97,7 +98,13 @@ namespace RTC void SctpAssociationBufferedAmount(uint32_t bufferedAmount); void SctpAssociationSendBufferFull(); void DataProducerClosed(); - void SendMessage(uint32_t ppid, const uint8_t* msg, size_t len, onQueuedCallback* = nullptr); + void SendMessage( + const uint8_t* msg, + size_t len, + uint32_t ppid, + std::vector& subchannels, + std::optional requiredSubchannel, + onQueuedCallback* = nullptr); /* Methods inherited from Channel::ChannelSocket::RequestHandler. */ public: @@ -120,6 +127,7 @@ namespace RTC RTC::SctpStreamParameters sctpStreamParameters; std::string label; std::string protocol; + absl::flat_hash_set subchannels; bool transportConnected{ false }; bool sctpAssociationConnected{ false }; bool paused{ false }; diff --git a/worker/include/RTC/DataProducer.hpp b/worker/include/RTC/DataProducer.hpp index cdff1842d5..e024c6b4d3 100644 --- a/worker/include/RTC/DataProducer.hpp +++ b/worker/include/RTC/DataProducer.hpp @@ -8,6 +8,7 @@ #include "RTC/SctpDictionaries.hpp" #include "RTC/Shared.hpp" #include +#include namespace RTC { @@ -23,9 +24,14 @@ namespace RTC public: virtual void OnDataProducerReceiveData(RTC::DataProducer* producer, size_t len) = 0; virtual void OnDataProducerMessageReceived( - RTC::DataProducer* dataProducer, uint32_t ppid, const uint8_t* msg, size_t len) = 0; - virtual void OnDataProducerPaused(RTC::DataProducer* dataProducer) = 0; - virtual void OnDataProducerResumed(RTC::DataProducer* dataProducer) = 0; + RTC::DataProducer* dataProducer, + const uint8_t* msg, + size_t len, + uint32_t ppid, + std::vector& subchannels, + std::optional requiredSubchannel) = 0; + virtual void OnDataProducerPaused(RTC::DataProducer* dataProducer) = 0; + virtual void OnDataProducerResumed(RTC::DataProducer* dataProducer) = 0; }; public: @@ -61,7 +67,12 @@ namespace RTC { return this->paused; } - void ReceiveMessage(uint32_t ppid, const uint8_t* msg, size_t len); + void ReceiveMessage( + const uint8_t* msg, + size_t len, + uint32_t ppid, + std::vector& subchannels, + std::optional requiredSubchannel); /* Methods inherited from Channel::ChannelSocket::RequestHandler. */ public: diff --git a/worker/include/RTC/DirectTransport.hpp b/worker/include/RTC/DirectTransport.hpp index 9b687c6086..7352f0ad75 100644 --- a/worker/include/RTC/DirectTransport.hpp +++ b/worker/include/RTC/DirectTransport.hpp @@ -32,9 +32,9 @@ namespace RTC void SendRtcpCompoundPacket(RTC::RTCP::CompoundPacket* packet) override; void SendMessage( RTC::DataConsumer* dataConsumer, - uint32_t ppid, const uint8_t* msg, size_t len, + uint32_t ppid, onQueuedCallback* cb = nullptr) override; void SendSctpData(const uint8_t* data, size_t len) override; void RecvStreamClosed(uint32_t ssrc) override; diff --git a/worker/include/RTC/PipeTransport.hpp b/worker/include/RTC/PipeTransport.hpp index 07ccbc10d5..6a2d803648 100644 --- a/worker/include/RTC/PipeTransport.hpp +++ b/worker/include/RTC/PipeTransport.hpp @@ -50,9 +50,9 @@ namespace RTC void SendRtcpCompoundPacket(RTC::RTCP::CompoundPacket* packet) override; void SendMessage( RTC::DataConsumer* dataConsumer, - uint32_t ppid, const uint8_t* msg, size_t len, + uint32_t ppid, onQueuedCallback* cb = nullptr) override; void SendSctpData(const uint8_t* data, size_t len) override; void RecvStreamClosed(uint32_t ssrc) override; diff --git a/worker/include/RTC/PlainTransport.hpp b/worker/include/RTC/PlainTransport.hpp index 202dcc6a21..c0cd6f5901 100644 --- a/worker/include/RTC/PlainTransport.hpp +++ b/worker/include/RTC/PlainTransport.hpp @@ -48,9 +48,9 @@ namespace RTC void SendRtcpCompoundPacket(RTC::RTCP::CompoundPacket* packet) override; void SendMessage( RTC::DataConsumer* dataConsumer, - uint32_t ppid, const uint8_t* msg, size_t len, + uint32_t ppid, onQueuedCallback* cb = nullptr) override; void SendSctpData(const uint8_t* data, size_t len) override; void RecvStreamClosed(uint32_t ssrc) override; diff --git a/worker/include/RTC/Router.hpp b/worker/include/RTC/Router.hpp index f74b735aca..46b96f7d65 100644 --- a/worker/include/RTC/Router.hpp +++ b/worker/include/RTC/Router.hpp @@ -15,8 +15,9 @@ #include "RTC/Transport.hpp" #include "RTC/WebRtcServer.hpp" #include +#include #include -#include +#include namespace RTC { @@ -96,9 +97,11 @@ namespace RTC void OnTransportDataProducerMessageReceived( RTC::Transport* transport, RTC::DataProducer* dataProducer, - uint32_t ppid, const uint8_t* msg, - size_t len) override; + size_t len, + uint32_t ppid, + std::vector& subchannels, + std::optional requiredSubchannel) override; void OnTransportNewDataConsumer( RTC::Transport* transport, RTC::DataConsumer* dataConsumer, std::string& dataProducerId) override; void OnTransportDataConsumerClosed(RTC::Transport* transport, RTC::DataConsumer* dataConsumer) override; diff --git a/worker/include/RTC/SctpAssociation.hpp b/worker/include/RTC/SctpAssociation.hpp index 0b52a05ebe..146aca32bd 100644 --- a/worker/include/RTC/SctpAssociation.hpp +++ b/worker/include/RTC/SctpAssociation.hpp @@ -47,9 +47,9 @@ namespace RTC virtual void OnSctpAssociationMessageReceived( RTC::SctpAssociation* sctpAssociation, uint16_t streamId, - uint32_t ppid, const uint8_t* msg, - size_t len) = 0; + size_t len, + uint32_t ppid) = 0; virtual void OnSctpAssociationBufferedAmount( RTC::SctpAssociation* sctpAssociation, uint32_t len) = 0; }; @@ -92,9 +92,9 @@ namespace RTC void ProcessSctpData(const uint8_t* data, size_t len); void SendSctpMessage( RTC::DataConsumer* dataConsumer, - uint32_t ppid, const uint8_t* msg, size_t len, + uint32_t ppid, onQueuedCallback* cb = nullptr); void HandleDataConsumer(RTC::DataConsumer* dataConsumer); void DataProducerClosed(RTC::DataProducer* dataProducer); diff --git a/worker/include/RTC/Transport.hpp b/worker/include/RTC/Transport.hpp index 54c4f95dfe..eb0013677b 100644 --- a/worker/include/RTC/Transport.hpp +++ b/worker/include/RTC/Transport.hpp @@ -30,6 +30,7 @@ #include "handles/TimerHandle.hpp" #include #include +#include namespace RTC { @@ -103,9 +104,11 @@ namespace RTC virtual void OnTransportDataProducerMessageReceived( RTC::Transport* transport, RTC::DataProducer* dataProducer, - uint32_t ppid, const uint8_t* msg, - size_t len) = 0; + size_t len, + uint32_t ppid, + std::vector& subchannels, + std::optional requiredSubchannel) = 0; virtual void OnTransportNewDataConsumer( RTC::Transport* transport, RTC::DataConsumer* dataConsumer, std::string& dataProducerId) = 0; virtual void OnTransportDataConsumerClosed( @@ -188,9 +191,9 @@ namespace RTC virtual void SendRtcpCompoundPacket(RTC::RTCP::CompoundPacket* packet) = 0; virtual void SendMessage( RTC::DataConsumer* dataConsumer, - uint32_t ppid, const uint8_t* msg, size_t len, + uint32_t ppid, onQueuedCallback* = nullptr) = 0; virtual void SendSctpData(const uint8_t* data, size_t len) = 0; virtual void RecvStreamClosed(uint32_t ssrc) = 0; @@ -245,7 +248,12 @@ namespace RTC this->DataReceived(len); } void OnDataProducerMessageReceived( - RTC::DataProducer* dataProducer, uint32_t ppid, const uint8_t* msg, size_t len) override; + RTC::DataProducer* dataProducer, + const uint8_t* msg, + size_t len, + uint32_t ppid, + std::vector& subchannels, + std::optional requiredSubchannel) override; void OnDataProducerPaused(RTC::DataProducer* dataProducer) override; void OnDataProducerResumed(RTC::DataProducer* dataProducer) override; @@ -253,9 +261,9 @@ namespace RTC public: void OnDataConsumerSendMessage( RTC::DataConsumer* dataConsumer, - uint32_t ppid, const uint8_t* msg, size_t len, + uint32_t ppid, onQueuedCallback* = nullptr) override; void OnDataConsumerDataProducerClosed(RTC::DataConsumer* dataConsumer) override; @@ -270,9 +278,9 @@ namespace RTC void OnSctpAssociationMessageReceived( RTC::SctpAssociation* sctpAssociation, uint16_t streamId, - uint32_t ppid, const uint8_t* msg, - size_t len) override; + size_t len, + uint32_t ppid) override; void OnSctpAssociationBufferedAmount( RTC::SctpAssociation* sctpAssociation, uint32_t bufferedAmount) override; diff --git a/worker/include/RTC/WebRtcTransport.hpp b/worker/include/RTC/WebRtcTransport.hpp index 2687c4c2d6..e88047980c 100644 --- a/worker/include/RTC/WebRtcTransport.hpp +++ b/worker/include/RTC/WebRtcTransport.hpp @@ -86,9 +86,9 @@ namespace RTC void SendRtcpCompoundPacket(RTC::RTCP::CompoundPacket* packet) override; void SendMessage( RTC::DataConsumer* dataConsumer, - uint32_t ppid, const uint8_t* msg, size_t len, + uint32_t ppid, onQueuedCallback* cb = nullptr) override; void SendSctpData(const uint8_t* data, size_t len) override; void RecvStreamClosed(uint32_t ssrc) override; diff --git a/worker/include/common.hpp b/worker/include/common.hpp index 5b1ba0dedd..3934db04b8 100644 --- a/worker/include/common.hpp +++ b/worker/include/common.hpp @@ -7,6 +7,7 @@ #include // uint8_t, etc #include // std::function #include // std::addressof() +#include #ifdef _WIN32 #include // Avoid uv/win.h: error C2628 'intptr_t' followed by 'int' is illegal. diff --git a/worker/src/Channel/ChannelRequest.cpp b/worker/src/Channel/ChannelRequest.cpp index 22ec2d04c9..e69bbb3da7 100644 --- a/worker/src/Channel/ChannelRequest.cpp +++ b/worker/src/Channel/ChannelRequest.cpp @@ -74,6 +74,7 @@ namespace Channel { FBS::Request::Method::DATACONSUMER_GET_BUFFERED_AMOUNT, "dataConsumer.getBufferedAmount" }, { FBS::Request::Method::DATACONSUMER_SET_BUFFERED_AMOUNT_LOW_THRESHOLD, "dataConsumer.setBufferedAmountLowThreshold" }, { FBS::Request::Method::DATACONSUMER_SEND, "dataConsumer.send" }, + { FBS::Request::Method::DATACONSUMER_SET_SUBCHANNELS, "dataConsumer.setSubchannels" }, { FBS::Request::Method::RTPOBSERVER_PAUSE, "rtpObserver.pause" }, { FBS::Request::Method::RTPOBSERVER_RESUME, "rtpObserver.resume" }, { FBS::Request::Method::RTPOBSERVER_ADD_PRODUCER, "rtpObserver.addProducer" }, diff --git a/worker/src/RTC/DataConsumer.cpp b/worker/src/RTC/DataConsumer.cpp index 6ee9176c25..26e2c2bd70 100644 --- a/worker/src/RTC/DataConsumer.cpp +++ b/worker/src/RTC/DataConsumer.cpp @@ -66,6 +66,11 @@ namespace RTC // paused is set to false by default. this->paused = data->paused(); + for (const auto subchannel : *data->subchannels()) + { + this->subchannels.insert(subchannel); + } + // NOTE: This may throw. this->shared->channelMessageRegistrator->RegisterHandler( this->id, @@ -85,12 +90,21 @@ namespace RTC { MS_TRACE(); - flatbuffers::Offset sctpStreamParametersOffset; + flatbuffers::Offset sctpStreamParameters; // Add sctpStreamParameters. if (this->type == DataConsumer::Type::SCTP) { - sctpStreamParametersOffset = this->sctpStreamParameters.FillBuffer(builder); + sctpStreamParameters = this->sctpStreamParameters.FillBuffer(builder); + } + + std::vector subchannels; + + subchannels.reserve(this->subchannels.size()); + + for (auto subchannel : this->subchannels) + { + subchannels.emplace_back(subchannel); } return FBS::DataConsumer::CreateDumpResponseDirect( @@ -98,11 +112,12 @@ namespace RTC this->id.c_str(), this->dataProducerId.c_str(), this->typeString.c_str(), - sctpStreamParametersOffset, + sctpStreamParameters, this->label.c_str(), this->protocol.c_str(), this->paused, - this->dataProducerPaused); + this->dataProducerPaused, + std::addressof(subchannels)); } flatbuffers::Offset DataConsumer::FillBufferStats( @@ -296,7 +311,38 @@ namespace RTC } }); - SendMessage(ppid, data, len, cb); + static std::vector EmptySubchannels; + + SendMessage(data, len, ppid, EmptySubchannels, std::nullopt, cb); + + break; + } + + case Channel::ChannelRequest::Method::DATACONSUMER_SET_SUBCHANNELS: + { + const auto* body = request->data->body_as(); + + this->subchannels.clear(); + + for (const auto subchannel : *body->subchannels()) + { + this->subchannels.insert(subchannel); + } + + std::vector subchannels; + + subchannels.reserve(this->subchannels.size()); + + for (auto subchannel : this->subchannels) + { + subchannels.emplace_back(subchannel); + } + + // Create response. + auto responseOffset = FBS::DataConsumer::CreateSetSubchannelsResponseDirect( + request->GetBufferBuilder(), std::addressof(subchannels)); + + request->Accept(FBS::Response::Body::FBS_DataConsumer_SetSubchannelsResponse, responseOffset); break; } @@ -431,7 +477,13 @@ namespace RTC this->listener->OnDataConsumerDataProducerClosed(this); } - void DataConsumer::SendMessage(uint32_t ppid, const uint8_t* msg, size_t len, onQueuedCallback* cb) + void DataConsumer::SendMessage( + const uint8_t* msg, + size_t len, + uint32_t ppid, + std::vector& subchannels, + std::optional requiredSubchannel, + onQueuedCallback* cb) { MS_TRACE(); @@ -440,6 +492,37 @@ namespace RTC return; } + // If a required subchannel is given, verify that this data consumer is + // subscribed to it. + if ( + requiredSubchannel.has_value() && + this->subchannels.find(requiredSubchannel.value()) == this->subchannels.end()) + { + return; + } + + // If subchannels are given, verify that this data consumer is subscribed + // to at least one of them. + if (subchannels.size() > 0) + { + bool subchannelMatched{ false }; + + for (const auto subchannel : subchannels) + { + if (this->subchannels.find(subchannel) != this->subchannels.end()) + { + subchannelMatched = true; + + break; + } + } + + if (!subchannelMatched) + { + return; + } + } + if (len > this->maxMessageSize) { MS_WARN_TAG( @@ -454,6 +537,6 @@ namespace RTC this->messagesSent++; this->bytesSent += len; - this->listener->OnDataConsumerSendMessage(this, ppid, msg, len, cb); + this->listener->OnDataConsumerSendMessage(this, msg, len, ppid, cb); } } // namespace RTC diff --git a/worker/src/RTC/DataProducer.cpp b/worker/src/RTC/DataProducer.cpp index 7cae8497c9..dc084ed25a 100644 --- a/worker/src/RTC/DataProducer.cpp +++ b/worker/src/RTC/DataProducer.cpp @@ -7,6 +7,7 @@ #include "MediaSoupErrors.hpp" #include "Utils.hpp" #include +#include namespace RTC { @@ -219,7 +220,23 @@ namespace RTC len); } - this->ReceiveMessage(body->ppid(), data, len); + std::vector subchannels; + + subchannels.reserve(body->subchannels()->size()); + + for (const auto subchannel : *body->subchannels()) + { + subchannels.emplace_back(subchannel); + } + + std::optional requiredSubchannel{ std::nullopt }; + + if (body->requiredSubchannel().has_value()) + { + requiredSubchannel = body->requiredSubchannel().value(); + } + + ReceiveMessage(data, len, body->ppid(), subchannels, requiredSubchannel); // Increase receive transmission. this->listener->OnDataProducerReceiveData(this, len); @@ -234,7 +251,12 @@ namespace RTC } } - void DataProducer::ReceiveMessage(uint32_t ppid, const uint8_t* msg, size_t len) + void DataProducer::ReceiveMessage( + const uint8_t* msg, + size_t len, + uint32_t ppid, + std::vector& subchannels, + std::optional requiredSubchannel) { MS_TRACE(); @@ -247,6 +269,7 @@ namespace RTC return; } - this->listener->OnDataProducerMessageReceived(this, ppid, msg, len); + this->listener->OnDataProducerMessageReceived( + this, msg, len, ppid, subchannels, requiredSubchannel); } } // namespace RTC diff --git a/worker/src/RTC/DirectTransport.cpp b/worker/src/RTC/DirectTransport.cpp index 9078f87daa..4c3f4e108a 100644 --- a/worker/src/RTC/DirectTransport.cpp +++ b/worker/src/RTC/DirectTransport.cpp @@ -213,7 +213,7 @@ namespace RTC } void DirectTransport::SendMessage( - RTC::DataConsumer* dataConsumer, uint32_t ppid, const uint8_t* msg, size_t len, onQueuedCallback* cb) + RTC::DataConsumer* dataConsumer, const uint8_t* msg, size_t len, uint32_t ppid, onQueuedCallback* cb) { MS_TRACE(); diff --git a/worker/src/RTC/PipeTransport.cpp b/worker/src/RTC/PipeTransport.cpp index bfdf7d5585..ea17fe8b3f 100644 --- a/worker/src/RTC/PipeTransport.cpp +++ b/worker/src/RTC/PipeTransport.cpp @@ -541,11 +541,11 @@ namespace RTC } void PipeTransport::SendMessage( - RTC::DataConsumer* dataConsumer, uint32_t ppid, const uint8_t* msg, size_t len, onQueuedCallback* cb) + RTC::DataConsumer* dataConsumer, const uint8_t* msg, size_t len, uint32_t ppid, onQueuedCallback* cb) { MS_TRACE(); - this->sctpAssociation->SendSctpMessage(dataConsumer, ppid, msg, len, cb); + this->sctpAssociation->SendSctpMessage(dataConsumer, msg, len, ppid, cb); } void PipeTransport::SendSctpData(const uint8_t* data, size_t len) diff --git a/worker/src/RTC/PlainTransport.cpp b/worker/src/RTC/PlainTransport.cpp index 9310674bff..224efc4fd2 100644 --- a/worker/src/RTC/PlainTransport.cpp +++ b/worker/src/RTC/PlainTransport.cpp @@ -893,11 +893,11 @@ namespace RTC } void PlainTransport::SendMessage( - RTC::DataConsumer* dataConsumer, uint32_t ppid, const uint8_t* msg, size_t len, onQueuedCallback* cb) + RTC::DataConsumer* dataConsumer, const uint8_t* msg, size_t len, uint32_t ppid, onQueuedCallback* cb) { MS_TRACE(); - this->sctpAssociation->SendSctpMessage(dataConsumer, ppid, msg, len, cb); + this->sctpAssociation->SendSctpMessage(dataConsumer, msg, len,ppid, cb); } void PlainTransport::SendSctpData(const uint8_t* data, size_t len) diff --git a/worker/src/RTC/Router.cpp b/worker/src/RTC/Router.cpp index 117c79b0db..364f9a5393 100644 --- a/worker/src/RTC/Router.cpp +++ b/worker/src/RTC/Router.cpp @@ -883,9 +883,11 @@ namespace RTC inline void Router::OnTransportDataProducerMessageReceived( RTC::Transport* /*transport*/, RTC::DataProducer* dataProducer, - uint32_t ppid, const uint8_t* msg, - size_t len) + size_t len, + uint32_t ppid, + std::vector& subchannels, + std::optional requiredSubchannel) { MS_TRACE(); @@ -893,7 +895,7 @@ namespace RTC for (auto* dataConsumer : dataConsumers) { - dataConsumer->SendMessage(ppid, msg, len); + dataConsumer->SendMessage(msg, len, ppid, subchannels, requiredSubchannel); } } diff --git a/worker/src/RTC/RtpDictionaries/Parameters.cpp b/worker/src/RTC/RtpDictionaries/Parameters.cpp index ab3ee0b776..c6ce8df665 100644 --- a/worker/src/RTC/RtpDictionaries/Parameters.cpp +++ b/worker/src/RTC/RtpDictionaries/Parameters.cpp @@ -36,10 +36,10 @@ namespace RTC case Value::Type::INTEGER: { - auto valueOffset = FBS::RtpParameters::CreateInteger(builder, value.integerValue); + auto valueOffset = FBS::RtpParameters::CreateInteger32(builder, value.integerValue); parameters.emplace_back(FBS::RtpParameters::CreateParameterDirect( - builder, key.c_str(), FBS::RtpParameters::Value::Integer, valueOffset.Union())); + builder, key.c_str(), FBS::RtpParameters::Value::Integer32, valueOffset.Union())); break; } @@ -68,10 +68,10 @@ namespace RTC case Value::Type::ARRAY_OF_INTEGERS: { auto valueOffset = - FBS::RtpParameters::CreateIntegerArrayDirect(builder, &value.arrayOfIntegers); + FBS::RtpParameters::CreateInteger32ArrayDirect(builder, &value.arrayOfIntegers); parameters.emplace_back(FBS::RtpParameters::CreateParameterDirect( - builder, key.c_str(), FBS::RtpParameters::Value::IntegerArray, valueOffset.Union())); + builder, key.c_str(), FBS::RtpParameters::Value::Integer32Array, valueOffset.Union())); break; } @@ -99,9 +99,9 @@ namespace RTC break; } - case FBS::RtpParameters::Value::Integer: + case FBS::RtpParameters::Value::Integer32: { - this->mapKeyValues.emplace(key, Value(parameter->value_as_Integer()->value())); + this->mapKeyValues.emplace(key, Value(parameter->value_as_Integer32()->value())); break; } @@ -120,9 +120,9 @@ namespace RTC break; } - case FBS::RtpParameters::Value::IntegerArray: + case FBS::RtpParameters::Value::Integer32Array: { - this->mapKeyValues.emplace(key, Value(parameter->value_as_IntegerArray()->value())); + this->mapKeyValues.emplace(key, Value(parameter->value_as_Integer32Array()->value())); break; } diff --git a/worker/src/RTC/SctpAssociation.cpp b/worker/src/RTC/SctpAssociation.cpp index 89f9bf525f..13922f00c8 100644 --- a/worker/src/RTC/SctpAssociation.cpp +++ b/worker/src/RTC/SctpAssociation.cpp @@ -376,7 +376,7 @@ namespace RTC } void SctpAssociation::SendSctpMessage( - RTC::DataConsumer* dataConsumer, uint32_t ppid, const uint8_t* msg, size_t len, onQueuedCallback* cb) + RTC::DataConsumer* dataConsumer, const uint8_t* msg, size_t len, uint32_t ppid, onQueuedCallback* cb) { MS_TRACE(); @@ -688,7 +688,7 @@ namespace RTC { MS_DEBUG_DEV("directly notifying listener [eor:1, buffer len:0]"); - this->listener->OnSctpAssociationMessageReceived(this, streamId, ppid, data, len); + this->listener->OnSctpAssociationMessageReceived(this, streamId, data, len, ppid); } // If end of message and there is buffered data, append data and notify buffer. else if (eor && this->messageBufferLen != 0) @@ -699,7 +699,7 @@ namespace RTC MS_DEBUG_DEV("notifying listener [eor:1, buffer len:%zu]", this->messageBufferLen); this->listener->OnSctpAssociationMessageReceived( - this, streamId, ppid, this->messageBuffer, this->messageBufferLen); + this, streamId, this->messageBuffer, this->messageBufferLen, ppid); this->messageBufferLen = 0; } diff --git a/worker/src/RTC/Transport.cpp b/worker/src/RTC/Transport.cpp index 78a4bb42ff..1393ddf15a 100644 --- a/worker/src/RTC/Transport.cpp +++ b/worker/src/RTC/Transport.cpp @@ -2646,11 +2646,17 @@ namespace RTC } inline void Transport::OnDataProducerMessageReceived( - RTC::DataProducer* dataProducer, uint32_t ppid, const uint8_t* msg, size_t len) + RTC::DataProducer* dataProducer, + const uint8_t* msg, + size_t len, + uint32_t ppid, + std::vector& subchannels, + std::optional requiredSubchannel) { MS_TRACE(); - this->listener->OnTransportDataProducerMessageReceived(this, dataProducer, ppid, msg, len); + this->listener->OnTransportDataProducerMessageReceived( + this, dataProducer, msg, len, ppid, subchannels, requiredSubchannel); } inline void Transport::OnDataProducerPaused(RTC::DataProducer* dataProducer) @@ -2668,11 +2674,11 @@ namespace RTC } inline void Transport::OnDataConsumerSendMessage( - RTC::DataConsumer* dataConsumer, uint32_t ppid, const uint8_t* msg, size_t len, onQueuedCallback* cb) + RTC::DataConsumer* dataConsumer, const uint8_t* msg, size_t len, uint32_t ppid, onQueuedCallback* cb) { MS_TRACE(); - SendMessage(dataConsumer, ppid, msg, len, cb); + SendMessage(dataConsumer, msg, len, ppid, cb); } inline void Transport::OnDataConsumerDataProducerClosed(RTC::DataConsumer* dataConsumer) @@ -2811,9 +2817,9 @@ namespace RTC inline void Transport::OnSctpAssociationMessageReceived( RTC::SctpAssociation* /*sctpAssociation*/, uint16_t streamId, - uint32_t ppid, const uint8_t* msg, - size_t len) + size_t len, + uint32_t ppid) { MS_TRACE(); @@ -2830,7 +2836,10 @@ namespace RTC // Pass the SCTP message to the corresponding DataProducer. try { - dataProducer->ReceiveMessage(ppid, msg, len); + static std::vector EmptySubchannels; + + dataProducer->ReceiveMessage( + msg, len, ppid, EmptySubchannels, /*requiredSubchannel*/ std::nullopt); } catch (std::exception& error) { diff --git a/worker/src/RTC/WebRtcTransport.cpp b/worker/src/RTC/WebRtcTransport.cpp index 6491ed6458..f1f083f9d2 100644 --- a/worker/src/RTC/WebRtcTransport.cpp +++ b/worker/src/RTC/WebRtcTransport.cpp @@ -899,11 +899,11 @@ namespace RTC } void WebRtcTransport::SendMessage( - RTC::DataConsumer* dataConsumer, uint32_t ppid, const uint8_t* msg, size_t len, onQueuedCallback* cb) + RTC::DataConsumer* dataConsumer, const uint8_t* msg, size_t len, uint32_t ppid, onQueuedCallback* cb) { MS_TRACE(); - this->sctpAssociation->SendSctpMessage(dataConsumer, ppid, msg, len, cb); + this->sctpAssociation->SendSctpMessage(dataConsumer, msg, len, ppid, cb); } void WebRtcTransport::SendSctpData(const uint8_t* data, size_t len)