Skip to content

Commit

Permalink
Merge pull request #3970 from vector-im/feature/ons/fallback_keys
Browse files Browse the repository at this point in the history
Fallback keys implementation
  • Loading branch information
bmarty authored Dec 13, 2021
2 parents 48fa411 + 76960f8 commit fa06005
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 38 deletions.
1 change: 1 addition & 0 deletions changelog.d/3473.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
MSC2732: Olm fallback keys
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,14 @@ data class SyncResponse(
@Json(name = "device_one_time_keys_count")
val deviceOneTimeKeysCount: DeviceOneTimeKeysCountSyncResponse? = null,

/**
* The key algorithms for which the server has an unused fallback key for the device.
* If the client wants the server to have a fallback key for a given key algorithm,
* but that algorithm is not listed in device_unused_fallback_key_types, the client will upload a new key.
*/
@Json(name = "org.matrix.msc2732.device_unused_fallback_key_types")
val deviceUnusedFallbackKeyTypes: List<String>? = null,

/**
* List of groups.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ import org.matrix.android.sdk.internal.crypto.model.CryptoDeviceInfo
import org.matrix.android.sdk.internal.crypto.model.ImportRoomKeysResult
import org.matrix.android.sdk.internal.crypto.model.MXDeviceInfo
import org.matrix.android.sdk.internal.crypto.model.MXEncryptEventContentResult
import org.matrix.android.sdk.internal.crypto.model.MXKey.Companion.KEY_SIGNED_CURVE_25519_TYPE
import org.matrix.android.sdk.internal.crypto.model.MXUsersDevicesMap
import org.matrix.android.sdk.internal.crypto.model.event.EncryptedEventContent
import org.matrix.android.sdk.internal.crypto.model.event.RoomKeyContent
Expand Down Expand Up @@ -431,6 +432,14 @@ internal class DefaultCryptoService @Inject constructor(
if (isStarted()) {
// Make sure we process to-device messages before generating new one-time-keys #2782
deviceListManager.refreshOutdatedDeviceLists()
// The presence of device_unused_fallback_key_types indicates that the server supports fallback keys.
// If there's no unused signed_curve25519 fallback key we need a new one.
if (syncResponse.deviceUnusedFallbackKeyTypes != null &&
// Generate a fallback key only if the server does not already have an unused fallback key.
!syncResponse.deviceUnusedFallbackKeyTypes.contains(KEY_SIGNED_CURVE_25519_TYPE)) {
oneTimeKeysUploader.needsNewFallback()
}

oneTimeKeysUploader.maybeUploadOneTimeKeys()
incomingGossipingRequestManager.processReceivedGossipingRequests()
}
Expand Down Expand Up @@ -928,7 +937,7 @@ internal class DefaultCryptoService @Inject constructor(
signatures = objectSigner.signObject(canonicalJson)
)

val uploadDeviceKeysParams = UploadKeysTask.Params(rest, null)
val uploadDeviceKeysParams = UploadKeysTask.Params(rest, null, null)
uploadKeysTask.execute(uploadDeviceKeysParams)

cryptoStore.setDeviceKeysUploaded(true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,51 @@ internal class MXOlmDevice @Inject constructor(
return store.getOlmAccount().maxOneTimeKeys()
}

/**
* Returns an unpublished fallback key
* A call to markKeysAsPublished will mark it as published and this
* call will return null (until a call to generateFallbackKey is made)
*/
fun getFallbackKey(): MutableMap<String, MutableMap<String, String>>? {
try {
return store.getOlmAccount().fallbackKey()
} catch (e: Exception) {
Timber.e("## getFallbackKey() : failed")
}
return null
}

/**
* Generates a new fallback key if there is not already
* an unpublished one.
* @return true if a new key was generated
*/
fun generateFallbackKeyIfNeeded(): Boolean {
try {
if (!hasUnpublishedFallbackKey()) {
store.getOlmAccount().generateFallbackKey()
store.saveOlmAccount()
return true
}
} catch (e: Exception) {
Timber.e("## generateFallbackKey() : failed")
}
return false
}

internal fun hasUnpublishedFallbackKey(): Boolean {
return getFallbackKey()?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty().isNotEmpty()
}

fun forgetFallbackKey() {
try {
store.getOlmAccount().forgetFallbackKey()
store.saveOlmAccount()
} catch (e: Exception) {
Timber.e("## forgetFallbackKey() : failed")
}
}

/**
* Release the instance
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package org.matrix.android.sdk.internal.crypto

import android.content.Context
import org.matrix.android.sdk.api.extensions.tryOrNull
import org.matrix.android.sdk.internal.crypto.model.MXKey
import org.matrix.android.sdk.internal.crypto.model.rest.KeysUploadResponse
Expand All @@ -28,11 +29,16 @@ import javax.inject.Inject
import kotlin.math.floor
import kotlin.math.min

// The spec recommend a 5mn delay, but due to federation
// or server downtime we give it a bit more time (1 hour)
const val FALLBACK_KEY_FORGET_DELAY = 60 * 60_000L

@SessionScope
internal class OneTimeKeysUploader @Inject constructor(
private val olmDevice: MXOlmDevice,
private val objectSigner: ObjectSigner,
private val uploadKeysTask: UploadKeysTask
private val uploadKeysTask: UploadKeysTask,
context: Context
) {
// tell if there is a OTK check in progress
private var oneTimeKeyCheckInProgress = false
Expand All @@ -41,6 +47,9 @@ internal class OneTimeKeysUploader @Inject constructor(
private var lastOneTimeKeyCheck: Long = 0
private var oneTimeKeyCount: Int? = null

// Simple storage to remember when was uploaded the last fallback key
private val storage = context.getSharedPreferences("OneTimeKeysUploader_${olmDevice.deviceEd25519Key.hashCode()}", Context.MODE_PRIVATE)

/**
* Stores the current one_time_key count which will be handled later (in a call of
* _onSyncCompleted). The count is e.g. coming from a /sync response.
Expand All @@ -51,6 +60,15 @@ internal class OneTimeKeysUploader @Inject constructor(
oneTimeKeyCount = currentCount
}

fun needsNewFallback() {
if (olmDevice.generateFallbackKeyIfNeeded()) {
// As we generated a new one, it's already forgetting one
// so we can clear the last publish time
// (in case the network calls fails after to avoid calling forgetKey)
saveLastFallbackKeyPublishTime(0L)
}
}

/**
* Check if the OTK must be uploaded.
*/
Expand All @@ -65,9 +83,19 @@ internal class OneTimeKeysUploader @Inject constructor(
return
}

lastOneTimeKeyCheck = System.currentTimeMillis()
oneTimeKeyCheckInProgress = true

val oneTimeKeyCountFromSync = oneTimeKeyCount
?: fetchOtkCount() // we don't have count from sync so get from server
?: return Unit.also {
oneTimeKeyCheckInProgress = false
Timber.w("maybeUploadOneTimeKeys: Failed to get otk count from server")
}

Timber.d("maybeUploadOneTimeKeys: otk count $oneTimeKeyCountFromSync , unpublished fallback key ${olmDevice.hasUnpublishedFallbackKey()}")

lastOneTimeKeyCheck = System.currentTimeMillis()

// We then check how many keys we can store in the Account object.
val maxOneTimeKeys = olmDevice.getMaxNumberOfOneTimeKeys()

Expand All @@ -78,37 +106,37 @@ internal class OneTimeKeysUploader @Inject constructor(
// discard the oldest private keys first. This will eventually clean
// out stale private keys that won't receive a message.
val keyLimit = floor(maxOneTimeKeys / 2.0).toInt()
if (oneTimeKeyCount == null) {
// Ask the server how many otk he has
oneTimeKeyCount = fetchOtkCount()
}
val oneTimeKeyCountFromSync = oneTimeKeyCount
if (oneTimeKeyCountFromSync != null) {
// We need to keep a pool of one time public keys on the server so that
// other devices can start conversations with us. But we can only store
// a finite number of private keys in the olm Account object.
// To complicate things further then can be a delay between a device
// claiming a public one time key from the server and it sending us a
// message. We need to keep the corresponding private key locally until
// we receive the message.
// But that message might never arrive leaving us stuck with duff
// private keys clogging up our local storage.
// So we need some kind of engineering compromise to balance all of
// these factors.
tryOrNull("Unable to upload OTK") {
val uploadedKeys = uploadOTK(oneTimeKeyCountFromSync, keyLimit)
Timber.v("## uploadKeys() : success, $uploadedKeys key(s) sent")
}
} else {
Timber.w("maybeUploadOneTimeKeys: waiting to know the number of OTK from the sync")
lastOneTimeKeyCheck = 0

// We need to keep a pool of one time public keys on the server so that
// other devices can start conversations with us. But we can only store
// a finite number of private keys in the olm Account object.
// To complicate things further then can be a delay between a device
// claiming a public one time key from the server and it sending us a
// message. We need to keep the corresponding private key locally until
// we receive the message.
// But that message might never arrive leaving us stuck with duff
// private keys clogging up our local storage.
// So we need some kind of engineering compromise to balance all of
// these factors.
tryOrNull("Unable to upload OTK") {
val uploadedKeys = uploadOTK(oneTimeKeyCountFromSync, keyLimit)
Timber.v("## uploadKeys() : success, $uploadedKeys key(s) sent")
}
oneTimeKeyCheckInProgress = false

// Check if we need to forget a fallback key
val latestPublishedTime = getLastFallbackKeyPublishTime()
if (latestPublishedTime != 0L && System.currentTimeMillis() - latestPublishedTime > FALLBACK_KEY_FORGET_DELAY) {
// This should be called once you are reasonably certain that you will not receive any more messages
// that use the old fallback key
Timber.d("## forgetFallbackKey()")
olmDevice.forgetFallbackKey()
}
}

private suspend fun fetchOtkCount(): Int? {
return tryOrNull("Unable to get OTK count") {
val result = uploadKeysTask.execute(UploadKeysTask.Params(null, null))
val result = uploadKeysTask.execute(UploadKeysTask.Params(null, null, null))
result.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE)
}
}
Expand All @@ -121,24 +149,47 @@ internal class OneTimeKeysUploader @Inject constructor(
* @return the number of uploaded keys
*/
private suspend fun uploadOTK(keyCount: Int, keyLimit: Int): Int {
if (keyLimit <= keyCount) {
if (keyLimit <= keyCount && !olmDevice.hasUnpublishedFallbackKey()) {
// If we don't need to generate any more keys then we are done.
return 0
}
val keysThisLoop = min(keyLimit - keyCount, ONE_TIME_KEY_GENERATION_MAX_NUMBER)
olmDevice.generateOneTimeKeys(keysThisLoop)
var keysThisLoop = 0
if (keyLimit > keyCount) {
// Creating keys can be an expensive operation so we limit the
// number we generate in one go to avoid blocking the application
// for too long.
keysThisLoop = min(keyLimit - keyCount, ONE_TIME_KEY_GENERATION_MAX_NUMBER)
olmDevice.generateOneTimeKeys(keysThisLoop)
}

// We check before sending if there is an unpublished key in order to saveLastFallbackKeyPublishTime if needed
val hadUnpublishedFallbackKey = olmDevice.hasUnpublishedFallbackKey()
val response = uploadOneTimeKeys(olmDevice.getOneTimeKeys())
olmDevice.markKeysAsPublished()
if (hadUnpublishedFallbackKey) {
// It had an unpublished fallback key that was published just now
saveLastFallbackKeyPublishTime(System.currentTimeMillis())
}

if (response.hasOneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE)) {
// Maybe upload other keys
return keysThisLoop + uploadOTK(response.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE), keyLimit)
return keysThisLoop +
uploadOTK(response.oneTimeKeyCountsForAlgorithm(MXKey.KEY_SIGNED_CURVE_25519_TYPE), keyLimit) +
(if (hadUnpublishedFallbackKey) 1 else 0)
} else {
Timber.e("## uploadOTK() : response for uploading keys does not contain one_time_key_counts.signed_curve25519")
throw Exception("response for uploading keys does not contain one_time_key_counts.signed_curve25519")
}
}

private fun saveLastFallbackKeyPublishTime(timeMillis: Long) {
storage.edit().putLong("last_fb_key_publish", timeMillis).apply()
}

private fun getLastFallbackKeyPublishTime(): Long {
return storage.getLong("last_fb_key_publish", 0)
}

/**
* Upload curve25519 one time keys.
*/
Expand All @@ -159,10 +210,26 @@ internal class OneTimeKeysUploader @Inject constructor(
oneTimeJson["signed_curve25519:$key_id"] = k
}

val fallbackJson = mutableMapOf<String, Any>()
val fallbackCurve25519Map = olmDevice.getFallbackKey()?.get(OlmAccount.JSON_KEY_ONE_TIME_KEY).orEmpty()
fallbackCurve25519Map.forEach { (key_id, key) ->
val k = mutableMapOf<String, Any>()
k["key"] = key
k["fallback"] = true
val canonicalJson = JsonCanonicalizer.getCanonicalJson(Map::class.java, k)
k["signatures"] = objectSigner.signObject(canonicalJson)

fallbackJson["signed_curve25519:$key_id"] = k
}

// For now, we set the device id explicitly, as we may not be using the
// same one as used in login.
val uploadParams = UploadKeysTask.Params(null, oneTimeJson)
return uploadKeysTask.execute(uploadParams)
val uploadParams = UploadKeysTask.Params(
deviceKeys = null,
oneTimeKeys = oneTimeJson,
fallbackKeys = fallbackJson.takeIf { fallbackJson.isNotEmpty() }
)
return uploadKeysTask.executeRetry(uploadParams, 3)
}

