Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fallback keys implementation #3970

Merged
merged 9 commits into from
Dec 13, 2021
1 change: 1 addition & 0 deletions changelog.d/3473.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
MSC2732: Olm fallback keys
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ data class SyncResponse(
@Json(name = "device_one_time_keys_count")
val deviceOneTimeKeysCount: DeviceOneTimeKeysCountSyncResponse? = null,

/**
* The key algorithms for which the server has an unused fallback key for the device.
* If the client wants the server to have a fallback key for a given key algorithm,
* but that algorithm is not listed in device_unused_fallback_key_types, the client will upload a new key.
*/
@Json(name = "org.matrix.msc2732.device_unused_fallback_key_types")
val deviceUnusedFallbackKeyTypes: List<String>? = null,

/**
* List of groups.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo
import org.matrix.android.sdk.internal.crypto.model.ImportRoomKeysResult
import org.matrix.android.sdk.internal.crypto.model.MXDeviceInfo
import org.matrix.android.sdk.internal.crypto.model.MXEncryptEventContentResult
import org.matrix.android.sdk.internal.crypto.model.MXKey.Companion.KEY_SIGNED_CURVE_25519_TYPE
import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap
import org.matrix.android.sdk.internal.crypto.model.event.EncryptedEventContent
import org.matrix.android.sdk.internal.crypto.model.event.RoomKeyContent
Expand Down Expand Up @@ -431,6 +432,14 @@ internal class DefaultCryptoService @Inject constructor(
if (isStarted()) {
// Make sure we process to-device messages before generating new one-time-keys #2782
deviceListManager.refreshOutdatedDeviceLists()
// The presence of device_unused_fallback_key_types indicates that the server supports fallback keys.
// If there's no unused signed_curve25519 fallback key we need a new one.
if (syncResponse.deviceUnusedFallbackKeyTypes != null &&
// Generate a fallback key only if the server does not already have an unused fallback key.
!syncResponse.deviceUnusedFallbackKeyTypes.contains(KEY_SIGNED_CURVE_25519_TYPE)) {
oneTimeKeysUploader.needsNewFallback()
}

oneTimeKeysUploader.maybeUploadOneTimeKeys()
incomingGossipingRequestManager.processReceivedGossipingRequests()
}
Expand Down Expand Up @@ -928,7 +937,7 @@ internal class DefaultCryptoService @Inject constructor(
signatures = objectSigner.signObject(canonicalJson)
)

val uploadDeviceKeysParams = UploadKeysTask.Params(rest, null)
val uploadDeviceKeysParams = UploadKeysTask.Params(rest, null, null)
uploadKeysTask.execute(uploadDeviceKeysParams)

cryptoStore.setDeviceKeysUploaded(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,51 @@ internal class MXOlmDevice @Inject constructor(
return store.getOlmAccount().maxOneTimeKeys()
}

/**
* Returns an unpublished fallback key
* A call to markKeysAsPublished will mark it as published and this
* call will return null (until a call to generateFallbackKey is made)
*/
fun getFallbackKey(): MutableMap<String, MutableMap<String, String>>? {
try {
return store.getOlmAccount().fallbackKey()
} catch (e: Exception) {
Timber.e("## getFallbackKey() : failed")
}
return null
}

/**
* Generates a new fallback key if there is not already
* an unpublished one.
* @return true if a new key was generated
*/
fun generateFallbackKeyIfNeeded(): Boolean {
try {
if (!hasUnpublishedFallbackKey()) {
store.getOlmAccount().generateFallbackKey()
store.saveOlmAccount()
return true
}
} catch (e: Exception) {
Timber.e("## generateFallbackKey() : failed")
}
return false
}

internal fun hasUnpublishedFallbackKey(): Boolean {
return getFallbackKey()?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty().isNotEmpty()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

orEmpty().isNotEmpty() is a bit weird :/

}

fun forgetFallbackKey() {
try {
store.getOlmAccount().forgetFallbackKey()
store.saveOlmAccount()
} catch (e: Exception) {
Timber.e("## forgetFallbackKey() : failed")
}
}

/**
* Release the instance
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.matrix.android.sdk.internal.crypto

import android.content.Context
import org.matrix.android.sdk.api.extensions.tryOrNull
import org.matrix.android.sdk.internal.crypto.model.MXKey
import org.matrix.android.sdk.internal.crypto.model.rest.KeysUploadResponse
Expand All @@ -28,11 +29,16 @@ import javax.inject.Inject
import kotlin.math.floor
import kotlin.math.min

// The spec recommend a 5mn delay, but due to federation
// or server downtime we give it a bit more time (1 hour)
const val FALLBACK_KEY_FORGET_DELAY = 60 * 60_000L

@SessionScope
internal class OneTimeKeysUploader @Inject constructor(
private val olmDevice: MXOlmDevice,
private val objectSigner: ObjectSigner,
private val uploadKeysTask: UploadKeysTask
private val uploadKeysTask: UploadKeysTask,
context: Context
) {
// tell if there is a OTK check in progress
private var oneTimeKeyCheckInProgress = false
Expand All @@ -41,6 +47,9 @@ internal class OneTimeKeysUploader @Inject constructor(
private var lastOneTimeKeyCheck: Long = 0
private var oneTimeKeyCount: Int? = null

// Simple storage to remember when was uploaded the last fallback key
private val storage = context.getSharedPreferences("OneTimeKeysUploader_${olmDevice.deviceEd25519Key.hashCode()}", Context.MODE_PRIVATE)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bit dirty... A proper store should be injected in the constructor...


/**
* Stores the current one_time_key count which will be handled later (in a call of
* _onSyncCompleted). The count is e.g. coming from a /sync response.
Expand All @@ -51,6 +60,15 @@ internal class OneTimeKeysUploader @Inject constructor(
oneTimeKeyCount = currentCount
}

fun needsNewFallback() {
if (olmDevice.generateFallbackKeyIfNeeded()) {
// As we generated a new one, it's already forgetting one
// so we can clear the last publish time
// (in case the network calls fails after to avoid calling forgetKey)
saveLastFallbackKeyPublishTime(0L)
}
}

/**
* Check if the OTK must be uploaded.
*/
Expand All @@ -65,9 +83,19 @@ internal class OneTimeKeysUploader @Inject constructor(
return
}

lastOneTimeKeyCheck = System.currentTimeMillis()
oneTimeKeyCheckInProgress = true

val oneTimeKeyCountFromSync = oneTimeKeyCount
?: fetchOtkCount() // we don't have count from sync so get from server
?: return Unit.also {
oneTimeKeyCheckInProgress = false
Timber.w("maybeUploadOneTimeKeys: Failed to get otk count from server")
}

Timber.d("maybeUploadOneTimeKeys: otk count $oneTimeKeyCountFromSync , unpublished fallback key ${olmDevice.hasUnpublishedFallbackKey()}")

lastOneTimeKeyCheck = System.currentTimeMillis()

// We then check how many keys we can store in the Account object.
val maxOneTimeKeys = olmDevice.getMaxNumberOfOneTimeKeys()

Expand All @@ -78,37 +106,37 @@ internal class OneTimeKeysUploader @Inject constructor(
// discard the oldest private keys first. This will eventually clean
// out stale private keys that won't receive a message.
val keyLimit = floor(maxOneTimeKeys / 2.0).toInt()
if (oneTimeKeyCount == null) {
// Ask the server how many otk he has
oneTimeKeyCount = fetchOtkCount()
}
val oneTimeKeyCountFromSync = oneTimeKeyCount
if (oneTimeKeyCountFromSync != null) {
// We need to keep a pool of one time public keys on the server so that
// other devices can start conversations with us. But we can only store
// a finite number of private keys in the olm Account object.
// To complicate things further then can be a delay between a device
// claiming a public one time key from the server and it sending us a
// message. We need to keep the corresponding private key locally until
// we receive the message.
// But that message might never arrive leaving us stuck with duff
// private keys clogging up our local storage.
// So we need some kind of engineering compromise to balance all of
// these factors.
tryOrNull("Unable to upload OTK") {
val uploadedKeys = uploadOTK(oneTimeKeyCountFromSync, keyLimit)
Timber.v("## uploadKeys() : success, $uploadedKeys key(s) sent")
}
} else {
Timber.w("maybeUploadOneTimeKeys: waiting to know the number of OTK from the sync")
lastOneTimeKeyCheck = 0

// We need to keep a pool of one time public keys on the server so that
// other devices can start conversations with us. But we can only store
// a finite number of private keys in the olm Account object.
// To complicate things further then can be a delay between a device
// claiming a public one time key from the server and it sending us a
// message. We need to keep the corresponding private key locally until
// we receive the message.
// But that message might never arrive leaving us stuck with duff
// private keys clogging up our local storage.
// So we need some kind of engineering compromise to balance all of
// these factors.
tryOrNull("Unable to upload OTK") {
val uploadedKeys = uploadOTK(oneTimeKeyCountFromSync, keyLimit)
Timber.v("## uploadKeys() : success, $uploadedKeys key(s) sent")
}
oneTimeKeyCheckInProgress = false

// Check if we need to forget a fallback key
val latestPublishedTime = getLastFallbackKeyPublishTime()
if (latestPublishedTime != 0L && System.currentTimeMillis() - latestPublishedTime > FALLBACK_KEY_FORGET_DELAY) {
// This should be called once you are reasonably certain that you will not receive any more messages
// that use the old fallback key
Timber.d("## forgetFallbackKey()")
olmDevice.forgetFallbackKey()
}
}

private suspend fun fetchOtkCount(): Int? {
return tryOrNull("Unable to get OTK count") {
val result = uploadKeysTask.execute(UploadKeysTask.Params(null, null))
val result = uploadKeysTask.execute(UploadKeysTask.Params(null, null, null))
result.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE)
}
}
Expand All @@ -121,24 +149,47 @@ internal class OneTimeKeysUploader @Inject constructor(
* @return the number of uploaded keys
*/
private suspend fun uploadOTK(keyCount: Int, keyLimit: Int): Int {
if (keyLimit <= keyCount) {
if (keyLimit <= keyCount && !olmDevice.hasUnpublishedFallbackKey()) {
// If we don't need to generate any more keys then we are done.
return 0
}
val keysThisLoop = min(keyLimit - keyCount, ONE_TIME_KEY_GENERATION_MAX_NUMBER)
olmDevice.generateOneTimeKeys(keysThisLoop)
var keysThisLoop = 0
if (keyLimit > keyCount) {
// Creating keys can be an expensive operation so we limit the
// number we generate in one go to avoid blocking the application
// for too long.
keysThisLoop = min(keyLimit - keyCount, ONE_TIME_KEY_GENERATION_MAX_NUMBER)
olmDevice.generateOneTimeKeys(keysThisLoop)
}

// We check before sending if there is an unpublished key in order to saveLastFallbackKeyPublishTime if needed
val hadUnpublishedFallbackKey = olmDevice.hasUnpublishedFallbackKey()
val response = uploadOneTimeKeys(olmDevice.getOneTimeKeys())
olmDevice.markKeysAsPublished()
if (hadUnpublishedFallbackKey) {
// It had an unpublished fallback key that was published just now
saveLastFallbackKeyPublishTime(System.currentTimeMillis())
}

if (response.hasOneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE)) {
// Maybe upload other keys
return keysThisLoop + uploadOTK(response.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE), keyLimit)
return keysThisLoop +
uploadOTK(response.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE), keyLimit) +
(if (hadUnpublishedFallbackKey) 1 else 0)
} else {
Timber.e("## uploadOTK() : response for uploading keys does not contain one_time_key_counts.signed_curve25519")
throw Exception("response for uploading keys does not contain one_time_key_counts.signed_curve25519")
}
}

