Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/aris/crypto replay attack #6077

Merged
merged 11 commits into from
May 25, 2022
1 change: 1 addition & 0 deletions changelog.d/6077.sdk
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improve replay attacks and reduce duplicate message index errors
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) 2022 New Vector Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.matrix.android.sdk.internal.crypto.replayattack

import androidx.test.filters.LargeTest
import org.amshove.kluent.internal.assertFailsWith
import org.junit.Assert
import org.junit.Assert.assertEquals
import org.junit.FixMethodOrder
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.junit.runners.MethodSorters
import org.matrix.android.sdk.InstrumentedTest
import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.common.CommonTestHelper
import org.matrix.android.sdk.common.CryptoTestHelper

@RunWith(JUnit4::class)
@FixMethodOrder(MethodSorters.JVM)
@LargeTest
class ReplayAttackTest : InstrumentedTest {

@Test
fun replayAttackAlreadyDecryptedEventTest() {
val testHelper = CommonTestHelper(context())
val cryptoTestHelper = CryptoTestHelper(testHelper)
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)

val e2eRoomID = cryptoTestData.roomId

// Alice
val aliceSession = cryptoTestData.firstSession
val aliceRoomPOV = aliceSession.roomService().getRoom(e2eRoomID)!!

// Bob
val bobSession = cryptoTestData.secondSession
val bobRoomPOV = bobSession!!.roomService().getRoom(e2eRoomID)!!
assertEquals(bobRoomPOV.roomSummary()?.joinedMembersCount, 2)

// Alice will send a message
val sentEvents = testHelper.sendTextMessage(aliceRoomPOV, "Hello I will be decrypted twice", 1)
Assert.assertTrue("Message should be sent", sentEvents.size == 1)

val fakeEventId = sentEvents[0].eventId + "_fake"
val fakeEventWithTheSameIndex =
sentEvents[0].copy(eventId = fakeEventId, root = sentEvents[0].root.copy(eventId = fakeEventId))

testHelper.runBlockingTest {
// Lets assume we are from the main timelineId
val timelineId = "timelineId"
// Lets decrypt the original event
aliceSession.cryptoService().decryptEvent(sentEvents[0].root, timelineId)
// Lets decrypt the fake event that will have the same message index
val exception = assertFailsWith<MXCryptoError.Base> {
// An exception should be thrown while the same index would have been used for the previous decryption
aliceSession.cryptoService().decryptEvent(fakeEventWithTheSameIndex.root, timelineId)
}
assertEquals(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, exception.errorType)
}
cryptoTestData.cleanUp(testHelper)
}

