Skip to content

Commit

Permalink
Merge pull request #6077 from vector-im/feature/aris/crypto_replay_at…
Browse files Browse the repository at this point in the history
…tack

Feature/aris/crypto replay attack
  • Loading branch information
BillCarsonFr authored May 25, 2022
2 parents f5b4e89 + 85f3592 commit 52eb48d
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 47 deletions.
1 change: 1 addition & 0 deletions changelog.d/6077.sdk
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve replay attacks and reduce duplicate message index errors
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
* Copyright 2022 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.matrix.android.sdk.internal.crypto.replayattack

import androidx.test.filters.LargeTest
import org.amshove.kluent.internal.assertFailsWith
import org.junit.Assert
import org.junit.Assert.assertEquals
import org.junit.Assert.fail
import org.junit.FixMethodOrder
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.runners.MethodSorters
import org.matrix.android.sdk.InstrumentedTest
import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.common.CommonTestHelper
import org.matrix.android.sdk.common.CryptoTestHelper

@RunWith(JUnit4::class)
@FixMethodOrder(MethodSorters.JVM)
@LargeTest
class ReplayAttackTest : InstrumentedTest {

@Test
fun replayAttackAlreadyDecryptedEventTest() {
val testHelper = CommonTestHelper(context())
val cryptoTestHelper = CryptoTestHelper(testHelper)
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)

val e2eRoomID = cryptoTestData.roomId

// Alice
val aliceSession = cryptoTestData.firstSession
val aliceRoomPOV = aliceSession.roomService().getRoom(e2eRoomID)!!

// Bob
val bobSession = cryptoTestData.secondSession
val bobRoomPOV = bobSession!!.roomService().getRoom(e2eRoomID)!!
assertEquals(bobRoomPOV.roomSummary()?.joinedMembersCount, 2)

// Alice will send a message
val sentEvents = testHelper.sendTextMessage(aliceRoomPOV, "Hello I will be decrypted twice", 1)
assertEquals(1, sentEvents.size)

val fakeEventId = sentEvents[0].eventId + "_fake"
val fakeEventWithTheSameIndex =
sentEvents[0].copy(eventId = fakeEventId, root = sentEvents[0].root.copy(eventId = fakeEventId))

testHelper.runBlockingTest {
// Lets assume we are from the main timelineId
val timelineId = "timelineId"
// Lets decrypt the original event
aliceSession.cryptoService().decryptEvent(sentEvents[0].root, timelineId)
// Lets decrypt the fake event that will have the same message index
val exception = assertFailsWith<MXCryptoError.Base> {
// An exception should be thrown while the same index would have been used for the previous decryption
aliceSession.cryptoService().decryptEvent(fakeEventWithTheSameIndex.root, timelineId)
}
assertEquals(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, exception.errorType)
}
cryptoTestData.cleanUp(testHelper)
}

@Test
fun replayAttackSameEventTest() {
val testHelper = CommonTestHelper(context())
val cryptoTestHelper = CryptoTestHelper(testHelper)
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)

val e2eRoomID = cryptoTestData.roomId

// Alice
val aliceSession = cryptoTestData.firstSession
val aliceRoomPOV = aliceSession.roomService().getRoom(e2eRoomID)!!

// Bob
val bobSession = cryptoTestData.secondSession
val bobRoomPOV = bobSession!!.roomService().getRoom(e2eRoomID)!!
assertEquals(bobRoomPOV.roomSummary()?.joinedMembersCount, 2)

// Alice will send a message
val sentEvents = testHelper.sendTextMessage(aliceRoomPOV, "Hello I will be decrypted twice", 1)
Assert.assertTrue("Message should be sent", sentEvents.size == 1)
assertEquals(sentEvents.size, 1)