private fun saveLastFallbackKeyPublishTime(timeMillis: Long) {
storage.edit().putLong("last_fb_key_publish", timeMillis).apply()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: use edit {}

}

private fun getLastFallbackKeyPublishTime(): Long {
return storage.getLong("last_fb_key_publish", 0)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Defining a const should be better (in a store...)

}

/**
* Upload curve25519 one time keys.
*/
Expand All @@ -159,10 +210,26 @@ internal class OneTimeKeysUploader @Inject constructor(
oneTimeJson["signed_curve25519:$key_id"] = k
}

val fallbackJson = mutableMapOf<String, Any>()
val fallbackCurve25519Map = olmDevice.getFallbackKey()?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty()
fallbackCurve25519Map.forEach { (key_id, key) ->
val k = mutableMapOf<String, Any>()
k["key"] = key
k["fallback"] = true
val canonicalJson = JsonCanonicalizer.getCanonicalJson(Map::class.java, k)
k["signatures"] = objectSigner.signObject(canonicalJson)

fallbackJson["signed_curve25519:$key_id"] = k
}

// For now, we set the device id explicitly, as we may not be using the
// same one as used in login.
val uploadParams = UploadKeysTask.Params(null, oneTimeJson)
return uploadKeysTask.execute(uploadParams)
val uploadParams = UploadKeysTask.Params(
deviceKeys = null,
oneTimeKeys = oneTimeJson,
fallbackKeys = fallbackJson.takeIf { fallbackJson.isNotEmpty() }
)
return uploadKeysTask.executeRetry(uploadParams, 3)
}

