Skip to content

Commit

Permalink
test(login): fix login integration test (#1587)
Browse files Browse the repository at this point in the history
  • Loading branch information
baurine authored Sep 11, 2023
1 parent ded9a32 commit f1e012a
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 16 deletions.
16 changes: 13 additions & 3 deletions pkg/apiserver/user/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ type AuthService struct {
middleware *jwt.GinJWTMiddleware
authenticators map[utils.AuthType]Authenticator

rsaPublicKey *rsa.PublicKey
RsaPublicKey *rsa.PublicKey
RsaPrivateKey *rsa.PrivateKey
}

Expand Down Expand Up @@ -104,7 +104,7 @@ func NewAuthService(featureFlags *featureflag.Registry) *AuthService {
middleware: nil,
authenticators: map[utils.AuthType]Authenticator{},
RsaPrivateKey: privateKey,
rsaPublicKey: publicKey,
RsaPublicKey: publicKey,
}

middleware, err := jwt.New(&jwt.GinJWTMiddleware{
Expand All @@ -122,6 +122,16 @@ func NewAuthService(featureFlags *featureflag.Registry) *AuthService {
if err != nil {
return nil, errorx.Decorate(err, "authenticate failed")
}
// TODO: uncomment it after thinking clearly
// if form.Type == 0 {
// // generate new rsa key pair for each sql auth login
// privateKey, publicKey, err := GenerateKey()
// // if generate successfully, replace the old key pair
// if err == nil {
// service.RsaPrivateKey = privateKey
// service.RsaPublicKey = publicKey
// }
// }
return u, nil
},
PayloadFunc: func(data interface{}) jwt.MapClaims {
Expand Down Expand Up @@ -312,7 +322,7 @@ func (s *AuthService) GetLoginInfoHandler(c *gin.Context) {
sort.Ints(supportedAuth)
// both work
// publicKeyStr, err := ExportPublicKeyAsString(s.rsaPublicKey)
publicKeyStr, err := DumpPublicKeyBase64(s.rsaPublicKey)
publicKeyStr, err := DumpPublicKeyBase64(s.RsaPublicKey)
if err != nil {
rest.Error(c, err)
return
Expand Down
12 changes: 12 additions & 0 deletions pkg/apiserver/user/rsa_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,18 @@ func DumpPrivateKeyBase64(privatekey *rsa.PrivateKey) (string, error) {
return keyBase64, nil
}

// Encrypt by public key.
func Encrypt(plainText string, publicKey *rsa.PublicKey) (string, error) {
encryptedText, err := rsa.EncryptPKCS1v15(rand.Reader, publicKey, []byte(plainText))
if err != nil {
return "", err
}

// the encryptedText is encoded by base64 in the frontend by jsEncrypt
encodedText := base64.StdEncoding.EncodeToString(encryptedText)
return encodedText, nil
}

// Decrypt by private key.
func Decrypt(cipherText string, privateKey *rsa.PrivateKey) (string, error) {
// the cipherText is encoded by base64 in the frontend by jsEncrypt
Expand Down
10 changes: 4 additions & 6 deletions pkg/apiserver/user/sqlauth/sqlauth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
package sqlauth

import (
"crypto/rsa"

"github.com/joomcode/errorx"
"go.uber.org/fx"

Expand All @@ -17,8 +15,8 @@ const typeID utils.AuthType = 0

type Authenticator struct {
user.BaseAuthenticator
tidbClient *tidb.Client
rsaPrivateKey *rsa.PrivateKey
tidbClient *tidb.Client
authService *user.AuthService
}

func NewAuthenticator(tidbClient *tidb.Client) *Authenticator {
Expand All @@ -29,7 +27,7 @@ func NewAuthenticator(tidbClient *tidb.Client) *Authenticator {

func registerAuthenticator(a *Authenticator, authService *user.AuthService) {
authService.RegisterAuthenticator(typeID, a)
a.rsaPrivateKey = authService.RsaPrivateKey
a.authService = authService
}

var Module = fx.Options(
Expand All @@ -38,7 +36,7 @@ var Module = fx.Options(
)

func (a *Authenticator) Authenticate(f user.AuthenticateForm) (*utils.SessionUser, error) {
plainPwd, err := user.Decrypt(f.Password, a.rsaPrivateKey)
plainPwd, err := user.Decrypt(f.Password, a.authService.RsaPrivateKey)
if err != nil {
return nil, user.ErrSignInOther.WrapWithNoMessage(err)
}
Expand Down
3 changes: 2 additions & 1 deletion tests/integration/info/info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,8 @@ func (s *testInfoSuite) getTokenBySQLRoot() string {
param := make(map[string]interface{})
param["type"] = 0
param["username"] = "root"
param["password"] = ""
pwd, _ := user.Encrypt("", s.authService.RsaPublicKey)
param["password"] = pwd

jsonByte, _ := json.Marshal(param)
req, _ := http.NewRequest(http.MethodPost, "/user/login", bytes.NewReader(jsonByte))
Expand Down
45 changes: 39 additions & 6 deletions tests/integration/user/user_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ func (s *testUserSuite) TestLoginWithNotExistUser() {
param := make(map[string]interface{})
param["type"] = 0
param["username"] = "not_exist"
param["password"] = "aaa"
pwd, _ := user.Encrypt("aaa", s.authService.RsaPublicKey)
param["password"] = pwd

jsonByte, _ := json.Marshal(param)
req, _ := http.NewRequest(http.MethodPost, "/user/login", bytes.NewReader(jsonByte))
Expand All @@ -109,7 +110,8 @@ func (s *testUserSuite) TestLoginWithWrongPassword() {
param := make(map[string]interface{})
param["type"] = 0
param["username"] = "dashboardAdmin"
param["password"] = "123456789"
pwd, _ := user.Encrypt("123456789", s.authService.RsaPublicKey)
param["password"] = pwd

jsonByte, _ := json.Marshal(param)
req, _ := http.NewRequest(http.MethodPost, "/user/login", bytes.NewReader(jsonByte))
Expand All @@ -125,7 +127,8 @@ func (s *testUserSuite) TestLoginWithInsufficientPrivs() {
param := make(map[string]interface{})
param["type"] = 0
param["username"] = "dashboardAdmin-2"
param["password"] = "12345678"
pwd, _ := user.Encrypt("12345678", s.authService.RsaPublicKey)
param["password"] = pwd

jsonByte, _ := json.Marshal(param)
req, _ := http.NewRequest(http.MethodPost, "/user/login", bytes.NewReader(jsonByte))
Expand All @@ -142,7 +145,8 @@ func (s *testUserSuite) TestLoginWithSufficientPrivs() {
param := make(map[string]interface{})
param["type"] = 0
param["username"] = "dashboardAdmin"
param["password"] = "12345678"
pwd, _ := user.Encrypt("12345678", s.authService.RsaPublicKey)
param["password"] = pwd

jsonByte, _ := json.Marshal(param)
req, _ := http.NewRequest(http.MethodPost, "/user/login", bytes.NewReader(jsonByte))
Expand Down Expand Up @@ -177,7 +181,8 @@ func (s *testUserSuite) TestLoginWithWrongPasswordForRoot() {
param := make(map[string]interface{})
param["type"] = 0
param["username"] = "root"
param["password"] = "aaa"
pwd, _ := user.Encrypt("aaa", s.authService.RsaPublicKey)
param["password"] = pwd

jsonByte, _ := json.Marshal(param)
req, _ := http.NewRequest(http.MethodPost, "/user/login", bytes.NewReader(jsonByte))
Expand All @@ -193,7 +198,8 @@ func (s *testUserSuite) TestLoginWithCorrectPasswordForRoot() {
param := make(map[string]interface{})
param["type"] = 0
param["username"] = "root"
param["password"] = ""
pwd, _ := user.Encrypt("", s.authService.RsaPublicKey)
param["password"] = pwd

jsonByte, _ := json.Marshal(param)
req, _ := http.NewRequest(http.MethodPost, "/user/login", bytes.NewReader(jsonByte))
Expand All @@ -210,6 +216,33 @@ func (s *testUserSuite) TestLoginWithCorrectPasswordForRoot() {
s.Require().Nil(err)
}

// TODO: uncomment it after thinking clearly
// func (s *testUserSuite) TestLoginWithSamePayloadTwice() {
// param := make(map[string]interface{})
// param["type"] = 0
// param["username"] = "root"
// pwd, _ := user.Encrypt("", s.authService.RsaPublicKey)
// param["password"] = pwd

// // success at the first time
// jsonByte, _ := json.Marshal(param)
// req, _ := http.NewRequest(http.MethodPost, "/user/login", bytes.NewReader(jsonByte))
// c, w := util.TestReqWithHandlers(req, s.authService.LoginHandler)

// s.Require().Len(c.Errors, 0)
// s.Require().Equal(200, c.Writer.Status())
// s.Require().Equal(200, w.Code)

// // fail at the second time
// req, _ = http.NewRequest(http.MethodPost, "/user/login", bytes.NewReader(jsonByte))
// c, w = util.TestReqWithHandlers(req, s.authService.LoginHandler)

// s.Require().Contains(c.Errors.Last().Err.Error(), "authenticate failed")
// s.Require().Contains(c.Errors.Last().Err.Error(), "crypto/rsa: decryption error")
// s.Require().Equal(401, c.Writer.Status())
// s.Require().Equal(401, w.Code)
// }

func (s *testUserSuite) TestLoginInfo() {
req, _ := http.NewRequest(http.MethodGet, "/user/login_info", nil)
c, w := util.TestReqWithHandlers(req, s.authService.GetLoginInfoHandler)
Expand Down

0 comments on commit f1e012a

Please sign in to comment.