Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Return the current OTK count on an empty upload request #1774

Merged
merged 9 commits into from
Mar 2, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion clientapi/routing/keys.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ func UploadKeys(req *http.Request, keyAPI api.KeyInternalAPI, device *userapi.De
return *resErr
}

uploadReq := &api.PerformUploadKeysRequest{}
uploadReq := &api.PerformUploadKeysRequest{
DeviceID: device.ID,
UserID: device.UserID,
}
if r.DeviceKeys != nil {
uploadReq.DeviceKeys = []api.DeviceKeys{
{
Expand Down
2 changes: 2 additions & 0 deletions keyserver/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ type OneTimeKeysCount struct {

// PerformUploadKeysRequest is the request to PerformUploadKeys
type PerformUploadKeysRequest struct {
UserID string // User performing the request
DeviceID string // Device performing the request
DeviceKeys []DeviceKeys
OneTimeKeys []OneTimeKeys
// OnlyDisplayNameUpdates should be `true` if ALL the DeviceKeys are present to update
Expand Down
24 changes: 18 additions & 6 deletions keyserver/internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -513,6 +513,18 @@ func (a *KeyInternalAPI) uploadLocalDeviceKeys(ctx context.Context, req *api.Per
}

func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.PerformUploadKeysRequest, res *api.PerformUploadKeysResponse) {
if len(req.OneTimeKeys) == 0 {
neilalexander marked this conversation as resolved.
Show resolved Hide resolved
counts, err := a.DB.OneTimeKeysCount(ctx, req.UserID, req.DeviceID)
if err != nil {
res.Error = &api.KeyError{
Err: fmt.Sprintf("a.DB.OneTimeKeysCount: %s", err),
}
}
if counts != nil {
res.OneTimeKeyCounts = append(res.OneTimeKeyCounts, *counts)
}
return
}
for _, key := range req.OneTimeKeys {
// grab existing keys based on (user/device/algorithm/key ID)
keyIDsWithAlgorithms := make([]string, len(key.KeyJSON))
Expand All @@ -521,27 +533,27 @@ func (a *KeyInternalAPI) uploadOneTimeKeys(ctx context.Context, req *api.Perform
keyIDsWithAlgorithms[i] = keyIDWithAlgo
i++
}
existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, key.UserID, key.DeviceID, keyIDsWithAlgorithms)
existingKeys, err := a.DB.ExistingOneTimeKeys(ctx, req.UserID, req.DeviceID, keyIDsWithAlgorithms)
if err != nil {
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: "failed to query existing one-time keys: " + err.Error(),
})
continue
}
for keyIDWithAlgo := range existingKeys {
// if keys exist and the JSON doesn't match, error out as the key already exists
if !bytes.Equal(existingKeys[keyIDWithAlgo], key.KeyJSON[keyIDWithAlgo]) {
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", key.UserID, key.DeviceID, keyIDWithAlgo),
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s: algorithm / key ID %s one-time key already exists", req.UserID, req.DeviceID, keyIDWithAlgo),
})
continue
}
}
// store one-time keys
counts, err := a.DB.StoreOneTimeKeys(ctx, key)
if err != nil {
res.KeyError(key.UserID, key.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", key.UserID, key.DeviceID, err.Error()),
res.KeyError(req.UserID, req.DeviceID, &api.KeyError{
Err: fmt.Sprintf("%s device %s : failed to store one-time keys: %s", req.UserID, req.DeviceID, err.Error()),
})
continue
}
Expand Down