Skip to content

Commit

Permalink
Fix issue with prekey update
Browse files Browse the repository at this point in the history
  • Loading branch information
AsamK committed Feb 18, 2024
1 parent 7206b4d commit 2c0ad7f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,11 @@ public void checkAccountState() throws IOException {
}
try {
updateAccountAttributes();
context.getPreKeyHelper().refreshPreKeysIfNecessary();
if (account.getPreviousStorageVersion() < 9) {
context.getPreKeyHelper().forceRefreshPreKeys();
} else {
context.getPreKeyHelper().refreshPreKeysIfNecessary();
}
if (account.getAci() == null || account.getPni() == null) {
checkWhoAmiI();
}
Expand Down
50 changes: 40 additions & 10 deletions lib/src/main/java/org/asamk/signal/manager/helper/PreKeyHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ public void refreshPreKeysIfNecessary() throws IOException {
refreshPreKeysIfNecessary(ServiceIdType.PNI);
}

public void forceRefreshPreKeys() throws IOException {
forceRefreshPreKeys(ServiceIdType.ACI);
forceRefreshPreKeys(ServiceIdType.PNI);
}

public void refreshPreKeysIfNecessary(ServiceIdType serviceIdType) throws IOException {
final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
if (identityKeyPair == null) {
Expand All @@ -56,6 +61,22 @@ public void refreshPreKeysIfNecessary(ServiceIdType serviceIdType) throws IOExce
}
}

public void forceRefreshPreKeys(ServiceIdType serviceIdType) throws IOException {
final var identityKeyPair = account.getIdentityKeyPair(serviceIdType);
if (identityKeyPair == null) {
return;
}
final var accountId = account.getAccountId(serviceIdType);
if (accountId == null) {
return;
}

final var counts = new OneTimePreKeyCounts(0, 0);
if (refreshPreKeysIfNecessary(serviceIdType, identityKeyPair, counts, true)) {
refreshPreKeysIfNecessary(serviceIdType, identityKeyPair, counts, true);
}
}

private boolean refreshPreKeysIfNecessary(
final ServiceIdType serviceIdType, final IdentityKeyPair identityKeyPair
) throws IOException {
Expand All @@ -67,8 +88,17 @@ private boolean refreshPreKeysIfNecessary(
preKeyCounts = new OneTimePreKeyCounts(0, 0);
}

return refreshPreKeysIfNecessary(serviceIdType, identityKeyPair, preKeyCounts, false);
}

private boolean refreshPreKeysIfNecessary(
final ServiceIdType serviceIdType,
final IdentityKeyPair identityKeyPair,
final OneTimePreKeyCounts preKeyCounts,
final boolean force
) throws IOException {
List<PreKeyRecord> preKeyRecords = null;
if (preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
if (force || preKeyCounts.getEcCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
logger.debug("Refreshing {} ec pre keys, because only {} of min {} pre keys remain",
serviceIdType,
preKeyCounts.getEcCount(),
Expand All @@ -77,13 +107,13 @@ private boolean refreshPreKeysIfNecessary(
}

SignedPreKeyRecord signedPreKeyRecord = null;
if (signedPreKeyNeedsRefresh(serviceIdType)) {
if (force || signedPreKeyNeedsRefresh(serviceIdType)) {
logger.debug("Refreshing {} signed pre key.", serviceIdType);
signedPreKeyRecord = generateSignedPreKey(serviceIdType, identityKeyPair);
}

List<KyberPreKeyRecord> kyberPreKeyRecords = null;
if (preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
if (force || preKeyCounts.getKyberCount() < ServiceConfig.PREKEY_MINIMUM_COUNT) {
logger.debug("Refreshing {} kyber pre keys, because only {} of min {} pre keys remain",
serviceIdType,
preKeyCounts.getKyberCount(),
Expand All @@ -92,9 +122,11 @@ private boolean refreshPreKeysIfNecessary(
}

KyberPreKeyRecord lastResortKyberPreKeyRecord = null;
if (lastResortKyberPreKeyNeedsRefresh(serviceIdType)) {
if (force || lastResortKyberPreKeyNeedsRefresh(serviceIdType)) {
logger.debug("Refreshing {} last resort kyber pre key.", serviceIdType);
lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType, identityKeyPair);
lastResortKyberPreKeyRecord = generateLastResortKyberPreKey(serviceIdType,
identityKeyPair,
kyberPreKeyRecords == null ? 0 : kyberPreKeyRecords.size());
}

if (signedPreKeyRecord == null
Expand Down Expand Up @@ -157,9 +189,7 @@ private List<PreKeyRecord> generatePreKeys(ServiceIdType serviceIdType) {
final var accountData = account.getAccountData(serviceIdType);
final var offset = accountData.getPreKeyMetadata().getNextPreKeyId();

var records = KeyUtils.generatePreKeyRecords(offset);

return records;
return KeyUtils.generatePreKeyRecords(offset);
}

private boolean signedPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
Expand Down Expand Up @@ -210,10 +240,10 @@ private boolean lastResortKyberPreKeyNeedsRefresh(ServiceIdType serviceIdType) {
}

private KyberPreKeyRecord generateLastResortKyberPreKey(
ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair
ServiceIdType serviceIdType, IdentityKeyPair identityKeyPair, final int offset
) {
final var accountData = account.getAccountData(serviceIdType);
final var signedPreKeyId = accountData.getPreKeyMetadata().getNextKyberPreKeyId();
final var signedPreKeyId = accountData.getPreKeyMetadata().getNextKyberPreKeyId() + offset;

return KeyUtils.generateKyberPreKeyRecord(signedPreKeyId, identityKeyPair.getPrivateKey());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ public class SignalAccount implements Closeable {
private static final Logger logger = LoggerFactory.getLogger(SignalAccount.class);

private static final int MINIMUM_STORAGE_VERSION = 1;
private static final int CURRENT_STORAGE_VERSION = 8;
private static final int CURRENT_STORAGE_VERSION = 9;

private final Object LOCK = new Object();

Expand Down Expand Up @@ -1111,7 +1111,7 @@ public void addKyberPreKeys(ServiceIdType serviceIdType, List<KyberPreKeyRecord>
serviceIdType,
preKeyMetadata.nextKyberPreKeyId);
accountData.getSignalServiceAccountDataStore()
.markAllOneTimeEcPreKeysStaleIfNecessary(System.currentTimeMillis());
.markAllOneTimeKyberPreKeysStaleIfNecessary(System.currentTimeMillis());
for (var record : records) {
if (preKeyMetadata.nextKyberPreKeyId != record.getId()) {
logger.error("Invalid kyber pre key id {}, expected {}",
Expand Down

0 comments on commit 2c0ad7f

Please sign in to comment.