testHelper.runBlockingTest {
// Lets assume we are from the main timelineId
val timelineId = "timelineId"
// Lets decrypt the original event
aliceSession.cryptoService().decryptEvent(sentEvents[0].root, timelineId)
try {
// Lets try to decrypt the same event
aliceSession.cryptoService().decryptEvent(sentEvents[0].root, timelineId)
} catch (ex: Throwable) {
fail("Shouldn't throw a decryption error for same event")
}
}
cryptoTestData.cleanUp(testHelper)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,9 @@ internal class MXOlmDevice @Inject constructor(
// So, store these message indexes per timeline id.
//
// The first level keys are timeline ids.
// The second level keys are strings of form "<senderKey>|<session_id>|<message_index>"
private val inboundGroupSessionMessageIndexes: MutableMap<String, MutableSet<String>> = HashMap()
// The second level values is a Map that represents:
// "<senderKey>|<session_id>|<roomId>|<message_index>" --> eventId
private val inboundGroupSessionMessageIndexes: MutableMap<String, MutableMap<String, String>> = HashMap()

init {
// Retrieve the account from the store
Expand Down Expand Up @@ -755,67 +756,72 @@ internal class MXOlmDevice @Inject constructor(
* @param body the base64-encoded body of the encrypted message.
* @param roomId the room in which the message was received.
* @param timeline the id of the timeline where the event is decrypted. It is used to prevent replay attack.
* @param eventId the eventId of the message that will be decrypted
* @param sessionId the session identifier.
* @param senderKey the base64-encoded curve25519 key of the sender.
* @return the decrypting result. Nil if the sessionId is unknown.
* @return the decrypting result. Null if the sessionId is unknown.
*/
@Throws(MXCryptoError::class)
suspend fun decryptGroupMessage(body: String,
roomId: String,
timeline: String?,
eventId: String,
sessionId: String,
senderKey: String): OlmDecryptionResult {
val sessionHolder = getInboundGroupSession(sessionId, senderKey, roomId)
val wrapper = sessionHolder.wrapper
val inboundGroupSession = wrapper.olmInboundGroupSession
?: throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, "Session is null")
// Check that the room id matches the original one for the session. This stops
// the HS pretending a message was targeting a different room.
if (roomId == wrapper.roomId) {
val decryptResult = try {
sessionHolder.mutex.withLock {
inboundGroupSession.decryptMessage(body)
}
} catch (e: OlmException) {
Timber.tag(loggerTag.value).e(e, "## decryptGroupMessage () : decryptMessage failed")
throw MXCryptoError.OlmError(e)
}

if (timeline?.isNotBlank() == true) {
val timelineSet = inboundGroupSessionMessageIndexes.getOrPut(timeline) { mutableSetOf() }

val messageIndexKey = senderKey + "|" + sessionId + "|" + decryptResult.mIndex

if (timelineSet.contains(messageIndexKey)) {
val reason = String.format(MXCryptoError.DUPLICATE_MESSAGE_INDEX_REASON, decryptResult.mIndex)
Timber.tag(loggerTag.value).e("## decryptGroupMessage() timelineId=$timeline: $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, reason)
}

timelineSet.add(messageIndexKey)
}

inboundGroupSessionStore.storeInBoundGroupSession(sessionHolder, sessionId, senderKey)
val payload = try {
val adapter = MoshiProvider.providesMoshi().adapter<JsonDict>(JSON_DICT_PARAMETERIZED_TYPE)
val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage)
adapter.fromJson(payloadString)
} catch (e: Exception) {
Timber.tag(loggerTag.value).e("## decryptGroupMessage() : fails to parse the payload")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON)
}

return OlmDecryptionResult(
payload,
wrapper.keysClaimed,
senderKey,
wrapper.forwardingCurve25519KeyChain
)
} else {
if (roomId != wrapper.roomId) {
// Check that the room id matches the original one for the session. This stops
// the HS pretending a message was targeting a different room.
val reason = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, wrapper.roomId)
Timber.tag(loggerTag.value).e("## decryptGroupMessage() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, reason)
}
val decryptResult = try {
sessionHolder.mutex.withLock {
inboundGroupSession.decryptMessage(body)
}
} catch (e: OlmException) {
Timber.tag(loggerTag.value).e(e, "## decryptGroupMessage () : decryptMessage failed")
throw MXCryptoError.OlmError(e)
}

val messageIndexKey = senderKey + "|" + sessionId + "|" + roomId + "|" + decryptResult.mIndex
Timber.tag(loggerTag.value).v("##########################################################")
Timber.tag(loggerTag.value).v("## decryptGroupMessage() timeline: $timeline")
Timber.tag(loggerTag.value).v("## decryptGroupMessage() senderKey: $senderKey")
Timber.tag(loggerTag.value).v("## decryptGroupMessage() sessionId: $sessionId")
Timber.tag(loggerTag.value).v("## decryptGroupMessage() roomId: $roomId")
Timber.tag(loggerTag.value).v("## decryptGroupMessage() eventId: $eventId")
Timber.tag(loggerTag.value).v("## decryptGroupMessage() mIndex: ${decryptResult.mIndex}")

if (timeline?.isNotBlank() == true) {
val replayAttackMap = inboundGroupSessionMessageIndexes.getOrPut(timeline) { mutableMapOf() }
if (replayAttackMap.contains(messageIndexKey) && replayAttackMap[messageIndexKey] != eventId) {
val reason = String.format(MXCryptoError.DUPLICATE_MESSAGE_INDEX_REASON, decryptResult.mIndex)
Timber.tag(loggerTag.value).e("## decryptGroupMessage() timelineId=$timeline: $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, reason)
}
replayAttackMap[messageIndexKey] = eventId
}
inboundGroupSessionStore.storeInBoundGroupSession(sessionHolder, sessionId, senderKey)
val payload = try {
val adapter = MoshiProvider.providesMoshi().adapter<JsonDict>(JSON_DICT_PARAMETERIZED_TYPE)
val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage)
adapter.fromJson(payloadString)
} catch (e: Exception) {
Timber.tag(loggerTag.value).e("## decryptGroupMessage() : fails to parse the payload")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON)
}

return OlmDecryptionResult(
payload,
wrapper.keysClaimed,
senderKey,
wrapper.forwardingCurve25519KeyChain
)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ internal class MXMegolmDecryption(
encryptedEventContent.ciphertext,
event.roomId,
timeline,
eventId = event.eventId.orEmpty(),
encryptedEventContent.sessionId,
encryptedEventContent.senderKey
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,10 @@ internal class RoomSyncHandler @Inject constructor(

private fun decryptIfNeeded(event: Event, roomId: String) {
try {
val timelineId = generateTimelineId(roomId)
// Event from sync does not have roomId, so add it to the event first
// note: runBlocking should be used here while we are in realm single thread executor, to avoid thread switching
val result = runBlocking { cryptoService.decryptEvent(event.copy(roomId = roomId), "") }
val result = runBlocking { cryptoService.decryptEvent(event.copy(roomId = roomId), timelineId) }
event.mxDecryptionResult = OlmDecryptionResult(
payload = result.clearEvent,
senderKey = result.senderCurve25519Key,
Expand All @@ -537,6 +538,10 @@ internal class RoomSyncHandler @Inject constructor(
}
}

private fun generateTimelineId(roomId: String): String {
return "RoomSyncHandler$roomId"
}

data class EphemeralResult(
val typingUserIds: List<String> = emptyList()
)
Expand Down

0 comments on commit 52eb48d

Please sign in to comment.