companion object {
Expand All @@ -173,6 +240,6 @@ internal class OneTimeKeysUploader @Inject constructor(
private const val ONE_TIME_KEY_GENERATION_MAX_NUMBER = 5

// frequency with which to check & upload one-time keys
private const val ONE_TIME_KEY_UPLOAD_PERIOD = (60 * 1000).toLong() // one minute
private const val ONE_TIME_KEY_UPLOAD_PERIOD = (60_000).toLong() // one minute
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,12 @@ internal data class KeysUploadBody(
* May be absent if no new one-time keys are required.
*/
@Json(name = "one_time_keys")
val oneTimeKeys: JsonDict? = null
val oneTimeKeys: JsonDict? = null,

/**
* If the user had previously uploaded a fallback key for a given algorithm, it is replaced.
* The server will only keep one fallback key per algorithm for each user.
*/
@Json(name = "org.matrix.msc2732.fallback_keys")
val fallbackKeys: JsonDict? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ internal interface UploadKeysTask : Task<UploadKeysTask.Params, KeysUploadRespon
// the device keys to send.
val deviceKeys: DeviceKeys?,
// the one-time keys to send.
val oneTimeKeys: JsonDict?
val oneTimeKeys: JsonDict?,
val fallbackKeys: JsonDict?
)
}

Expand All @@ -44,7 +45,8 @@ internal class DefaultUploadKeysTask @Inject constructor(
override suspend fun execute(params: UploadKeysTask.Params): KeysUploadResponse {
val body = KeysUploadBody(
deviceKeys = params.deviceKeys,
oneTimeKeys = params.oneTimeKeys
oneTimeKeys = params.oneTimeKeys,
fallbackKeys = params.fallbackKeys
)

Timber.i("## Uploading device keys -> $body")
Expand Down