Skip to content

Commit

Permalink
fix key auth issue
Browse files Browse the repository at this point in the history
  • Loading branch information
spikelu2016 committed Mar 19, 2024
1 parent 12c1f40 commit 30d88a0
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cmd/bricksllm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ func main() {
v := validator.NewValidator(costLimitCache, rateLimitCache, costStorage)
rec := recorder.NewRecorder(costStorage, costLimitCache, ce, store)
rlm := manager.NewRateLimitManager(rateLimitCache)
a := auth.NewAuthenticator(psm, memStore, rm)
a := auth.NewAuthenticator(psm, memStore, rm, store)

c := cache.NewCache(apiCache)

Expand Down
31 changes: 29 additions & 2 deletions internal/authenticator/authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,23 @@ type keyMemStorage interface {
GetKey(hash string) *key.ResponseKey
}

type keyStorage interface {
GetKeyByHash(hash string) (*key.ResponseKey, error)
}

type Authenticator struct {
psm providerSettingsManager
kms keyMemStorage
rm routesManager
ks keyStorage
}

func NewAuthenticator(psm providerSettingsManager, kms keyMemStorage, rm routesManager) *Authenticator {
func NewAuthenticator(psm providerSettingsManager, kms keyMemStorage, rm routesManager, ks keyStorage) *Authenticator {
return &Authenticator{
psm: psm,
kms: kms,
rm: rm,
ks: ks,
}
}

Expand Down Expand Up @@ -151,6 +157,11 @@ func canAccessPath(provider string, path string) bool {
return true
}

type notFoundError interface {
Error() string
NotFound()
}

func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.ResponseKey, []*provider.Setting, error) {
raw, err := getApiKey(req)
if err != nil {
Expand All @@ -160,7 +171,23 @@ func (a *Authenticator) AuthenticateHttpRequest(req *http.Request) (*key.Respons
hash := encrypter.Encrypt(raw)

key := a.kms.GetKey(hash)
if key == nil || key.Revoked {
if key == nil {
key, err = a.ks.GetKeyByHash(hash)
if err != nil {
_, ok := err.(notFoundError)
if ok {
return nil, nil, internal_errors.NewAuthError("key not found")
}

return nil, nil, err
}
}

if key == nil {
return nil, nil, internal_errors.NewAuthError("key not found")
}

if key.Revoked {
return nil, nil, internal_errors.NewAuthError("not authorized")
}

Expand Down
54 changes: 54 additions & 0 deletions internal/storage/postgresql/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,60 @@ func (s *Store) GetKeys(tags, keyIds []string, provider string) ([]*key.Response
return keys, nil
}

func (s *Store) GetKeyByHash(hash string) (*key.ResponseKey, error) {
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.rt)
defer cancel()

var k key.ResponseKey
var settingId sql.NullString
var data []byte

err := s.db.QueryRowContext(ctxTimeout, "SELECT * FROM keys WHERE key = $1", hash).Scan(
&k.Name,
&k.CreatedAt,
&k.UpdatedAt,
pq.Array(&k.Tags),
&k.Revoked,
&k.KeyId,
&k.Key,
&k.RevokedReason,
&k.CostLimitInUsd,
&k.CostLimitInUsdOverTime,
&k.CostLimitInUsdUnit,
&k.RateLimitOverTime,
&k.RateLimitUnit,
&k.Ttl,
&settingId,
&data,
pq.Array(&k.SettingIds),
&k.ShouldLogRequest,
&k.ShouldLogResponse,
&k.RotationEnabled,
&k.PolicyId,
)

if err != nil {
if err == sql.ErrNoRows {
return nil, internal_errors.NewNotFoundError("key is not found using hash")
}

return nil, err
}

k.SettingId = settingId.String

if len(data) != 0 {
pathConfigs := []key.PathConfig{}
if err := json.Unmarshal(data, &pathConfigs); err != nil {
return nil, err
}

k.AllowedPaths = pathConfigs
}

return &k, nil
}

func (s *Store) GetKey(keyId string) (*key.ResponseKey, error) {
ctxTimeout, cancel := context.WithTimeout(context.Background(), s.rt)
defer cancel()
Expand Down

0 comments on commit 30d88a0

Please sign in to comment.