Skip to content

Commit

Permalink
Showing 7 changed files with 252 additions and 5 deletions.
1 change: 1 addition & 0 deletions build.gradle.kts
Original file line number Diff line number Diff line change
@@ -48,6 +48,7 @@ dependencies {
testImplementation("org.junit.jupiter:junit-jupiter:5.6.0")
testImplementation("org.junit.jupiter:junit-jupiter-params:5.6.0")
testImplementation("org.jetbrains.kotlin:kotlin-test-junit5")
testImplementation("org.jetbrains.kotlinx:kotlinx-coroutines-test:1.3.8")
}

java {
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
@file:JvmName("CargoMessageBatching")

package tech.relaycorp.relaynet.messages.payloads

import tech.relaycorp.relaynet.messages.InvalidMessageException
import tech.relaycorp.relaynet.ramf.EncryptedRAMFMessage
import java.time.ZonedDateTime
import java.util.Collections

private const val MAX_BATCH_LENGTH =
EncryptedRAMFMessage.MAX_PAYLOAD_PLAINTEXT_LENGTH - CargoMessage.DER_TL_OVERHEAD_OCTETS

/**
* Serialization and expiry date of a message to be encapsulated in a cargo message set.
*
* @throws InvalidMessageException if `cargoMessageSerialized` is longer than
* [CargoMessage.MAX_LENGTH]
*/
@Suppress("ArrayInDataClass")
data class CargoMessageWithExpiry(
val cargoMessageSerialized: ByteArray,
val expiryDate: ZonedDateTime
) {
init {
if (CargoMessage.MAX_LENGTH < cargoMessageSerialized.size) {
throw InvalidMessageException(
"Message must not be longer than ${CargoMessage.MAX_LENGTH} octets " +
"(got ${cargoMessageSerialized.size})"
)
}
}
}

/**
* Serialization and expiry date of a cargo message set.
*/
data class CargoMessageSetWithExpiry(
val cargoMessageSet: CargoMessageSet,
val latestMessageExpiryDate: ZonedDateTime
)

/**
* Batch as many messages together as possible without exceeding the payload length limit on
* individual cargoes.
*
* If all messages can be encapsulated in the same cargo message set, they will be. Otherwise,
* multiple cargo message sets will be generated. The output will be empty if the input is
* empty too.
*/
suspend fun Sequence<CargoMessageWithExpiry>.batch(): Sequence<CargoMessageSetWithExpiry> =
sequence {
val currentBatch = mutableListOf<ByteArray>()
var currentBatchExpiry: ZonedDateTime? = null
var currentBatchAvailableOctets = MAX_BATCH_LENGTH

this@batch.forEach { messageWithExpiry ->
val messageTlvLength =
CargoMessage.DER_TL_OVERHEAD_OCTETS + messageWithExpiry.cargoMessageSerialized.size
val messageFitsInCurrentBatch = messageTlvLength <= currentBatchAvailableOctets
if (!messageFitsInCurrentBatch) {
val cargoMessageSet = CargoMessageSet(currentBatch.toTypedArray())
yield(CargoMessageSetWithExpiry(cargoMessageSet, currentBatchExpiry!!))

currentBatch.clear()
currentBatchExpiry = null
currentBatchAvailableOctets = MAX_BATCH_LENGTH
}

currentBatch.add(messageWithExpiry.cargoMessageSerialized)
currentBatchAvailableOctets -= messageTlvLength

currentBatchExpiry = currentBatchExpiry ?: messageWithExpiry.expiryDate
currentBatchExpiry =
Collections.max(listOf(currentBatchExpiry, messageWithExpiry.expiryDate))
}

if (currentBatch.isNotEmpty()) {
val cargoMessageSet = CargoMessageSet(currentBatch.toTypedArray())
yield(CargoMessageSetWithExpiry(cargoMessageSet, currentBatchExpiry as ZonedDateTime))
}
}
Original file line number Diff line number Diff line change
@@ -2,11 +2,12 @@ package tech.relaycorp.relaynet.messages.payloads

import tech.relaycorp.relaynet.messages.PARCEL_SERIALIZER
import tech.relaycorp.relaynet.messages.ParcelCollectionAck
import tech.relaycorp.relaynet.ramf.EncryptedRAMFMessage

/**
* Message encapsulated in a cargo message set, classified with its type.
*/
class CargoMessage(val messageSerialized: ByteArray) {
class CargoMessage internal constructor(val messageSerialized: ByteArray) {
var type: Type? = null
private set

@@ -26,4 +27,20 @@ class CargoMessage(val messageSerialized: ByteArray) {
PARCEL(PARCEL_SERIALIZER.formatSignature.asList()),
PCA(ParcelCollectionAck.FORMAT_SIGNATURE.asList())
}

companion object {
/**
* Number of octets needed to represent the type and length of an 8 MiB value in DER.
*/
internal const val DER_TL_OVERHEAD_OCTETS = 5

/**
* Maximum number of octets for any serialized message to be encapsulated in a cargo.
*
* This is the result of subtracting the TLVs for the SET and OCTET STRING values from
* the maximum size of an SDU to be encrypted.
*/
internal const val MAX_LENGTH =
EncryptedRAMFMessage.MAX_PAYLOAD_PLAINTEXT_LENGTH - (DER_TL_OVERHEAD_OCTETS * 2)
}
}
Original file line number Diff line number Diff line change
@@ -43,4 +43,9 @@ abstract class EncryptedRAMFMessage<P : EncryptedPayload> internal constructor(

@Throws(RAMFException::class)
internal abstract fun deserializePayload(payloadPlaintext: ByteArray): P

companion object {
// Per the RAMF spec
internal const val MAX_PAYLOAD_PLAINTEXT_LENGTH = 8_322_048
}
}
5 changes: 4 additions & 1 deletion src/main/kotlin/tech/relaycorp/relaynet/ramf/RAMFMessage.kt
Original file line number Diff line number Diff line change
@@ -16,7 +16,6 @@ import java.util.UUID
private const val MAX_RECIPIENT_ADDRESS_LENGTH = 1024
private const val MAX_MESSAGE_ID_LENGTH = 64
private const val MAX_TTL = 15552000
private const val MAX_PAYLOAD_LENGTH = 8388608

private const val DEFAULT_TTL_MINUTES = 5
private const val DEFAULT_TTL_SECONDS = DEFAULT_TTL_MINUTES * 60
@@ -157,4 +156,8 @@ abstract class RAMFMessage<P : Payload> internal constructor(
throw RAMFException("Recipient address is invalid")
}
}

companion object {
internal const val MAX_PAYLOAD_LENGTH = 8_388_608
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package tech.relaycorp.relaynet.messages.payloads

import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.runBlockingTest
import org.junit.jupiter.api.assertThrows
import tech.relaycorp.relaynet.messages.InvalidMessageException
import java.time.ZonedDateTime
import kotlin.test.Test
import kotlin.test.assertEquals

private val expiryDate = ZonedDateTime.now().plusDays(1)

@ExperimentalCoroutinesApi
class BatchTest {
private val messageSerialized = "I'm a parcel. Pinky promise.".toByteArray()

@Test
fun `Zero messages should result in zero batches`() = runBlockingTest {
val batches = emptySequence<CargoMessageWithExpiry>().batch()

assertEquals(0, batches.count())
}

@Test
fun `A single message should result in one batch`() = runBlockingTest {
val batches = sequenceOf(CargoMessageWithExpiry(messageSerialized, expiryDate)).batch()

assertEquals(1, batches.count())
val cargoMessageSet = batches.first().cargoMessageSet
assertEquals(1, cargoMessageSet.messages.size)
assertEquals(messageSerialized.asList(), cargoMessageSet.messages.first().asList())
}

@Test
fun `Multiple small messages should be put in the same batch`() = runBlockingTest {
val message2Serialized = "I'm a PCA. *wink wink*".toByteArray()

val batches = sequenceOf(
CargoMessageWithExpiry(messageSerialized, expiryDate),
CargoMessageWithExpiry(message2Serialized, expiryDate)
).batch()

assertEquals(1, batches.count())
val cargoMessageSet = batches.first().cargoMessageSet
assertEquals(2, cargoMessageSet.messages.size)
assertEquals(messageSerialized.asList(), cargoMessageSet.messages.first().asList())
assertEquals(message2Serialized.asList(), cargoMessageSet.messages[1].asList())
}

@Test
fun `Messages should be put into as few batches as possible`() = runBlockingTest {
val octetsIn3Mib = 3145728
val messageSerialized = "a".repeat(octetsIn3Mib).toByteArray()

val batches = sequenceOf(
CargoMessageWithExpiry(messageSerialized, expiryDate),
CargoMessageWithExpiry(messageSerialized, expiryDate),
CargoMessageWithExpiry(messageSerialized, expiryDate)
).batch()

assertEquals(2, batches.count())
val cargoMessageSet1 = batches.first().cargoMessageSet
assertEquals(
listOf(messageSerialized.asList(), messageSerialized.asList()),
cargoMessageSet1.messages.map { it.asList() }
)
val cargoMessageSet2 = batches.last().cargoMessageSet
assertEquals(1, cargoMessageSet2.messages.size)
assertEquals(messageSerialized.asList(), cargoMessageSet2.messages.first().asList())
}

@Test
fun `Messages collectively reaching the max length should be placed together`() =
runBlockingTest {
val halfLimit = CargoMessage.MAX_LENGTH / 2
val message1Serialized = "a".repeat(halfLimit - 3).toByteArray()
val message2Serialized = "a".repeat(halfLimit - 2).toByteArray()

val batches = sequenceOf(
CargoMessageWithExpiry(message1Serialized, expiryDate),
CargoMessageWithExpiry(message2Serialized, expiryDate)
).batch()

assertEquals(1, batches.count())
val cargoMessageSet = batches.first().cargoMessageSet
assertEquals(2, cargoMessageSet.messages.size)
assertEquals(message1Serialized.asList(), cargoMessageSet.messages[0].asList())
assertEquals(message2Serialized.asList(), cargoMessageSet.messages[1].asList())
}

@Test
fun `Expiry date of batch should be that of its message with latest expiry`() =
runBlockingTest {
// Generate two batches where the expiry date of the former is that of its first
// message, and the expiry date of the latter batch is that of its last message
val messageSerialized = "a".repeat(CargoMessage.MAX_LENGTH / 2 - 3).toByteArray()
val now = ZonedDateTime.now()
val message1ExpiryDate = now.plusDays(2)
val message2ExpiryDate = now.plusDays(1)
val message3ExpiryDate = now.plusDays(3)
val message4ExpiryDate = now.plusDays(4)

val batches = sequenceOf(
CargoMessageWithExpiry(messageSerialized, message1ExpiryDate),
CargoMessageWithExpiry(messageSerialized, message2ExpiryDate),
CargoMessageWithExpiry(messageSerialized, message3ExpiryDate),
CargoMessageWithExpiry(messageSerialized, message4ExpiryDate)
).batch()

assertEquals(2, batches.count())
assertEquals(2, batches.first().cargoMessageSet.messages.size)
assertEquals(message1ExpiryDate, batches.first().latestMessageExpiryDate)
assertEquals(2, batches.last().cargoMessageSet.messages.size)
assertEquals(message4ExpiryDate, batches.last().latestMessageExpiryDate)
}
}

class CargoMessageWithExpiryTest {
@Test
fun `A message with the largest possible length should be accepted`() {
val messageSerialized = "a".repeat(CargoMessage.MAX_LENGTH).toByteArray()

CargoMessageWithExpiry(messageSerialized, expiryDate)
}

@Test
fun `Messages exceeding the max per-message size should be refused`() {
val messageSerialized = "a".repeat(CargoMessage.MAX_LENGTH + 1).toByteArray()

val exception = assertThrows<InvalidMessageException> {
CargoMessageWithExpiry(messageSerialized, expiryDate)
}

assertEquals(
"Message must not be longer than ${CargoMessage.MAX_LENGTH} octets " +
"(got ${messageSerialized.size})",
exception.message
)
}
}
Original file line number Diff line number Diff line change
@@ -184,8 +184,7 @@ class RAMFMessageTest {

@Test
fun `Payload should not span more than 8 MiB`() {
val octetsIn8Mib = 8388608
val longPayloadLength = octetsIn8Mib + 1
val longPayloadLength = RAMFMessage.MAX_PAYLOAD_LENGTH + 1
val longPayload = "a".repeat(longPayloadLength).toByteArray()
val exception = assertThrows<RAMFException> {
StubEncryptedRAMFMessage(
@@ -196,7 +195,8 @@ class RAMFMessageTest {
}

assertEquals(
"Payload cannot span more than $octetsIn8Mib octets (got $longPayloadLength)",
"Payload cannot span more than ${RAMFMessage.MAX_PAYLOAD_LENGTH} octets " +
"(got $longPayloadLength)",
exception.message
)
}

0 comments on commit e47a87e

Please sign in to comment.