diff --git a/node.go b/node.go index d1713dd..fd99ebc 100644 --- a/node.go +++ b/node.go @@ -143,9 +143,7 @@ func (n *node) _findRoute(meth, path string) (*node, *routeHandler, int) { if handler != nil { return node, handler, wildcardLen } - if node != nil { - found = node - } + found = node } } } @@ -154,9 +152,12 @@ func (n *node) _findRoute(meth, path string) (*node, *routeHandler, int) { if n.colon != nil { if i := strings.IndexByte(path, '/'); i > 0 { node, handler, wildcardLen := n.colon._findRoute(meth, path[i:]) - if handler != nil { + if node != nil && handler != nil { return node, handler, wildcardLen } + if found == nil { + found = node + } } else if n.colon.handlerMap != nil { if handler := n.colon.handlerMap.Get(meth); handler != nil { return n.colon, handler, 0 diff --git a/router_test.go b/router_test.go index 362bb9e..dc1f8d6 100644 --- a/router_test.go +++ b/router_test.go @@ -108,12 +108,15 @@ func TestNotFound(t *testing.T) { require.Equal(t, http.StatusNotFound, w.Code) require.Equal(t, 0, calledNotFound) - // Now try with a custome handler. + // Now try with a custom handler. router = New(WithNotFoundHandler(notFoundHandler)) router.GET("/user/abc", simpleHandler) + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/abc/", nil) + router.ServeHTTP(w, req) - require.Equal(t, http.StatusNotFound, w.Code) + require.Equal(t, http.StatusOK, w.Code) require.Equal(t, 1, calledNotFound) } @@ -127,6 +130,7 @@ func TestMethodNotAllowed(t *testing.T) { router := New() router.POST("/abc", simpleHandler) + router.GET("/abc/:id/def/:sub_id", simpleHandler) w := httptest.NewRecorder() req, _ := http.NewRequest("GET", "/abc", nil) @@ -135,12 +139,22 @@ func TestMethodNotAllowed(t *testing.T) { require.Equal(t, http.StatusMethodNotAllowed, w.Code) require.Equal(t, 0, calledMethodNotAllowed) - // Now try with a custome handler. + w = httptest.NewRecorder() + req, _ = http.NewRequest("POST", "/abc/1/def/2", nil) + + router.ServeHTTP(w, req) + require.Equal(t, http.StatusMethodNotAllowed, w.Code) + require.Equal(t, 0, calledMethodNotAllowed) + + // Now try with a custom handler. router = New(WithMethodNotAllowedHandler(methodNotAllowedHandler)) router.POST("/abc", simpleHandler) + w = httptest.NewRecorder() + req, _ = http.NewRequest("GET", "/abc", nil) + router.ServeHTTP(w, req) - require.Equal(t, http.StatusMethodNotAllowed, w.Code) + require.Equal(t, http.StatusOK, w.Code) require.Equal(t, 1, calledMethodNotAllowed) }