Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement strict decoding for JetStream API requests #5858

Merged
merged 7 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions server/jetstream.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ type JetStreamConfig struct {
Domain string `json:"domain,omitempty"`
CompressOK bool `json:"compress_ok,omitempty"`
UniqueTag string `json:"unique_tag,omitempty"`
Strict bool `json:"strict,omitempty"`
}

// Statistics about JetStream for this server.
Expand Down Expand Up @@ -462,6 +463,11 @@ func (s *Server) enableJetStream(cfg JetStreamConfig) error {
s.Noticef("")
}
s.Noticef("---------------- JETSTREAM ----------------")

if cfg.Strict {
s.Noticef(" Strict: %t", cfg.Strict)
}

s.Noticef(" Max Memory: %s", friendlyBytes(cfg.MaxMemory))
s.Noticef(" Max Storage: %s", friendlyBytes(cfg.MaxStore))
s.Noticef(" Store Directory: \"%s\"", cfg.StoreDir)
Expand Down Expand Up @@ -554,6 +560,7 @@ func (s *Server) restartJetStream() error {
MaxMemory: opts.JetStreamMaxMemory,
MaxStore: opts.JetStreamMaxStore,
Domain: opts.JetStreamDomain,
Strict: opts.JetStreamStrict,
}
s.Noticef("Restarting JetStream")
err := s.EnableJetStream(&cfg)
Expand Down Expand Up @@ -2527,6 +2534,9 @@ func (s *Server) dynJetStreamConfig(storeDir string, maxStore, maxMem int64) *Je

opts := s.getOpts()

// Strict mode.
jsc.Strict = opts.JetStreamStrict

// Sync options.
jsc.SyncInterval = opts.SyncInterval
jsc.SyncAlways = opts.SyncAlways
Expand Down
71 changes: 49 additions & 22 deletions server/jetstream_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"os"
"path/filepath"
Expand Down Expand Up @@ -1066,6 +1067,32 @@ func (s *Server) getRequestInfo(c *client, raw []byte) (pci *ClientInfo, acc *Ac
return &ci, acc, hdr, msg, nil
}

