Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

casbin: add EnforceHandler to allow custom callback to handle enforcing. #66

Merged
merged 3 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions .github/workflows/echo-contrib.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@ on:

jobs:
test:
env:
latest: '1.17'
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
go: [1.14, 1.15, 1.16]
go: ['1.14', '1.15', '1.16', '1.17']
name: ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -39,26 +41,27 @@ jobs:
with:
ref: ${{ github.ref }}

- name: Install Dependencies
run: go get -u honnef.co/go/tools/cmd/staticcheck@latest
- name: Run static checks
if: matrix.go == env.latest && matrix.os == 'ubuntu-latest'
run: |
go get -u honnef.co/go/tools/cmd/staticcheck@latest
staticcheck -tests=false ./...

- name: Run Tests
run: |
staticcheck -tests=false ./...
go test -race --coverprofile=coverage.coverprofile --covermode=atomic ./...

- name: Upload coverage to Codecov
if: success() && matrix.go == 1.16 && matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v1
if: success() && matrix.go == env.latest && matrix.os == 'ubuntu-latest'
uses: codecov/codecov-action@v2
with:
token:
fail_ci_if_error: false
benchmark:
needs: test
strategy:
matrix:
os: [ubuntu-latest]
go: [1.16]
go: [1.17]
name: Benchmark comparison ${{ matrix.os }} @ Go ${{ matrix.go }}
runs-on: ${{ matrix.os }}
steps:
Expand Down
70 changes: 40 additions & 30 deletions casbin/casbin.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,11 @@ Advanced example:
package casbin

import (
"net/http"

"errors"
"github.com/casbin/casbin/v2"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"net/http"
)

type (
Expand All @@ -59,11 +59,18 @@ type (
Skipper middleware.Skipper

// Enforcer CasbinAuth main rule.
// Required.
// One of Enforcer or EnforceHandler fields is required.
Enforcer *casbin.Enforcer

// EnforceHandler is custom callback to handle enforcing.
// One of Enforcer or EnforceHandler fields is required.
EnforceHandler func(c echo.Context, user string) (bool, error)

// Method to get the username - defaults to using basic auth
UserGetter func(c echo.Context) (string, error)

// Method to handle errors
ErrorHandler func(c echo.Context, internal error, proposedStatus int) error
}
)

Expand All @@ -75,6 +82,11 @@ var (
username, _, _ := c.Request().BasicAuth()
return username, nil
},
ErrorHandler: func(c echo.Context, internal error, proposedStatus int) error {
err := echo.NewHTTPError(proposedStatus, internal.Error())
err.Internal = internal
return err
},
}
)

Expand All @@ -91,44 +103,42 @@ func Middleware(ce *casbin.Enforcer) echo.MiddlewareFunc {
// MiddlewareWithConfig returns a CasbinAuth middleware with config.
// See `Middleware()`.
func MiddlewareWithConfig(config Config) echo.MiddlewareFunc {
// Defaults
if config.Enforcer == nil && config.EnforceHandler == nil {
panic("one of casbin middleware Enforcer or EnforceHandler fields must be set")
}
if config.Skipper == nil {
config.Skipper = DefaultConfig.Skipper
}
if config.UserGetter == nil {
config.UserGetter = DefaultConfig.UserGetter
}
if config.ErrorHandler == nil {
config.ErrorHandler = DefaultConfig.ErrorHandler
}
if config.EnforceHandler == nil {
config.EnforceHandler = func(c echo.Context, user string) (bool, error) {
return config.Enforcer.Enforce(user, c.Request().URL.Path, c.Request().Method)
}
}

return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
if config.Skipper(c) {
return next(c)
}

if pass, err := config.CheckPermission(c); err == nil && pass {
return next(c)
} else if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, err.Error())
user, err := config.UserGetter(c)
if err != nil {
return config.ErrorHandler(c, err, http.StatusForbidden)
}

return echo.ErrForbidden
pass, err := config.EnforceHandler(c, user)
if err != nil {
return config.ErrorHandler(c, err, http.StatusInternalServerError)
}
if !pass {
return config.ErrorHandler(c, errors.New("enforce did not pass"), http.StatusForbidden)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the status code configurable? For example, one might use 402 instead of 403.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be done inside/with errorhandler.

}
return next(c)
}
}
}

// GetUserName gets the user name from the request.
// It calls the UserGetter field of the Config struct that allows the caller to customize user identification.
func (a *Config) GetUserName(c echo.Context) (string, error) {
username, err := a.UserGetter(c)
return username, err
}

// CheckPermission checks the user/method/path combination from the request.
// Returns true (permission granted) or false (permission forbidden)
func (a *Config) CheckPermission(c echo.Context) (bool, error) {
user, err := a.GetUserName(c)
if err != nil {
// Fail safe and do not propagate
return false, nil
}
method := c.Request().Method
path := c.Request().URL.Path
return a.Enforcer.Enforce(user, path, method)
}
26 changes: 26 additions & 0 deletions casbin/casbin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@ package casbin

import (
"errors"
"github.com/stretchr/testify/assert"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/casbin/casbin/v2"
Expand Down Expand Up @@ -131,3 +133,27 @@ func TestUserGetterError(t *testing.T) {
})
testRequest(t, h, "cathy", "/dataset1/item", "GET", 403)
}

func TestCustomEnforceHandler(t *testing.T) {
ce, err := casbin.NewEnforcer("auth_model.conf", "auth_policy.csv")
assert.NoError(t, err)

_, err = ce.AddPolicy("bob", "/user/bob", "PATCH_SELF")
assert.NoError(t, err)

cnf := Config{
EnforceHandler: func(c echo.Context, user string) (bool, error) {
method := c.Request().Method
if strings.HasPrefix(c.Request().URL.Path, "/user/bob") {
method += "_SELF"
}
return ce.Enforce(user, c.Request().URL.Path, method)
},
}
h := MiddlewareWithConfig(cnf)(func(c echo.Context) error {
return c.String(http.StatusOK, "test")
})
testRequest(t, h, "bob", "/dataset2/resource1", "GET", http.StatusOK)
testRequest(t, h, "bob", "/user/alice", "PATCH", http.StatusForbidden)
testRequest(t, h, "bob", "/user/bob", "PATCH", http.StatusOK)
}