companion object {
Expand All @@ -173,6 +240,6 @@ internal class OneTimeKeysUploader @Inject constructor(
private const val ONE_TIME_KEY_GENERATION_MAX_NUMBER = 5

// frequency with which to check & upload one-time keys
private const val ONE_TIME_KEY_UPLOAD_PERIOD = (60 * 1000).toLong() // one minute
private const val ONE_TIME_KEY_UPLOAD_PERIOD = (60_000).toLong() // one minute
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,5 +40,12 @@ internal data class KeysUploadBody(
* May be absent if no new one-time keys are required.
*/
@Json(name = "one_time_keys")
val oneTimeKeys: JsonDict? = null
val oneTimeKeys: JsonDict? = null,

/**
* If the user had previously uploaded a fallback key for a given algorithm, it is replaced.
* The server will only keep one fallback key per algorithm for each user.
*/
@Json(name = "org.matrix.msc2732.fallback_keys")
val fallbackKeys: JsonDict? = null
)
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ internal interface UploadKeysTask : Task<UploadKeysTask.Params, KeysUploadRespon
// the device keys to send.
val deviceKeys: DeviceKeys?,
// the one-time keys to send.
val oneTimeKeys: JsonDict?
val oneTimeKeys: JsonDict?,
val fallbackKeys: JsonDict?
)
}

Expand All @@ -44,7 +45,8 @@ internal class DefaultUploadKeysTask @Inject constructor(
override suspend fun execute(params: UploadKeysTask.Params): KeysUploadResponse {
val body = KeysUploadBody(
deviceKeys = params.deviceKeys,
oneTimeKeys = params.oneTimeKeys
oneTimeKeys = params.oneTimeKeys,
fallbackKeys = params.fallbackKeys
)

Timber.i("## Uploading device keys -> $body")
Expand Down

0 comments on commit fa06005

Please sign in to comment.