@Test
fun replayAttackSameEventTest() {
val testHelper = CommonTestHelper(context())
val cryptoTestHelper = CryptoTestHelper(testHelper)
val cryptoTestData = cryptoTestHelper.doE2ETestWithAliceAndBobInARoom(true)

val e2eRoomID = cryptoTestData.roomId

// Alice
val aliceSession = cryptoTestData.firstSession
val aliceRoomPOV = aliceSession.roomService().getRoom(e2eRoomID)!!

// Bob
val bobSession = cryptoTestData.secondSession
val bobRoomPOV = bobSession!!.roomService().getRoom(e2eRoomID)!!
assertEquals(bobRoomPOV.roomSummary()?.joinedMembersCount, 2)

// Alice will send a message
val sentEvents = testHelper.sendTextMessage(aliceRoomPOV, "Hello I will be decrypted twice", 1)
Assert.assertTrue("Message should be sent", sentEvents.size == 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can be better to use assertEqual for better test failure report (for instance "expected 1, actual 2", rather than "expected true, actual false" when using assertTrue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated


testHelper.runBlockingTest {
// Lets assume we are from the main timelineId
val timelineId = "timelineId"
// Lets decrypt the original event
aliceSession.cryptoService().decryptEvent(sentEvents[0].root, timelineId)
// Lets try to decrypt the same event
aliceSession.cryptoService().decryptEvent(sentEvents[0].root, timelineId)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd do a try {.. decrypt } catch (error) { fail("Shouldn't throw a decryption error for same event") }

}
cryptoTestData.cleanUp(testHelper)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ internal class MXOlmDevice @Inject constructor(
// The second level keys are strings of form "<senderKey>|<session_id>|<message_index>"
private val inboundGroupSessionMessageIndexes: MutableMap<String, MutableSet<String>> = HashMap()

private val replayAttackMap: MutableMap<String, String> = HashMap()

init {
// Retrieve the account from the store
try {
Expand Down Expand Up @@ -763,59 +765,71 @@ internal class MXOlmDevice @Inject constructor(
suspend fun decryptGroupMessage(body: String,
roomId: String,
timeline: String?,
eventId: String,
sessionId: String,
senderKey: String): OlmDecryptionResult {
val sessionHolder = getInboundGroupSession(sessionId, senderKey, roomId)
val wrapper = sessionHolder.wrapper
val inboundGroupSession = wrapper.olmInboundGroupSession
?: throw MXCryptoError.Base(MXCryptoError.ErrorType.UNABLE_TO_DECRYPT, "Session is null")
// Check that the room id matches the original one for the session. This stops
// the HS pretending a message was targeting a different room.
if (roomId == wrapper.roomId) {
val decryptResult = try {
sessionHolder.mutex.withLock {
inboundGroupSession.decryptMessage(body)
}
} catch (e: OlmException) {
Timber.tag(loggerTag.value).e(e, "## decryptGroupMessage () : decryptMessage failed")
throw MXCryptoError.OlmError(e)
}

if (timeline?.isNotBlank() == true) {
val timelineSet = inboundGroupSessionMessageIndexes.getOrPut(timeline) { mutableSetOf() }

val messageIndexKey = senderKey + "|" + sessionId + "|" + decryptResult.mIndex

if (timelineSet.contains(messageIndexKey)) {
val reason = String.format(MXCryptoError.DUPLICATE_MESSAGE_INDEX_REASON, decryptResult.mIndex)
Timber.tag(loggerTag.value).e("## decryptGroupMessage() timelineId=$timeline: $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, reason)
}

timelineSet.add(messageIndexKey)
}

inboundGroupSessionStore.storeInBoundGroupSession(sessionHolder, sessionId, senderKey)
val payload = try {
val adapter = MoshiProvider.providesMoshi().adapter<JsonDict>(JSON_DICT_PARAMETERIZED_TYPE)
val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage)
adapter.fromJson(payloadString)
} catch (e: Exception) {
Timber.tag(loggerTag.value).e("## decryptGroupMessage() : fails to parse the payload")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON)
}

return OlmDecryptionResult(
payload,
wrapper.keysClaimed,
senderKey,
wrapper.forwardingCurve25519KeyChain
)
} else {
if (roomId != wrapper.roomId) {
// Check that the room id matches the original one for the session. This stops
// the HS pretending a message was targeting a different room.
val reason = String.format(MXCryptoError.INBOUND_SESSION_MISMATCH_ROOM_ID_REASON, roomId, wrapper.roomId)
Timber.tag(loggerTag.value).e("## decryptGroupMessage() : $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.INBOUND_SESSION_MISMATCH_ROOM_ID, reason)
}
val decryptResult = try {
sessionHolder.mutex.withLock {
inboundGroupSession.decryptMessage(body)
}
} catch (e: OlmException) {
Timber.tag(loggerTag.value).e(e, "## decryptGroupMessage () : decryptMessage failed")
throw MXCryptoError.OlmError(e)
}

val messageIndexKey = senderKey + "|" + sessionId + "|" + roomId + "|" + decryptResult.mIndex
Timber.tag(loggerTag.value).d("##########################################################")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reduce log visibility to verbose!

Timber.tag(loggerTag.value).d("## decryptGroupMessage() timeline: $timeline")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() senderKey: $senderKey")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() sessionId: $sessionId")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() roomId: $roomId")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() eventId: $eventId")
Timber.tag(loggerTag.value).d("## decryptGroupMessage() mIndex: ${decryptResult.mIndex}")

if (timeline?.isNotBlank() == true) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we might need some concurrency issues when accessing the timelineset. Maybe use a mutex for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its the way the current code works, I will have a look

val timelineSet = inboundGroupSessionMessageIndexes.getOrPut(timeline) { mutableSetOf() }
if (timelineSet.contains(messageIndexKey) && messageIndexKey.alreadyUsed(eventId)) {
val reason = String.format(MXCryptoError.DUPLICATE_MESSAGE_INDEX_REASON, decryptResult.mIndex)
Timber.tag(loggerTag.value).e("## decryptGroupMessage() timelineId=$timeline: $reason")
throw MXCryptoError.Base(MXCryptoError.ErrorType.DUPLICATED_MESSAGE_INDEX, reason)
}
timelineSet.add(messageIndexKey)
}
replayAttackMap[messageIndexKey] = eventId
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It feels strange to have the replayAttackMap global and shared by everyone.
I would have made inboundGroupSessionMessageIndexes a Map<timelineId, replayAttackMap>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeap, its a good point will refactor it a bit

inboundGroupSessionStore.storeInBoundGroupSession(sessionHolder, sessionId, senderKey)
val payload = try {
val adapter = MoshiProvider.providesMoshi().adapter<JsonDict>(JSON_DICT_PARAMETERIZED_TYPE)
val payloadString = convertFromUTF8(decryptResult.mDecryptedMessage)
adapter.fromJson(payloadString)
} catch (e: Exception) {
Timber.tag(loggerTag.value).e("## decryptGroupMessage() : fails to parse the payload")
throw MXCryptoError.Base(MXCryptoError.ErrorType.BAD_DECRYPTED_FORMAT, MXCryptoError.BAD_DECRYPTED_FORMAT_TEXT_REASON)
}

return OlmDecryptionResult(
payload,
wrapper.keysClaimed,
senderKey,
wrapper.forwardingCurve25519KeyChain
)
}

/**
* Determines whether or not the messageKey has already been used to decrypt another eventId
*/
private fun String.alreadyUsed(eventId: String): Boolean {
return replayAttackMap[this] != null && replayAttackMap[this] != eventId
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ internal class MXMegolmDecryption(
encryptedEventContent.ciphertext,
event.roomId,
timeline,
eventId = event.eventId.orEmpty(),
encryptedEventContent.sessionId,
encryptedEventContent.senderKey
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.matrix.android.sdk.api.session.crypto.MXCryptoError
import org.matrix.android.sdk.api.session.crypto.model.OlmDecryptionResult
import org.matrix.android.sdk.api.session.events.model.Event
import org.matrix.android.sdk.api.session.events.model.EventType
import org.matrix.android.sdk.api.session.events.model.isThread
import org.matrix.android.sdk.api.session.events.model.toModel
import org.matrix.android.sdk.api.session.homeserver.HomeServerCapabilitiesService
import org.matrix.android.sdk.api.session.initsync.InitSyncStep
Expand Down Expand Up @@ -520,9 +521,10 @@ internal class RoomSyncHandler @Inject constructor(

private fun decryptIfNeeded(event: Event, roomId: String) {
try {
val timelineId = generateTimelineId(roomId, event)
// Event from sync does not have roomId, so add it to the event first
// note: runBlocking should be used here while we are in realm single thread executor, to avoid thread switching
val result = runBlocking { cryptoService.decryptEvent(event.copy(roomId = roomId), "") }
val result = runBlocking { cryptoService.decryptEvent(event.copy(roomId = roomId), timelineId) }
event.mxDecryptionResult = OlmDecryptionResult(
payload = result.clearEvent,
senderKey = result.senderCurve25519Key,
Expand All @@ -537,6 +539,11 @@ internal class RoomSyncHandler @Inject constructor(
}
}

private fun generateTimelineId(roomId: String, event: Event): String {
val threadIndicator = if (event.isThread()) "_thread_" else "_"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we make a different timeline id between thread and normal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the fact that it's technically a different timeline, that can contain the same events with the main timeline in parallel. But maybe re-using the main timelineId is enough, will update it

return "${RoomSyncHandler::class.java.simpleName}$threadIndicator$roomId"
}

data class EphemeralResult(
val typingUserIds: List<String> = emptyList()
)
Expand Down