diff --git a/ts/LibSignalStores.ts b/ts/LibSignalStores.ts index 89acecc03..56332c49a 100644 --- a/ts/LibSignalStores.ts +++ b/ts/LibSignalStores.ts @@ -214,15 +214,19 @@ export class PreKeys extends PreKeyStore { } export type SenderKeysOptions = Readonly<{ - ourUuid: UUID; + readonly ourUuid: UUID; + readonly zone: Zone | undefined; }>; export class SenderKeys extends SenderKeyStore { private readonly ourUuid: UUID; - constructor({ ourUuid }: SenderKeysOptions) { + readonly zone: Zone | undefined; + + constructor({ ourUuid, zone }: SenderKeysOptions) { super(); this.ourUuid = ourUuid; + this.zone = zone; } async saveSenderKey( @@ -235,7 +239,8 @@ export class SenderKeys extends SenderKeyStore { await window.textsecure.storage.protocol.saveSenderKey( encodedAddress, distributionId, - record + record, + { zone: this.zone } ); } @@ -247,7 +252,8 @@ export class SenderKeys extends SenderKeyStore { const senderKey = await window.textsecure.storage.protocol.getSenderKey( encodedAddress, - distributionId + distributionId, + { zone: this.zone } ); return senderKey || null; diff --git a/ts/SignalProtocolStore.ts b/ts/SignalProtocolStore.ts index 02e7922e4..3ac3fd654 100644 --- a/ts/SignalProtocolStore.ts +++ b/ts/SignalProtocolStore.ts @@ -191,6 +191,7 @@ const EventsMixin = function EventsMixin(this: unknown) { } as any as typeof window.Backbone.EventsMixin; type SessionCacheEntry = CacheEntryType; +type SenderKeyCacheEntry = CacheEntryType; type ZoneQueueEntryType = Readonly<{ zone: Zone; @@ -213,10 +214,7 @@ export class SignalProtocolStore extends EventsMixin { CacheEntryType >; - senderKeys?: Map< - SenderKeyIdType, - CacheEntryType - >; + senderKeys?: Map; sessions?: Map; @@ -239,6 +237,8 @@ export class SignalProtocolStore extends EventsMixin { private pendingSessions = new Map(); + private pendingSenderKeys = new Map(); + private pendingUnprocessed = new Map(); async hydrateCaches(): Promise { @@ -501,7 +501,21 @@ export class SignalProtocolStore extends EventsMixin { await window.Signal.Data.removeAllSignedPreKeys(); } - // Sender Key Queue + // Sender Key + + // Re-entrant sender key transaction routine. Only one sender key transaction could + // be running at the same time. + // + // While in transaction: + // + // - `saveSenderKey()` adds the updated session to the `pendingSenderKeys` + // - `getSenderKey()` looks up the session first in `pendingSenderKeys` and only + // then in the main `senderKeys` store + // + // When transaction ends: + // + // - successfully: pending sender key stores are batched into the database + // - with an error: pending sender key stores are reverted async enqueueSenderKeyJob( qualifiedAddress: QualifiedAddress, @@ -534,8 +548,6 @@ export class SignalProtocolStore extends EventsMixin { return freshQueue; } - // Sender Keys - private getSenderKeyId( senderKeyId: QualifiedAddress, distributionId: string @@ -546,79 +558,94 @@ export class SignalProtocolStore extends EventsMixin { async saveSenderKey( qualifiedAddress: QualifiedAddress, distributionId: string, - record: SenderKeyRecord + record: SenderKeyRecord, + { zone = GLOBAL_ZONE }: SessionTransactionOptions = {} ): Promise { - if (!this.senderKeys) { - throw new Error('saveSenderKey: this.senderKeys not yet cached!'); - } + await this.withZone(zone, 'saveSenderKey', async () => { + if (!this.senderKeys) { + throw new Error('saveSenderKey: this.senderKeys not yet cached!'); + } - const senderId = qualifiedAddress.toString(); + const senderId = qualifiedAddress.toString(); - try { - const id = this.getSenderKeyId(qualifiedAddress, distributionId); + try { + const id = this.getSenderKeyId(qualifiedAddress, distributionId); - const fromDB: SenderKeyType = { - id, - senderId, - distributionId, - data: record.serialize(), - lastUpdatedDate: Date.now(), - }; + const fromDB: SenderKeyType = { + id, + senderId, + distributionId, + data: record.serialize(), + lastUpdatedDate: Date.now(), + }; - await window.Signal.Data.createOrUpdateSenderKey(fromDB); + this.pendingSenderKeys.set(id, { + hydrated: true, + fromDB, + item: record, + }); - this.senderKeys.set(id, { - hydrated: true, - fromDB, - item: record, - }); - } catch (error) { - const errorString = error && error.stack ? error.stack : error; - log.error( - `saveSenderKey: failed to save senderKey ${senderId}/${distributionId}: ${errorString}` - ); - } + // Current zone doesn't support pending sessions - commit immediately + if (!zone.supportsPendingSenderKeys()) { + await this.commitZoneChanges('saveSenderKey'); + } + } catch (error) { + const errorString = error && error.stack ? error.stack : error; + log.error( + `saveSenderKey: failed to save senderKey ${senderId}/${distributionId}: ${errorString}` + ); + } + }); } async getSenderKey( qualifiedAddress: QualifiedAddress, - distributionId: string + distributionId: string, + { zone = GLOBAL_ZONE }: SessionTransactionOptions = {} ): Promise { - if (!this.senderKeys) { - throw new Error('getSenderKey: this.senderKeys not yet cached!'); - } + return this.withZone(zone, 'getSenderKey', async () => { + if (!this.senderKeys) { + throw new Error('getSenderKey: this.senderKeys not yet cached!'); + } - const senderId = qualifiedAddress.toString(); + const senderId = qualifiedAddress.toString(); - try { - const id = this.getSenderKeyId(qualifiedAddress, distributionId); + try { + const id = this.getSenderKeyId(qualifiedAddress, distributionId); - const entry = this.senderKeys.get(id); - if (!entry) { - log.error('Failed to fetch sender key:', id); + const map = this.pendingSenderKeys.has(id) + ? this.pendingSenderKeys + : this.senderKeys; + const entry = map.get(id); + + if (!entry) { + log.error('Failed to fetch sender key:', id); + return undefined; + } + + if (entry.hydrated) { + log.info('Successfully fetched sender key (cache hit):', id); + return entry.item; + } + + const item = SenderKeyRecord.deserialize( + Buffer.from(entry.fromDB.data) + ); + this.senderKeys.set(id, { + hydrated: true, + item, + fromDB: entry.fromDB, + }); + log.info('Successfully fetched sender key(cache miss):', id); + return item; + } catch (error) { + const errorString = error && error.stack ? error.stack : error; + log.error( + `getSenderKey: failed to load sender key ${senderId}/${distributionId}: ${errorString}` + ); return undefined; } - - if (entry.hydrated) { - log.info('Successfully fetched sender key (cache hit):', id); - return entry.item; - } - - const item = SenderKeyRecord.deserialize(Buffer.from(entry.fromDB.data)); - this.senderKeys.set(id, { - hydrated: true, - item, - fromDB: entry.fromDB, - }); - log.info('Successfully fetched sender key(cache miss):', id); - return item; - } catch (error) { - const errorString = error && error.stack ? error.stack : error; - log.error( - `getSenderKey: failed to load sender key ${senderId}/${distributionId}: ${errorString}` - ); - return undefined; - } + }); } async removeSenderKey( @@ -645,11 +672,16 @@ export class SignalProtocolStore extends EventsMixin { } } - async clearSenderKeyStore(): Promise { - if (this.senderKeys) { - this.senderKeys.clear(); - } - await window.Signal.Data.removeAllSenderKeys(); + async removeAllSenderKeys(): Promise { + return this.withZone(GLOBAL_ZONE, 'removeAllSenderKeys', async () => { + if (this.senderKeys) { + this.senderKeys.clear(); + } + if (this.pendingSenderKeys) { + this.pendingSenderKeys.clear(); + } + await window.Signal.Data.removeAllSenderKeys(); + }); } // Session Queue @@ -700,6 +732,7 @@ export class SignalProtocolStore extends EventsMixin { // // - successfully: pending session stores are batched into the database // - with an error: pending session stores are reverted + public async withZone( zone: Zone, name: string, @@ -753,45 +786,66 @@ export class SignalProtocolStore extends EventsMixin { } private async commitZoneChanges(name: string): Promise { - const { pendingSessions, pendingUnprocessed } = this; + const { pendingSenderKeys, pendingSessions, pendingUnprocessed } = this; - if (pendingSessions.size === 0 && pendingUnprocessed.size === 0) { + if ( + pendingSenderKeys.size === 0 && + pendingSessions.size === 0 && + pendingUnprocessed.size === 0 + ) { return; } log.info( - `commitZoneChanges(${name}): pending sessions ${pendingSessions.size} ` + + `commitZoneChanges(${name}): ` + + `pending sender keys ${pendingSenderKeys.size}, ` + + `pending sessions ${pendingSessions.size}, ` + `pending unprocessed ${pendingUnprocessed.size}` ); + this.pendingSenderKeys = new Map(); this.pendingSessions = new Map(); this.pendingUnprocessed = new Map(); - // Commit both unprocessed and sessions in the same database transaction - // to unroll both on error. - await window.Signal.Data.commitSessionsAndUnprocessed({ + // Commit both sender keys, sessions and unprocessed in the same database transaction + // to unroll both on error. + await window.Signal.Data.commitDecryptResult({ + senderKeys: Array.from(pendingSenderKeys.values()).map( + ({ fromDB }) => fromDB + ), sessions: Array.from(pendingSessions.values()).map( ({ fromDB }) => fromDB ), unprocessed: Array.from(pendingUnprocessed.values()), }); - const { sessions } = this; - assert(sessions !== undefined, "Can't commit unhydrated storage"); - // Apply changes to in-memory storage after successful DB write. + + const { sessions } = this; + assert(sessions !== undefined, "Can't commit unhydrated session storage"); pendingSessions.forEach((value, key) => { sessions.set(key, value); }); + + const { senderKeys } = this; + assert( + senderKeys !== undefined, + "Can't commit unhydrated sender key storage" + ); + pendingSenderKeys.forEach((value, key) => { + senderKeys.set(key, value); + }); } private async revertZoneChanges(name: string, error: Error): Promise { log.info( `revertZoneChanges(${name}): ` + - `pending sessions size ${this.pendingSessions.size} ` + + `pending sender keys size ${this.pendingSenderKeys.size}, ` + + `pending sessions size ${this.pendingSessions.size}, ` + `pending unprocessed size ${this.pendingUnprocessed.size}`, error && error.stack ); + this.pendingSenderKeys.clear(); this.pendingSessions.clear(); this.pendingUnprocessed.clear(); } diff --git a/ts/sql/Client.ts b/ts/sql/Client.ts index 99f0e7a78..ba90d5e0f 100644 --- a/ts/sql/Client.ts +++ b/ts/sql/Client.ts @@ -186,7 +186,7 @@ const dataInterface: ClientInterface = { createOrUpdateSession, createOrUpdateSessions, - commitSessionsAndUnprocessed, + commitDecryptResult, bulkAddSessions, removeSessionById, removeSessionsByConversation, @@ -921,11 +921,12 @@ async function createOrUpdateSession(data: SessionType) { async function createOrUpdateSessions(array: Array) { await channels.createOrUpdateSessions(array); } -async function commitSessionsAndUnprocessed(options: { +async function commitDecryptResult(options: { + senderKeys: Array; sessions: Array; unprocessed: Array; }) { - await channels.commitSessionsAndUnprocessed(options); + await channels.commitDecryptResult(options); } async function bulkAddSessions(array: Array) { await channels.bulkAddSessions(array); diff --git a/ts/sql/Interface.ts b/ts/sql/Interface.ts index 6f73ea884..6d849d263 100644 --- a/ts/sql/Interface.ts +++ b/ts/sql/Interface.ts @@ -329,7 +329,8 @@ export type DataInterface = { createOrUpdateSession: (data: SessionType) => Promise; createOrUpdateSessions: (array: Array) => Promise; - commitSessionsAndUnprocessed(options: { + commitDecryptResult(options: { + senderKeys: Array; sessions: Array; unprocessed: Array; }): Promise; diff --git a/ts/sql/Server.ts b/ts/sql/Server.ts index dfb66c15d..91a213023 100644 --- a/ts/sql/Server.ts +++ b/ts/sql/Server.ts @@ -182,7 +182,7 @@ const dataInterface: ServerInterface = { createOrUpdateSession, createOrUpdateSessions, - commitSessionsAndUnprocessed, + commitDecryptResult, bulkAddSessions, removeSessionById, removeSessionsByConversation, @@ -757,6 +757,10 @@ async function removeAllItems(): Promise { } async function createOrUpdateSenderKey(key: SenderKeyType): Promise { + createOrUpdateSenderKeySync(key); +} + +function createOrUpdateSenderKeySync(key: SenderKeyType): void { const db = getInstance(); prepare( @@ -1175,16 +1179,22 @@ async function createOrUpdateSessions( })(); } -async function commitSessionsAndUnprocessed({ +async function commitDecryptResult({ + senderKeys, sessions, unprocessed, }: { + senderKeys: Array; sessions: Array; unprocessed: Array; }): Promise { const db = getInstance(); db.transaction(() => { + for (const item of senderKeys) { + assertSync(createOrUpdateSenderKeySync(item)); + } + for (const item of sessions) { assertSync(createOrUpdateSessionSync(item)); } diff --git a/ts/test-electron/SignalProtocolStore_test.ts b/ts/test-electron/SignalProtocolStore_test.ts index 4b30f1dac..b47deaed0 100644 --- a/ts/test-electron/SignalProtocolStore_test.ts +++ b/ts/test-electron/SignalProtocolStore_test.ts @@ -1442,7 +1442,9 @@ describe('SignalProtocolStore', () => { }); describe('zones', () => { + const distributionId = UUID.generate().toString(); const zone = new Zone('zone', { + pendingSenderKeys: true, pendingSessions: true, pendingUnprocessed: true, }); @@ -1450,6 +1452,7 @@ describe('SignalProtocolStore', () => { beforeEach(async () => { await store.removeAllUnprocessed(); await store.removeAllSessions(theirUuid.toString()); + await store.removeAllSenderKeys(); }); it('should not store pending sessions in global zone', async () => { @@ -1467,12 +1470,29 @@ describe('SignalProtocolStore', () => { assert.equal(await store.loadSession(id), testRecord); }); - it('commits session stores and unprocessed on success', async () => { + it('should not store pending sender keys in global zone', async () => { const id = new QualifiedAddress(ourUuid, new Address(theirUuid, 1)); - const testRecord = getSessionRecord(); + const testRecord = getSenderKeyRecord(); + + await assert.isRejected( + store.withZone(GLOBAL_ZONE, 'test', async () => { + await store.saveSenderKey(id, distributionId, testRecord); + throw new Error('Failure'); + }), + 'Failure' + ); + + assert.equal(await store.getSenderKey(id, distributionId), testRecord); + }); + + it('commits sender keys, sessions and unprocessed on success', async () => { + const id = new QualifiedAddress(ourUuid, new Address(theirUuid, 1)); + const testSession = getSessionRecord(); + const testSenderKey = getSenderKeyRecord(); await store.withZone(zone, 'test', async () => { - await store.storeSession(id, testRecord, { zone }); + await store.storeSession(id, testSession, { zone }); + await store.saveSenderKey(id, distributionId, testSenderKey, { zone }); await store.addUnprocessed( { @@ -1484,10 +1504,16 @@ describe('SignalProtocolStore', () => { }, { zone } ); - assert.equal(await store.loadSession(id, { zone }), testRecord); + + assert.equal(await store.loadSession(id, { zone }), testSession); + assert.equal( + await store.getSenderKey(id, distributionId, { zone }), + testSenderKey + ); }); - assert.equal(await store.loadSession(id), testRecord); + assert.equal(await store.loadSession(id), testSession); + assert.equal(await store.getSenderKey(id, distributionId), testSenderKey); const allUnprocessed = await store.getAllUnprocessed(); assert.deepEqual( @@ -1496,18 +1522,31 @@ describe('SignalProtocolStore', () => { ); }); - it('reverts session stores and unprocessed on error', async () => { + it('reverts sender keys, sessions and unprocessed on error', async () => { const id = new QualifiedAddress(ourUuid, new Address(theirUuid, 1)); - const testRecord = getSessionRecord(); - const failedRecord = getSessionRecord(); + const testSession = getSessionRecord(); + const failedSession = getSessionRecord(); + const testSenderKey = getSenderKeyRecord(); + const failedSenderKey = getSenderKeyRecord(); - await store.storeSession(id, testRecord); - assert.equal(await store.loadSession(id), testRecord); + await store.storeSession(id, testSession); + assert.equal(await store.loadSession(id), testSession); + + await store.saveSenderKey(id, distributionId, testSenderKey); + assert.equal(await store.getSenderKey(id, distributionId), testSenderKey); await assert.isRejected( store.withZone(zone, 'test', async () => { - await store.storeSession(id, failedRecord, { zone }); - assert.equal(await store.loadSession(id, { zone }), failedRecord); + await store.storeSession(id, failedSession, { zone }); + assert.equal(await store.loadSession(id, { zone }), failedSession); + + await store.saveSenderKey(id, distributionId, failedSenderKey, { + zone, + }); + assert.equal( + await store.getSenderKey(id, distributionId, { zone }), + failedSenderKey + ); await store.addUnprocessed( { @@ -1525,7 +1564,8 @@ describe('SignalProtocolStore', () => { 'Failure' ); - assert.equal(await store.loadSession(id), testRecord); + assert.equal(await store.loadSession(id), testSession); + assert.equal(await store.getSenderKey(id, distributionId), testSenderKey); assert.deepEqual(await store.getAllUnprocessed(), []); }); diff --git a/ts/textsecure/MessageReceiver.ts b/ts/textsecure/MessageReceiver.ts index 256223b1c..0365fa012 100644 --- a/ts/textsecure/MessageReceiver.ts +++ b/ts/textsecure/MessageReceiver.ts @@ -138,6 +138,7 @@ type CacheAddItemType = { }; type LockedStores = { + readonly senderKeyStore: SenderKeys; readonly sessionStore: Sessions; readonly identityKeyStore: IdentityKeys; readonly zone?: Zone; @@ -779,6 +780,7 @@ export default class MessageReceiver try { const zone = new Zone('decryptAndCacheBatch', { + pendingSenderKeys: true, pendingSessions: true, pendingUnprocessed: true, }); @@ -814,6 +816,10 @@ export default class MessageReceiver let stores = storesMap.get(destinationUuid.toString()); if (!stores) { stores = { + senderKeyStore: new SenderKeys({ + ourUuid: destinationUuid, + zone, + }), sessionStore: new Sessions({ zone, ourUuid: destinationUuid, @@ -1301,7 +1307,7 @@ export default class MessageReceiver } private async decryptSealedSender( - { sessionStore, identityKeyStore, zone }: LockedStores, + { senderKeyStore, sessionStore, identityKeyStore, zone }: LockedStores, envelope: UnsealedEnvelope, ciphertext: Uint8Array ): Promise { @@ -1352,7 +1358,6 @@ export default class MessageReceiver ); const sealedSenderIdentifier = certificate.senderUuid(); const sealedSenderSourceDevice = certificate.senderDeviceId(); - const senderKeyStore = new SenderKeys({ ourUuid: destinationUuid }); const address = new QualifiedAddress( destinationUuid, @@ -2000,7 +2005,6 @@ export default class MessageReceiver Buffer.from(distributionMessage) ); const { destinationUuid } = envelope; - const senderKeyStore = new SenderKeys({ ourUuid: destinationUuid }); const address = new QualifiedAddress( destinationUuid, Address.create(identifier, sourceDevice) @@ -2012,7 +2016,7 @@ export default class MessageReceiver processSenderKeyDistributionMessage( sender, senderKeyDistributionMessage, - senderKeyStore + stores.senderKeyStore ), stores.zone ); diff --git a/ts/textsecure/SendMessage.ts b/ts/textsecure/SendMessage.ts index 952f6dbc9..888895325 100644 --- a/ts/textsecure/SendMessage.ts +++ b/ts/textsecure/SendMessage.ts @@ -15,6 +15,7 @@ import { SenderKeyDistributionMessage, } from '@signalapp/signal-client'; +import { GLOBAL_ZONE } from '../SignalProtocolStore'; import { assert } from '../util/assert'; import { parseIntOrThrow } from '../util/parseIntOrThrow'; import { Address } from '../types/Address'; @@ -1874,7 +1875,7 @@ export default class MessageSender { ourUuid, new Address(ourUuid, ourDeviceId) ); - const senderKeyStore = new SenderKeys({ ourUuid }); + const senderKeyStore = new SenderKeys({ ourUuid, zone: GLOBAL_ZONE }); return window.textsecure.storage.protocol.enqueueSenderKeyJob( address, diff --git a/ts/util/Zone.ts b/ts/util/Zone.ts index 1f2c0ee64..7d0379856 100644 --- a/ts/util/Zone.ts +++ b/ts/util/Zone.ts @@ -2,6 +2,7 @@ // SPDX-License-Identifier: AGPL-3.0-only export type ZoneOptions = { + readonly pendingSenderKeys?: boolean; readonly pendingSessions?: boolean; readonly pendingUnprocessed?: boolean; }; @@ -12,6 +13,10 @@ export class Zone { private readonly options: ZoneOptions = {} ) {} + public supportsPendingSenderKeys(): boolean { + return this.options.pendingSenderKeys === true; + } + public supportsPendingSessions(): boolean { return this.options.pendingSessions === true; } diff --git a/ts/util/sendToGroup.ts b/ts/util/sendToGroup.ts index 8085a7424..dec1224bd 100644 --- a/ts/util/sendToGroup.ts +++ b/ts/util/sendToGroup.ts @@ -55,6 +55,7 @@ import * as RemoteConfig from '../RemoteConfig'; import { strictAssert } from './assert'; import * as log from '../logging/log'; +import { GLOBAL_ZONE } from '../SignalProtocolStore'; const ERROR_EXPIRED_OR_MISSING_DEVICES = 409; const ERROR_STALE_DEVICES = 410; @@ -861,7 +862,7 @@ async function encryptForSenderKey({ parseIntOrThrow(ourDeviceId, 'encryptForSenderKey, ourDeviceId') ); const ourAddress = getOurAddress(); - const senderKeyStore = new SenderKeys({ ourUuid }); + const senderKeyStore = new SenderKeys({ ourUuid, zone: GLOBAL_ZONE }); const message = Buffer.from(padMessage(contentMessage)); const ciphertextMessage =