func (s *Server) unmarshalRequest(c *client, acc *Account, subject string, msg []byte, v any) error {
decoder := json.NewDecoder(bytes.NewReader(msg))
decoder.DisallowUnknownFields()

for {
if err := decoder.Decode(v); err != nil {
if err == io.EOF {
return nil
}

var syntaxErr *json.SyntaxError
if errors.As(err, &syntaxErr) {
err = fmt.Errorf("%w at offset %d", err, syntaxErr.Offset)
}

c.RateLimitWarnf("Invalid JetStream request '%s > %s': %s", acc, subject, err)

if s.JetStreamConfig().Strict {
return err
caspervonb marked this conversation as resolved.
Show resolved Hide resolved
}

return json.Unmarshal(msg, v)
derekcollison marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

func (a *Account) trackAPI() {
a.mu.RLock()
jsa := a.js
Expand Down Expand Up @@ -1195,7 +1222,7 @@ func (s *Server) jsTemplateCreateRequest(sub *subscription, c *client, _ *Accoun
}

var cfg StreamTemplateConfig
if err := json.Unmarshal(msg, &cfg); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1252,7 +1279,7 @@ func (s *Server) jsTemplateNamesRequest(sub *subscription, c *client, _ *Account
var offset int
if isJSONObjectOrArray(msg) {
var req JSApiStreamTemplatesRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1435,7 +1462,7 @@ func (s *Server) jsStreamCreateRequest(sub *subscription, c *client, _ *Account,
}

var cfg StreamConfigRequest
if err := json.Unmarshal(msg, &cfg); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &cfg); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1546,7 +1573,7 @@ func (s *Server) jsStreamUpdateRequest(sub *subscription, c *client, _ *Account,
return
}
var ncfg StreamConfigRequest
if err := json.Unmarshal(msg, &ncfg); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &ncfg); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1645,7 +1672,7 @@ func (s *Server) jsStreamNamesRequest(sub *subscription, c *client, _ *Account,

if isJSONObjectOrArray(msg) {
var req JSApiStreamNamesRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1775,7 +1802,7 @@ func (s *Server) jsStreamListRequest(sub *subscription, c *client, _ *Account, s

if isJSONObjectOrArray(msg) {
var req JSApiStreamListRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -1945,7 +1972,7 @@ func (s *Server) jsStreamInfoRequest(sub *subscription, c *client, a *Account, s
var offset int
if isJSONObjectOrArray(msg) {
var req JSApiStreamInfoRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -2302,7 +2329,7 @@ func (s *Server) jsStreamRemovePeerRequest(sub *subscription, c *client, _ *Acco
}

var req JSApiStreamRemovePeerRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -2382,7 +2409,7 @@ func (s *Server) jsLeaderServerRemoveRequest(sub *subscription, c *client, _ *Ac
}

var req JSApiMetaServerRemoveRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -2485,7 +2512,7 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _
var resp = JSApiStreamUpdateResponse{ApiResponse: ApiResponse{Type: JSApiStreamUpdateResponseType}}

var req JSApiMetaServerStreamMoveRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand All @@ -2512,7 +2539,7 @@ func (s *Server) jsLeaderServerStreamMoveRequest(sub *subscription, c *client, _
if ok {
sa, ok := streams[streamName]
if ok {
cfg = *sa.Config.clone()
cfg = *sa.Config
streamFound = true
currPeers = sa.Group.Peers
currCluster = sa.Group.Cluster
Expand Down Expand Up @@ -2654,7 +2681,7 @@ func (s *Server) jsLeaderServerStreamCancelMoveRequest(sub *subscription, c *cli
if ok {
sa, ok := streams[streamName]
if ok {
cfg = *sa.Config.clone()
cfg = *sa.Config
streamFound = true
currPeers = sa.Group.Peers
}
Expand Down Expand Up @@ -2835,7 +2862,7 @@ func (s *Server) jsLeaderStepDownRequest(sub *subscription, c *client, _ *Accoun

if isJSONObjectOrArray(msg) {
var req JSApiLeaderStepdownRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3062,7 +3089,7 @@ func (s *Server) jsMsgDeleteRequest(sub *subscription, c *client, _ *Account, su
return
}
var req JSApiMsgDeleteRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3181,7 +3208,7 @@ func (s *Server) jsMsgGetRequest(sub *subscription, c *client, _ *Account, subje
return
}
var req JSApiMsgGetRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3324,7 +3351,7 @@ func (s *Server) jsStreamPurgeRequest(sub *subscription, c *client, _ *Account,
var purgeRequest *JSApiStreamPurgeRequest
if isJSONObjectOrArray(msg) {
var req JSApiStreamPurgeRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3414,7 +3441,7 @@ func (s *Server) jsStreamRestoreRequest(sub *subscription, c *client, _ *Account
}

var req JSApiStreamRestoreRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3717,7 +3744,7 @@ func (s *Server) jsStreamSnapshotRequest(sub *subscription, c *client, _ *Accoun
}

var req JSApiStreamSnapshotRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, smsg, s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -3915,7 +3942,7 @@ func (s *Server) jsConsumerCreateRequest(sub *subscription, c *client, a *Accoun
var resp = JSApiConsumerCreateResponse{ApiResponse: ApiResponse{Type: JSApiConsumerCreateResponseType}}

var req CreateConsumerRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -4157,7 +4184,7 @@ func (s *Server) jsConsumerNamesRequest(sub *subscription, c *client, _ *Account
var offset int
if isJSONObjectOrArray(msg) {
var req JSApiConsumersRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -4279,7 +4306,7 @@ func (s *Server) jsConsumerListRequest(sub *subscription, c *client, _ *Account,
var offset int
if isJSONObjectOrArray(msg) {
var req JSApiConsumersRequest
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down Expand Up @@ -4590,7 +4617,7 @@ func (s *Server) jsConsumerPauseRequest(sub *subscription, c *client, _ *Account
var resp = JSApiConsumerPauseResponse{ApiResponse: ApiResponse{Type: JSApiConsumerPauseResponseType}}

if isJSONObjectOrArray(msg) {
if err := json.Unmarshal(msg, &req); err != nil {
if err := s.unmarshalRequest(c, acc, subject, msg, &req); err != nil {
resp.Error = NewJSInvalidJSONError(err)
s.sendAPIErrResponse(ci, acc, subject, reply, string(msg), s.jsonResponse(&resp))
return
Expand Down
101 changes: 101 additions & 0 deletions server/jetstream_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24183,6 +24183,107 @@ func TestJetStreamStreamCreatePedanticMode(t *testing.T) {
}
}

func TestJetStreamStrictMode(t *testing.T) {
cfgFmt := []byte(fmt.Sprintf(`
jetstream: {
strict: true
enabled: true
max_file_store: 100MB
store_dir: %s
limits: {duplicate_window: "1m", max_request_batch: 250}
}
`, t.TempDir()))
conf := createConfFile(t, cfgFmt)
s, _ := RunServerWithConfig(conf)
defer s.Shutdown()

nc, err := nats.Connect(s.ClientURL())
if err != nil {
t.Fatalf("Error connecting to NATS: %v", err)
}
defer nc.Close()

tests := []struct {
name string
subject string
payload []byte
expectedErr string
}{
{
name: "Stream Create",
subject: "$JS.API.STREAM.CREATE.TEST_STREAM",
payload: []byte(`{"name":"TEST_STREAM","subjects":["test.>"],"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Stream Update",
subject: "$JS.API.STREAM.UPDATE.TEST_STREAM",
payload: []byte(`{"name":"TEST_STREAM","subjects":["test.>"],"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Stream Delete",
subject: "$JS.API.STREAM.DELETE.TEST_STREAM",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "expected an empty request payload",
},
{
name: "Stream Info",
subject: "$JS.API.STREAM.INFO.TEST_STREAM",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Consumer Create",
subject: "$JS.API.CONSUMER.CREATE.TEST_STREAM.TEST_CONSUMER",
payload: []byte(`{"durable_name":"TEST_CONSUMER","ack_policy":"explicit","extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Consumer Delete",
subject: "$JS.API.CONSUMER.DELETE.TEST_STREAM.TEST_CONSUMER",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "expected an empty request payload",
},
{
name: "Consumer Info",
subject: "$JS.API.CONSUMER.INFO.TEST_STREAM.TEST_CONSUMER",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "expected an empty request payload",
},
{
name: "Stream List",
subject: "$JS.API.STREAM.LIST",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
{
name: "Consumer List",
subject: "$JS.API.CONSUMER.LIST.TEST_STREAM",
payload: []byte(`{"extra_field":"unexpected"}`),
expectedErr: "invalid JSON",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := nc.Request(tt.subject, tt.payload, time.Second*10)
if err != nil {
t.Fatalf("Request failed: %v", err)
}

var apiResp ApiResponse = ApiResponse{}

if err := json.Unmarshal(resp.Data, &apiResp); err != nil {
t.Fatalf("Error unmarshalling response: %v", err)
}

require_NotNil(t, apiResp.Error.Description)
require_Contains(t, apiResp.Error.Description, tt.expectedErr)
})
}
}

func addConsumerWithError(t *testing.T, nc *nats.Conn, cfg *CreateConsumerRequest) (*ConsumerInfo, *ApiError) {
t.Helper()
req, err := json.Marshal(cfg)
Expand Down
Loading