Skip to content

Commit

Permalink
Merge pull request #128 from k1LoW/prepend
Browse files Browse the repository at this point in the history
Add Prepend() for prepending mathcher
  • Loading branch information
k1LoW authored Nov 24, 2024
2 parents 3dae733 + 2744325 commit 4e33db7
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 3 deletions.
24 changes: 21 additions & 3 deletions grpcstub.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ type Server struct {
healthCheck bool
disableReflection bool
status serverStatus
prependOnce bool
t TB
mu sync.RWMutex
}
Expand Down Expand Up @@ -326,7 +327,7 @@ func (s *Server) Match(fn func(req *Request) bool) *matcher {
}
s.mu.Lock()
defer s.mu.Unlock()
s.matchers = append(s.matchers, m)
s.addMatcher(m)
return m
}

Expand All @@ -347,7 +348,7 @@ func (s *Server) Service(service string) *matcher {
matchFuncs: []matchFunc{fn},
t: s.t,
}
s.matchers = append(s.matchers, m)
s.addMatcher(m)
return m
}

Expand Down Expand Up @@ -379,7 +380,7 @@ func (s *Server) Method(method string) *matcher {
matchFuncs: []matchFunc{fn},
t: s.t,
}
s.matchers = append(s.matchers, m)
s.addMatcher(m)
return m
}

Expand Down Expand Up @@ -517,6 +518,14 @@ func (s *Server) ClearMatchers() {
s.matchers = nil
}

// Prepend prepend matcher.
func (s *Server) Prepend() *Server {
s.mu.Lock()
defer s.mu.Unlock()
s.prependOnce = true
return s
}

// ClearRequests clear requests.
func (s *Server) ClearRequests() {
s.requests = nil
Expand All @@ -530,6 +539,15 @@ func (m *matcher) Requests() []*Request {
return m.requests
}

func (s *Server) addMatcher(m *matcher) {
if s.prependOnce {
s.matchers = append([]*matcher{m}, s.matchers...)
s.prependOnce = false
return
}
s.matchers = append(s.matchers, m)
}

func (s *Server) registerServer() {
for _, fd := range s.fds {
for i := 0; i < fd.Services().Len(); i++ {
Expand Down
50 changes: 50 additions & 0 deletions grpcstub_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -804,3 +804,53 @@ func TestUnmarshalProtoMessage(t *testing.T) {
t.Errorf("got %v\nwant %v", got, want)
}
}

func TestPrepend(t *testing.T) {
t.Run("Default", func(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Service("routeguide.RouteGuide").Response(map[string]any{"name": "hello"})
ts.Service("routeguide.RouteGuide").Response(map[string]any{"name": "world"})
ts.Service("routeguide.RouteGuide").Response(map[string]any{"name": "!!!"})

client := routeguide.NewRouteGuideClient(ts.Conn())
res, err := client.GetFeature(ctx, &routeguide.Point{
Latitude: 10,
Longitude: 13,
})
if err != nil {
t.Fatal(err)
}
got := res.Name
if want := "hello"; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
})

t.Run("Prepend", func(t *testing.T) {
ctx := context.Background()
ts := NewServer(t, "testdata/route_guide.proto")
t.Cleanup(func() {
ts.Close()
})
ts.Service("routeguide.RouteGuide").Response(map[string]any{"name": "hello"})
ts.Prepend().Service("routeguide.RouteGuide").Response(map[string]any{"name": "world"})
ts.Service("routeguide.RouteGuide").Response(map[string]any{"name": "!!!"})

client := routeguide.NewRouteGuideClient(ts.Conn())
res, err := client.GetFeature(ctx, &routeguide.Point{
Latitude: 10,
Longitude: 13,
})
if err != nil {
t.Fatal(err)
}
got := res.Name
if want := "world"; got != want {
t.Errorf("got %v\nwant %v", got, want)
}
})
}

0 comments on commit 4e33db7

Please sign in to comment.