diff --git a/network/p2p/acp118/handler.go b/network/p2p/acp118/handler.go index fd7ac8dc4f36..971f29af9b82 100644 --- a/network/p2p/acp118/handler.go +++ b/network/p2p/acp118/handler.go @@ -39,16 +39,16 @@ func NewHandler(verifier Verifier, signer warp.Signer) *Handler { } // NewCachedHandler returns an instance of Handler that caches successful -// requests. +// signatures. func NewCachedHandler( cacher cache.Cacher[ids.ID, []byte], verifier Verifier, signer warp.Signer, ) *Handler { return &Handler{ - cacher: cacher, - verifier: verifier, - signer: signer, + signatureCache: cacher, + verifier: verifier, + signer: signer, } } @@ -56,9 +56,9 @@ func NewCachedHandler( type Handler struct { p2p.NoOpHandler - cacher cache.Cacher[ids.ID, []byte] - verifier Verifier - signer warp.Signer + signatureCache cache.Cacher[ids.ID, []byte] + verifier Verifier + signer warp.Signer } func (h *Handler) AppRequest( @@ -84,8 +84,8 @@ func (h *Handler) AppRequest( } msgID := msg.ID() - if responseBytes, ok := h.cacher.Get(msgID); ok { - return responseBytes, nil + if signatureBytes, ok := h.signatureCache.Get(msgID); ok { + return signatureToResponse(signatureBytes) } if err := h.verifier.Verify(ctx, msg, request.Justification); err != nil { @@ -100,6 +100,11 @@ func (h *Handler) AppRequest( } } + h.signatureCache.Put(msgID, signature) + return signatureToResponse(signature) +} + +func signatureToResponse(signature []byte) ([]byte, *common.AppError) { response := &sdk.SignatureResponse{ Signature: signature, } @@ -111,7 +116,5 @@ func (h *Handler) AppRequest( Message: fmt.Sprintf("failed to marshal response: %s", err), } } - - h.cacher.Put(msgID, responseBytes) return responseBytes, nil } diff --git a/network/p2p/acp118/handler_test.go b/network/p2p/acp118/handler_test.go index 081ecd1f0351..e58d61a8f6a0 100644 --- a/network/p2p/acp118/handler_test.go +++ b/network/p2p/acp118/handler_test.go @@ -125,6 +125,12 @@ func TestHandler(t *testing.T) { require.NoError(err) require.True(bls.Verify(pk, signature, request.Message)) + + // Ensure the cache is populated with correct signature + sig, ok := tt.cacher.Get(unsignedMessage.ID()) + if ok { + require.Equal(sig, response.Signature) + } } for _, expectedErr = range tt.expectedErrs {