Skip to content

Commit

Permalink
fix oidc test, add tests for migration
Browse files Browse the repository at this point in the history
Signed-off-by: Kristoffer Dalby <[email protected]>
  • Loading branch information
kradalby committed Oct 21, 2024
1 parent b5e1891 commit b65d689
Show file tree
Hide file tree
Showing 7 changed files with 475 additions and 49 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
- TestPolicyUpdateWhileRunningWithCLIInDatabase
- TestOIDCAuthenticationPingAll
- TestOIDCExpireNodesBasedOnTokenExpiry
- TestOIDC024UserCreation
- TestAuthWebFlowAuthenticationPingAll
- TestAuthWebFlowLogoutAndRelogin
- TestUserCommand
Expand Down
37 changes: 34 additions & 3 deletions cmd/headscale/cli/mockoidc.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package cli

import (
"encoding/json"
"fmt"
"net"
"net/http"
"os"
"strconv"
"time"
Expand Down Expand Up @@ -64,14 +66,27 @@ func mockOIDC() error {
accessTTL = newTTL
}

userStr := os.Getenv("MOCKOIDC_USERS")
if userStr == "" {
return fmt.Errorf("MOCKOIDC_USERS not defined")
}

var users []mockoidc.MockUser
err := json.Unmarshal([]byte(userStr), &users)
if err != nil {
return fmt.Errorf("unmarshalling users: %w", err)
}

log.Info().Interface("users", users).Msg("loading users from JSON")

log.Info().Msgf("Access token TTL: %s", accessTTL)

port, err := strconv.Atoi(portStr)
if err != nil {
return err
}

mock, err := getMockOIDC(clientID, clientSecret)
mock, err := getMockOIDC(clientID, clientSecret, users)
if err != nil {
return err
}
Expand All @@ -93,12 +108,18 @@ func mockOIDC() error {
return nil
}

func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) {
func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) {
keypair, err := mockoidc.NewKeypair(nil)
if err != nil {
return nil, err
}

userQueue := mockoidc.UserQueue{}

for _, user := range users {
userQueue.Push(&user)
}

mock := mockoidc.MockOIDC{
ClientID: clientID,
ClientSecret: clientSecret,
Expand All @@ -107,9 +128,19 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro
CodeChallengeMethodsSupported: []string{"plain", "S256"},
Keypair: keypair,
SessionStore: mockoidc.NewSessionStore(),
UserQueue: &mockoidc.UserQueue{},
UserQueue: &userQueue,
ErrorQueue: &mockoidc.ErrorQueue{},
}

mock.AddMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Info().Msgf("Request: %+v", r)
h.ServeHTTP(w, r)
if r.Response != nil {
log.Info().Msgf("Response: %+v", r.Response)
}
})
})

return &mock, nil
}
9 changes: 7 additions & 2 deletions hscontrol/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
) (*types.User, error) {
var user *types.User
var err error
user, err = a.db.GetUserByOIDCIdentifier(claims.Sub)
user, err = a.db.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
}
Expand All @@ -448,10 +448,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
// TODO(kradalby): Remove when strip_email_domain and migration is removed
// after #2170 is cleaned up.
if a.cfg.MapLegacyUsers && user == nil {
log.Trace().Str("username", claims.Username).Str("sub", claims.Sub).Msg("user not found by OIDC identifier, looking up by username")
if oldUsername, err := getUserName(claims, a.cfg.StripEmaildomain); err == nil {
log.Trace().Str("old_username", oldUsername).Str("sub", claims.Sub).Msg("found username")
user, err = a.db.GetUserByName(oldUsername)
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, fmt.Errorf("creating or updating user: %w", err)
return nil, fmt.Errorf("getting user: %w", err)
}

// If the user exists, but it already has a provider identifier (OIDC sub), create a new user.
Expand Down Expand Up @@ -525,6 +527,9 @@ func getUserName(
claims *types.OIDCClaims,
stripEmaildomain bool,
) (string, error) {
if !claims.EmailVerified {
return "", fmt.Errorf("email not verified")
}
userName, err := util.NormalizeToFQDNRules(
claims.Email,
stripEmaildomain,
Expand Down
5 changes: 4 additions & 1 deletion hscontrol/types/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -905,7 +905,10 @@ func LoadServerConfig() (*Config, error) {
}
}(),
UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"),
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
// TODO(kradalby): Remove when strip_email_domain is removed
// after #2170 is cleaned up
StripEmaildomain: viper.GetBool("oidc.strip_email_domain"),
MapLegacyUsers: viper.GetBool("oidc.map_legacy_users"),
},

LogTail: logTailConfig,
Expand Down
16 changes: 10 additions & 6 deletions hscontrol/types/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package types
import (
"cmp"
"strconv"
"strings"

v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/juanfont/headscale/hscontrol/util"
Expand Down Expand Up @@ -39,7 +38,7 @@ type User struct {
// Unique identifier of the user from OIDC,
// comes from `sub` claim in the OIDC token
// and is used to lookup the user.
ProviderIdentifier string `gorm:"index,uniqueIndex:idx_name_provider_identifier"`
ProviderIdentifier string `gorm:"unique,index,uniqueIndex:idx_name_provider_identifier"`

// Provider is the origin of the user account,
// same as RegistrationMethod, without authkey.
Expand All @@ -58,9 +57,10 @@ type User struct {
// If the username does not contain an '@' it will be added to the end.
func (u *User) Username() string {
username := cmp.Or(u.Email, u.Name, u.ProviderIdentifier, strconv.FormatUint(uint64(u.ID), 10))
if !strings.Contains(username, "@") {
username = username + "@"
}
// TODO(kradalby): Wire up all of this for the future
// if !strings.Contains(username, "@") {
// username = username + "@"
// }

return username
}
Expand Down Expand Up @@ -138,10 +138,14 @@ type OIDCClaims struct {
Username string `json:"preferred_username,omitempty"`
}

func (c *OIDCClaims) Identifier() string {
return c.Iss + "/" + c.Sub
}

// FromClaim overrides a User from OIDC claims.
// All fields will be updated, except for the ID.
func (u *User) FromClaim(claims *OIDCClaims) {
u.ProviderIdentifier = claims.Iss + "/" + claims.Sub
u.ProviderIdentifier = claims.Identifier()
u.DisplayName = claims.Name
if claims.EmailVerified {
u.Email = claims.Email
Expand Down
Loading

0 comments on commit b65d689

Please sign in to comment.