diff --git a/api/etcd/list.go b/api/etcd/list.go index 186cdeb..c5bc82d 100644 --- a/api/etcd/list.go +++ b/api/etcd/list.go @@ -3,7 +3,6 @@ package etcd import ( "context" "encoding/base64" - "strconv" "github.com/coreos/etcd/clientv3" "github.com/heetch/regula/api" @@ -45,7 +44,7 @@ func (s *RulesetService) List(ctx context.Context, opt api.ListOptions) (*api.Ru } rulesets := api.Rulesets{ - Revision: strconv.FormatInt(resp.Header.Revision, 10), + Revision: resp.Header.Revision, } rulesets.Paths = make([]string, 0, len(resp.Kvs)) diff --git a/api/etcd/rulesets_test.go b/api/etcd/rulesets_test.go index e9cdd0a..03f8465 100644 --- a/api/etcd/rulesets_test.go +++ b/api/etcd/rulesets_test.go @@ -11,6 +11,7 @@ import ( "github.com/heetch/regula" "github.com/heetch/regula/api" "github.com/heetch/regula/rule" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -24,13 +25,11 @@ var ( endpoints = []string{"localhost:2379", "etcd:2379"} ) -func Init() { - rand.Seed(time.Now().UnixNano()) -} - func newEtcdRulesetService(t *testing.T) (*RulesetService, func()) { t.Helper() + rand.Seed(time.Now().UnixNano()) + cli, err := clientv3.New(clientv3.Config{ Endpoints: endpoints, DialTimeout: dialTimeout, @@ -49,18 +48,22 @@ func newEtcdRulesetService(t *testing.T) (*RulesetService, func()) { } func createRuleset(t *testing.T, s *RulesetService, path string, rules ...*rule.Rule) *regula.Ruleset { + t.Helper() + _, err := s.Put(context.Background(), path, rules) if err != nil && err != api.ErrRulesetNotModified { require.NoError(t, err) } rs, err := s.Get(context.Background(), path, "") - require.NoError(t, err) + assert.NoError(t, err) return rs } func createBoolRuleset(t *testing.T, s *RulesetService, path string, rules ...*rule.Rule) *regula.Ruleset { + t.Helper() + err := s.Create(context.Background(), path, ®ula.Signature{ReturnType: "bool"}) - require.False(t, err != nil && err != api.ErrAlreadyExists) + assert.False(t, err != nil && err != api.ErrAlreadyExists) return createRuleset(t, s, path, rules...) } diff --git a/api/etcd/watch.go b/api/etcd/watch.go index fc40783..86687ad 100644 --- a/api/etcd/watch.go +++ b/api/etcd/watch.go @@ -2,7 +2,7 @@ package etcd import ( "context" - "strconv" + "strings" "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/mvcc/mvccpb" @@ -12,39 +12,48 @@ import ( "github.com/pkg/errors" ) -// Watch the given prefix for anything new. -func (s *RulesetService) Watch(ctx context.Context, prefix string, revision string) (*api.RulesetEvents, error) { +// Watch a list of paths for changes and return a list of events. If paths is empty or nil, +// watch all paths. If the revision is negative, watch from the latest revision. +// This method blocks until there is a change in one of the paths or until the context is canceled. +// The given context can be used to limit the watch period or to cancel any running one. +func (s *RulesetService) Watch(ctx context.Context, opt api.WatchOptions) (*api.RulesetEvents, error) { ctx, cancel := context.WithCancel(ctx) defer cancel() + revision := opt.Revision + opts := []clientv3.OpOption{clientv3.WithPrefix()} - if i, _ := strconv.ParseInt(revision, 10, 64); i > 0 { + if revision > 0 { // watch from the next revision - opts = append(opts, clientv3.WithRev(i+1)) + opts = append(opts, clientv3.WithRev(revision+1)) } - events := api.RulesetEvents{ - Revision: revision, - } + var events api.RulesetEvents - wc := s.Client.Watch(ctx, s.rulesPath(prefix, ""), opts...) + wc := s.Client.Watch(ctx, s.rulesPath("", ""), opts...) for { select { case wresp := <-wc: if err := wresp.Err(); err != nil { - return nil, errors.Wrapf(err, "failed to watch prefix: '%s'", prefix) + return nil, errors.Wrapf(err, "failed to watch paths: '%#v'", opt.Paths) } + revision = wresp.Header.Revision + if len(wresp.Events) == 0 { continue } - list := make([]api.RulesetEvent, len(wresp.Events)) - for i, ev := range wresp.Events { - switch ev.Type { - case mvccpb.PUT: - list[i].Type = api.RulesetPutEvent - default: + var list []api.RulesetEvent + for _, ev := range wresp.Events { + // filter keys that haven't been selected + if !s.shouldIncludeEvent(ev, opt.Paths) { + s.Logger.Debug().Str("type", string(ev.Type)).Str("key", string(ev.Kv.Key)).Msg("watch: ignoring event key") + continue + } + + // filter event types, keep only PUT events + if ev.Type != mvccpb.PUT { s.Logger.Debug().Str("type", string(ev.Type)).Msg("watch: ignoring event type") continue } @@ -52,21 +61,54 @@ func (s *RulesetService) Watch(ctx context.Context, prefix string, revision stri var pbrs pb.Rules err := proto.Unmarshal(ev.Kv.Value, &pbrs) if err != nil { - s.Logger.Debug().Bytes("entry", ev.Kv.Value).Msg("watch: unmarshalling failed") - return nil, errors.Wrap(err, "failed to unmarshal entry") + s.Logger.Error().Bytes("entry", ev.Kv.Value).Msg("watch: unmarshalling failed, ignoring the event") + continue } + path, version := s.pathVersionFromKey(string(ev.Kv.Key)) - list[i].Path = path - list[i].Rules = rulesFromProtobuf(&pbrs) - list[i].Version = version + + list = append(list, api.RulesetEvent{ + Type: api.RulesetPutEvent, + Path: path, + Rules: rulesFromProtobuf(&pbrs), + Version: version, + }) + } + + // None of the events matched the user selection, so continue + // waiting for more. + if len(list) == 0 { + continue } events.Events = list - events.Revision = strconv.FormatInt(wresp.Header.Revision, 10) + events.Revision = revision return &events, nil case <-ctx.Done(): events.Timeout = true + // if we received events but ignored them + // this function will go on until the context is canceled. + // we need to return the latest received revision so the + // caller can start after the filtered events. + events.Revision = revision return &events, ctx.Err() } } } + +// shouldIncludeEvent reports whether the given event should be included +// in the Watch data for the given paths. +func (s *RulesetService) shouldIncludeEvent(ev *clientv3.Event, paths []string) bool { + // detect if the event key is found in the paths list + // or that the paths list is empty + key := string(ev.Kv.Key) + key = key[:strings.Index(key, versionSeparator)] + ok := len(paths) == 0 + for i := 0; i < len(paths) && !ok; i++ { + if key == s.rulesPath(paths[i], "") { + ok = true + } + } + + return ok +} diff --git a/api/etcd/watch_test.go b/api/etcd/watch_test.go index 565f05a..75cf62c 100644 --- a/api/etcd/watch_test.go +++ b/api/etcd/watch_test.go @@ -8,55 +8,93 @@ import ( "github.com/heetch/regula/api" "github.com/heetch/regula/rule" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestWatch(t *testing.T) { t.Parallel() - s, cleanup := newEtcdRulesetService(t) - defer cleanup() + tests := []struct { + name string + paths []string + expected []string + }{ + {"no paths", nil, []string{"a", "b", "c"}}, + {"existing paths", []string{"a", "c"}, []string{"a", "c"}}, + } - var wg sync.WaitGroup + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s, cleanup := newEtcdRulesetService(t) + defer cleanup() - wg.Add(1) - go func() { - defer wg.Done() + var wg sync.WaitGroup - time.Sleep(time.Second) + wg.Add(1) + go func() { + defer wg.Done() - r := rule.New(rule.True(), rule.BoolValue(true)) + // wait enought time so that the other goroutine had the time to run the watch method + // before writing data to the database. + time.Sleep(time.Second) - createBoolRuleset(t, s, "aa", r) - createBoolRuleset(t, s, "ab", r) - createBoolRuleset(t, s, "a/1", r) - }() + r := rule.New(rule.True(), rule.BoolValue(true)) - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + createBoolRuleset(t, s, "a", r) + createBoolRuleset(t, s, "b", r) + createBoolRuleset(t, s, "c", r) + }() - events, err := s.Watch(ctx, "a", "") - require.NoError(t, err) - require.Len(t, events.Events, 1) - require.NotEmpty(t, events.Revision) - require.Equal(t, "aa", events.Events[0].Path) - require.Equal(t, api.RulesetPutEvent, events.Events[0].Type) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() - wg.Wait() + var events api.RulesetEvents + var rev int64 + var watchCount int + for len(events.Events) != len(test.expected) && watchCount < 4 { + evs, err := s.Watch(ctx, api.WatchOptions{Paths: test.paths, Revision: rev}) + if err != nil { + if err != nil { + if err == context.DeadlineExceeded { + t.Errorf("timed out waiting for expected events") + } else { + t.Errorf("unexpected error from watcher: %v", err) + } + break + } + break + } + assert.True(t, len(evs.Events) > 0) + assert.NotEmpty(t, evs.Revision) + rev = evs.Revision + events.Events = append(events.Events, evs.Events...) + watchCount++ + } - events, err = s.Watch(ctx, "a", events.Revision) - require.NoError(t, err) - require.Len(t, events.Events, 2) - require.NotEmpty(t, events.Revision) - require.Equal(t, api.RulesetPutEvent, events.Events[0].Type) - require.Equal(t, "ab", events.Events[0].Path) - require.Equal(t, api.RulesetPutEvent, events.Events[1].Type) - require.Equal(t, "a/1", events.Events[1].Path) + wg.Wait() + + var foundCount int + for _, ev := range events.Events { + for _, p := range test.expected { + if ev.Path == p { + foundCount++ + break + } + } + } + require.Equal(t, len(test.expected), foundCount) + }) + + } t.Run("timeout", func(t *testing.T) { + s, cleanup := newEtcdRulesetService(t) + defer cleanup() + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() - events, err := s.Watch(ctx, "", "") + events, err := s.Watch(ctx, api.WatchOptions{}) require.Equal(t, context.DeadlineExceeded, err) require.True(t, events.Timeout) }) diff --git a/api/service.go b/api/service.go index fa2047b..dbcfb98 100644 --- a/api/service.go +++ b/api/service.go @@ -29,8 +29,9 @@ type RulesetService interface { // List returns the list of all rulesets paths. // The listing is paginated and can be customised using the ListOptions type. List(ctx context.Context, opt ListOptions) (*Rulesets, error) - // Watch a prefix for changes and return a list of events. - Watch(ctx context.Context, prefix string, revision string) (*RulesetEvents, error) + // Watch a list of paths for changes and return a list of events. + // The watcher can be customized using the WatchOption type. + Watch(ctx context.Context, opt WatchOptions) (*RulesetEvents, error) // Eval evaluates a ruleset given a path and a set of parameters. It implements the regula.Evaluator interface. Eval(ctx context.Context, path, version string, params rule.Params) (*regula.EvalResult, error) } @@ -55,10 +56,21 @@ func (l *ListOptions) GetLimit() int { return l.Limit } +// WatchOptions gives indications on what rulesets to watch. +type WatchOptions struct { + // List of paths to watch for changes. + // If the slice is empty, watch all paths. + Paths []string + // Indicates from which revision start watching. + // Any event happened after that revision is returned. + // If the revision is zero or negative, watch from the latest revision. + Revision int64 +} + // Rulesets holds a list of rulesets. type Rulesets struct { Paths []string `json:"paths"` - Revision string `json:"revision"` // revision when the request was applied + Revision int64 `json:"revision"` // revision when the request was applied Cursor string `json:"cursor,omitempty"` // cursor of the next page, if any } @@ -77,7 +89,7 @@ type RulesetEvent struct { // RulesetEvents holds a list of events occured on a group of rulesets. type RulesetEvents struct { - Events []RulesetEvent - Revision string - Timeout bool // indicates if the watch did timeout + Events []RulesetEvent `json:"events"` + Revision int64 `json:"revision"` + Timeout bool `json:"timeout"` // indicates if the watch did timeout } diff --git a/http/server/api.go b/http/server/api.go index 9e07f6e..3c69f27 100644 --- a/http/server/api.go +++ b/http/server/api.go @@ -6,7 +6,6 @@ import ( "fmt" "net/http" "strconv" - "strings" "time" "github.com/heetch/regula" @@ -18,38 +17,47 @@ import ( ) type rulesetAPI struct { - rulesets api.RulesetService - timeout time.Duration + rulesets api.RulesetService + // Timeout specific for the watch requests. watchTimeout time.Duration + // Context used by the Watch endpoint. Used to cancel long running + // watch requests when the server is shutting down. + watchCancelCtx context.Context } func (s *rulesetAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { - path := strings.TrimPrefix(r.URL.Path, "/rulesets/") - - if _, ok := r.URL.Query()["watch"]; ok && r.Method == "GET" { - ctx, cancel := context.WithTimeout(r.Context(), s.watchTimeout) - defer cancel() - s.watch(w, r.WithContext(ctx), path) - return - } - - ctx, cancel := context.WithTimeout(r.Context(), s.timeout) - defer cancel() - r = r.WithContext(ctx) + path := r.URL.Path + r.ParseForm() switch r.Method { case "GET": - if _, ok := r.URL.Query()["list"]; ok { + if len(r.Form["list"]) > 0 { + if len(path) != 0 { + w.WriteHeader(http.StatusNotFound) + return + } s.list(w, r) return } - if _, ok := r.URL.Query()["eval"]; ok { + if len(r.Form["eval"]) > 0 { s.eval(w, r, path) return } s.get(w, r, path) return case "POST": + if len(r.Form["watch"]) > 0 { + if len(path) != 0 { + w.WriteHeader(http.StatusNotFound) + return + } + + ctx, cancel := context.WithTimeout(s.watchCancelCtx, s.watchTimeout) + defer cancel() + s.watch(w, r.WithContext(ctx)) + return + } + s.create(w, r, path) return case "PUT": @@ -143,7 +151,7 @@ func (s *rulesetAPI) list(w http.ResponseWriter, r *http.Request) { return } - reghttp.EncodeJSON(w, r, (*api.Rulesets)(rulesets), http.StatusOK) + reghttp.EncodeJSON(w, r, rulesets, http.StatusOK) } func (s *rulesetAPI) eval(w http.ResponseWriter, r *http.Request, path string) { @@ -176,16 +184,33 @@ func (s *rulesetAPI) eval(w http.ResponseWriter, r *http.Request, path string) { reghttp.EncodeJSON(w, r, res, http.StatusOK) } -// watch watches a prefix for change and returns anything newer. -func (s *rulesetAPI) watch(w http.ResponseWriter, r *http.Request, prefix string) { - events, err := s.rulesets.Watch(r.Context(), prefix, r.URL.Query().Get("revision")) +// watch is a long polling endpoint that watches a list of paths for change and returns a list of events containing all the changes +// that happened since the start of the watch. +// If the revision query param is specified, it returns anything that happened after that revision. +// If no paths are specificied, it watches any path. +// The request context can be used to limit the watch period or to cancel any running one. +func (s *rulesetAPI) watch(w http.ResponseWriter, r *http.Request) { + var wo api.WatchOptions + + if r.ContentLength > 0 { + // There's a non-empty body, which means that the + // client has specified a set of paths to watch. + err := json.NewDecoder(r.Body).Decode(&wo) + if err != nil { + writeError(w, r, err, http.StatusBadRequest) + return + } + } + + events, err := s.rulesets.Watch(r.Context(), wo) if err != nil { switch err { - case context.Canceled, context.DeadlineExceeded: - // we do nothing - case api.ErrRulesetNotFound: - w.WriteHeader(http.StatusNotFound) - return + case context.Canceled: + // server is probably shutting down + // we do nothing and return a 200 to the client + case context.DeadlineExceeded: + // the watch request reached the deadline + // we do nothing and return a 200 to the client default: writeError(w, r, err, http.StatusInternalServerError) return diff --git a/http/server/api_test.go b/http/server/api_test.go index edc13fe..979e6c5 100644 --- a/http/server/api_test.go +++ b/http/server/api_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "net/url" "strconv" + "strings" "testing" "time" @@ -106,7 +107,7 @@ func TestServerList(t *testing.T) { rss := api.Rulesets{ Paths: []string{"aa", "bb"}, - Revision: "somerev", + Revision: 10, Cursor: "somecursor", } @@ -119,11 +120,12 @@ func TestServerList(t *testing.T) { err error }{ {"OK", "/rulesets/?list", http.StatusOK, &rss, api.ListOptions{}, nil}, - {"WithLimitAndCursor", "/rulesets/a?list&limit=10&cursor=abc123", http.StatusOK, &rss, api.ListOptions{Limit: 10, Cursor: "abc123"}, nil}, + {"WithLimitAndCursor", "/rulesets/?list&limit=10&cursor=abc123", http.StatusOK, &rss, api.ListOptions{Limit: 10, Cursor: "abc123"}, nil}, {"NoResult", "/rulesets/?list", http.StatusOK, new(api.Rulesets), api.ListOptions{}, nil}, - {"InvalidCursor", "/rulesets/someprefix?list&cursor=abc123", http.StatusBadRequest, new(api.Rulesets), api.ListOptions{Cursor: "abc123"}, api.ErrInvalidCursor}, - {"UnexpectedError", "/rulesets/someprefix?list", http.StatusInternalServerError, new(api.Rulesets), api.ListOptions{}, errors.New("unexpected error")}, - {"InvalidLimit", "/rulesets/someprefix?list&limit=badlimit", http.StatusBadRequest, nil, api.ListOptions{}, nil}, + {"InvalidCursor", "/rulesets/?list&cursor=abc123", http.StatusBadRequest, new(api.Rulesets), api.ListOptions{Cursor: "abc123"}, api.ErrInvalidCursor}, + {"UnexpectedError", "/rulesets/?list", http.StatusInternalServerError, new(api.Rulesets), api.ListOptions{}, errors.New("unexpected error")}, + {"InvalidLimit", "/rulesets/?list&limit=badlimit", http.StatusBadRequest, nil, api.ListOptions{}, nil}, + {"WithPath", "/rulesets/some/path?list&limit=badlimit", http.StatusNotFound, nil, api.ListOptions{}, nil}, } for _, test := range tests { @@ -249,70 +251,60 @@ func TestServerWatch(t *testing.T) { {Type: api.RulesetPutEvent, Path: "b", Rules: r2}, {Type: api.RulesetPutEvent, Path: "a", Rules: r2}, }, - Revision: "rev", + Revision: 10, } tests := []struct { name string path string + body string status int - es *api.RulesetEvents - err error + fn func(context.Context, api.WatchOptions) (*api.RulesetEvents, error) }{ - {"Root", "/rulesets/?watch", http.StatusOK, &l, nil}, - {"WithPrefix", "/rulesets/a?watch", http.StatusOK, &l, nil}, - {"NotFound", "/rulesets/a?watch", http.StatusNotFound, &l, api.ErrRulesetNotFound}, - {"Timeout", "/rulesets/?watch", http.StatusOK, nil, context.DeadlineExceeded}, - {"ContextCanceled", "/rulesets/?watch", http.StatusOK, nil, context.Canceled}, + {"No paths", "/rulesets/?watch", "", http.StatusOK, func(context.Context, api.WatchOptions) (*api.RulesetEvents, error) { return &l, nil }}, + {"With paths", "/rulesets/?watch", `{"paths": ["a", "b", "c"]}`, http.StatusOK, func(ctx context.Context, opt api.WatchOptions) (*api.RulesetEvents, error) { + require.EqualValues(t, 0, opt.Revision) + require.Equal(t, []string{"a", "b", "c"}, opt.Paths) + return &l, nil + }}, + {"Bad JSON", "/rulesets/?watch", `["a`, http.StatusBadRequest, func(ctx context.Context, opt api.WatchOptions) (*api.RulesetEvents, error) { return nil, nil }}, + {"Timeout", "/rulesets/?watch", "", http.StatusOK, func(context.Context, api.WatchOptions) (*api.RulesetEvents, error) { + return nil, context.DeadlineExceeded + }}, + {"ContextCanceled", "/rulesets/?watch", "", http.StatusOK, func(context.Context, api.WatchOptions) (*api.RulesetEvents, error) { return nil, context.Canceled }}, + {"WithRevision", "/rulesets/?watch", `{"revision": 100}`, http.StatusOK, func(ctx context.Context, opt api.WatchOptions) (*api.RulesetEvents, error) { + require.EqualValues(t, 100, opt.Revision) + return &l, nil + }}, + {"WithBadRevision", "/rulesets/?watch", `{"revision": "somerev"}`, http.StatusBadRequest, func(ctx context.Context, opt api.WatchOptions) (*api.RulesetEvents, error) { return &l, nil }}, } for _, test := range tests { - s.WatchFn = func(context.Context, string, string) (*api.RulesetEvents, error) { - return test.es, test.err - } - defer func() { s.WatchFn = nil }() + t.Run(test.name, func(t *testing.T) { + s.WatchFn = test.fn + defer func() { s.WatchFn = nil }() - w := httptest.NewRecorder() - r := httptest.NewRequest("GET", test.path, nil) - h.ServeHTTP(w, r) + w := httptest.NewRecorder() + r := httptest.NewRequest("POST", test.path, strings.NewReader(test.body)) + h.ServeHTTP(w, r) - require.Equal(t, test.status, w.Code) + require.Equal(t, test.status, w.Code) - if test.status == http.StatusOK { + if test.status != http.StatusOK { + return + } var res api.RulesetEvents + require.True(t, w.Body.Len() > 0) err := json.NewDecoder(w.Body).Decode(&res) require.NoError(t, err) - if test.es != nil { - require.Equal(t, len(test.es.Events), len(res.Events)) + if len(res.Events) > 0 { + require.Equal(t, len(l.Events), len(res.Events)) for i := range l.Events { require.Equal(t, l.Events[i], res.Events[i]) } } - } + }) } - - t.Run("WithRevision", func(t *testing.T) { - s.WatchFn = func(ctx context.Context, prefix string, revision string) (*api.RulesetEvents, error) { - require.Equal(t, "a", prefix) - require.Equal(t, "somerev", revision) - return &l, nil - } - defer func() { s.WatchFn = nil }() - - w := httptest.NewRecorder() - r := httptest.NewRequest("GET", "/rulesets/a?watch&revision=somerev", nil) - h.ServeHTTP(w, r) - - require.Equal(t, http.StatusOK, w.Code) - - var res api.RulesetEvents - err := json.NewDecoder(w.Body).Decode(&res) - require.NoError(t, err) - require.Equal(t, len(l.Events), len(res.Events)) - for i := range l.Events { - require.Equal(t, l.Events[i], res.Events[i]) - } - }) } func TestServerPut(t *testing.T) { diff --git a/http/server/handler.go b/http/server/handler.go index 2cd42a2..127b23b 100644 --- a/http/server/handler.go +++ b/http/server/handler.go @@ -37,21 +37,18 @@ func NewHandler(rsService api.RulesetService, cfg Config) http.Handler { } rulesetsHandler := rulesetAPI{ - rulesets: rsService, - timeout: cfg.Timeout, - watchTimeout: cfg.WatchTimeout, + rulesets: rsService, + watchTimeout: cfg.WatchTimeout, + watchCancelCtx: cfg.WatchCancelCtx, } // router mux := http.NewServeMux() - mux.HandleFunc("/rulesets/", func(w http.ResponseWriter, r *http.Request) { - if _, ok := r.URL.Query()["watch"]; ok && r.Method == "GET" { - rulesetsHandler.ServeHTTP(w, r.WithContext(cfg.WatchCancelCtx)) - return - } - - rulesetsHandler.ServeHTTP(w, r) - }) + mux.Handle("/rulesets/", http.StripPrefix("/rulesets/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), cfg.Timeout) + defer cancel() + rulesetsHandler.ServeHTTP(w, r.WithContext(ctx)) + }))) return mux } diff --git a/mock/store.go b/mock/store.go index 476cfe0..dfccc35 100644 --- a/mock/store.go +++ b/mock/store.go @@ -20,7 +20,7 @@ type RulesetService struct { ListCount int ListFn func(context.Context, api.ListOptions) (*api.Rulesets, error) WatchCount int - WatchFn func(context.Context, string, string) (*api.RulesetEvents, error) + WatchFn func(context.Context, api.WatchOptions) (*api.RulesetEvents, error) PutCount int PutFn func(context.Context, string, []*rule.Rule) (string, error) EvalCount int @@ -61,11 +61,11 @@ func (s *RulesetService) List(ctx context.Context, opt api.ListOptions) (*api.Ru } // Watch runs WatchFn if provided and increments WatchCount when invoked. -func (s *RulesetService) Watch(ctx context.Context, prefix, revision string) (*api.RulesetEvents, error) { +func (s *RulesetService) Watch(ctx context.Context, opt api.WatchOptions) (*api.RulesetEvents, error) { s.WatchCount++ if s.WatchFn != nil { - return s.WatchFn(ctx, prefix, revision) + return s.WatchFn(ctx, opt) } return nil, nil