diff --git a/main.go b/main.go index f11c3a0..d28365e 100644 --- a/main.go +++ b/main.go @@ -5,8 +5,7 @@ import ( "os" "net/http" "crypto/tls" -// "pleasant" - "github.com/bva/vault-pps-plugin/pleasant" + "./pleasant" "github.com/hashicorp/vault/helper/pluginutil" "github.com/hashicorp/vault/logical" @@ -29,6 +28,7 @@ func main() { BackendFactoryFunc: factoryFunc, TLSProviderFunc: tlsProviderFunc, }) + if err != nil { log.Println(err) os.Exit(1) diff --git a/pleasant/backend.go b/pleasant/backend.go index d53f031..0462b68 100644 --- a/pleasant/backend.go +++ b/pleasant/backend.go @@ -7,31 +7,24 @@ import ( "github.com/hashicorp/vault/logical/framework" ) -// New returns a new backend as an interface. This func -// is only necessary for builtin backend plugins. -func New() (interface{}, error) { - return Backend(), nil -} - -// Factory returns a new backend as logical.Backend. -func Factory(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { - b := Backend() - if err := b.Setup(ctx, conf); err != nil { - return nil, err - } - return b, nil -} +var( + backend_singleton *backend +) // FactoryType is a wrapper func that allows the Factory func to specify // the backend type for the mock backend plugin instance. func FactoryType(backendType logical.BackendType) logical.Factory { return func(ctx context.Context, conf *logical.BackendConfig) (logical.Backend, error) { - b := Backend() - b.BackendType = backendType - if err := b.Setup(ctx, conf); err != nil { - return nil, err + if backend_singleton == nil { + backend_singleton = Backend() + + backend_singleton.BackendType = backendType + if err := backend_singleton.Setup(ctx, conf); err != nil { + return nil, err + } } - return b, nil + + return *backend_singleton, nil } } diff --git a/pleasant/path_kv.go b/pleasant/path_kv.go index 10e3eab..2bd3ad0 100644 --- a/pleasant/path_kv.go +++ b/pleasant/path_kv.go @@ -158,21 +158,19 @@ func (b *backend) credentialUpdate(credential *Credential, data *framework.Field for field, value := range data.Raw { if strings.HasPrefix(field, "Custom:") { custom_field := strings.TrimPrefix(field, "Custom:") - custom_field_value := value.(string) - if len(custom_field_value) > 0 { - custom_fields[custom_field] = custom_field_value - } else { + if value == nil || len(value.(string)) == 0 { delete(custom_fields, custom_field) + } else { + custom_fields[custom_field] = value.(string) } } else if strings.HasPrefix(field, "Attachment:") { attachment_field := strings.TrimPrefix(field, "Attachment:") - attachment_field_value := value.(string) - if len(attachment_field_value) > 0 { - attachments[attachment_field] = attachment_field_value - } else { + if value == nil || len(value.(string)) == 0 { delete(attachments, attachment_field) + } else { + attachments[attachment_field] = value.(string) } } } diff --git a/pleasant/pleasant.go b/pleasant/pleasant.go index 6290f20..ba43849 100644 --- a/pleasant/pleasant.go +++ b/pleasant/pleasant.go @@ -6,6 +6,7 @@ import ( "strings" "time" "fmt" + "sync" "sync/atomic" ) @@ -84,6 +85,8 @@ type Pleasant struct { backend logical.Backend + mutex *sync.Mutex + reauth_quit chan bool recredential_quit chan bool } @@ -91,6 +94,8 @@ type Pleasant struct { func NewPleasant(backend logical.Backend) *Pleasant { p := new(Pleasant) + p.mutex = &sync.Mutex{} + p.reauth_quit = make(chan bool) p.recredential_quit = make(chan bool) @@ -107,20 +112,13 @@ func (p *Pleasant) Login(url, username, password string) *Pleasant { "Accept-Encoding": "gzip", }) - resp, _ := p.resty.R().SetFormData(map[string]string{ - "grant_type": "password", - "username": username, - "password": password, - }).SetResult(AuthSuccess{}).Post("/OAuth2/Token") - - auth := resp.Result().(*AuthSuccess) - p.auth.Store(auth) + p.auth.Store(p.RequestAuth(username, password)) go func() { auth := p.auth.Load().(*AuthSuccess) ticker := time.NewTicker(time.Second) - duration := time.Duration(auth.ExpiresIn - 30) * time.Second + duration := time.Duration(int(auth.ExpiresIn/2)) * time.Second for range ticker.C { duration -= time.Second @@ -133,16 +131,35 @@ func (p *Pleasant) Login(url, username, password string) *Pleasant { default: if duration <= 0 { p.backend.Logger().Debug("Timer cached AuthToken") - resp, _ := p.resty.R().SetFormData(map[string]string{ - "grant_type": "password", - "username": username, - "password": password, - }).SetResult(AuthSuccess{}).Post("/OAuth2/Token") + p.auth.Store(p.RequestAuth(username, password)) + auth = p.auth.Load().(*AuthSuccess) + duration = time.Duration(int(auth.ExpiresIn/2)) * time.Second + } + } + } + }() - auth := resp.Result().(*AuthSuccess) - p.auth.Store(auth) + go func() { + p.backend.Logger().Debug("Starting cached RootCredentialGroup goroutine") + + ticker := time.NewTicker(time.Second) + duration := time.Duration(5*60) * time.Second - duration = time.Duration(auth.ExpiresIn - 1) * time.Second + for range ticker.C { + duration -= time.Second + + select { + case <- p.recredential_quit: + p.backend.Logger().Debug("Timer cached RootCredentialGroup cancelled") + return + + default: + if duration <= 0 { + p.backend.Logger().Debug("Timer cached RootCredentialGroup") + p.mutex.Lock() + p.Invalidate() + p.mutex.Unlock() + duration = time.Duration(5*60) * time.Second } } } @@ -158,57 +175,46 @@ func (p *Pleasant) Logout() { p.auth.Store(nil) } -func (p *Pleasant) Invalidate() { - p.RequestRootCredentialGroup(true) -} - func (p *Pleasant) GetAccessToken() string { return p.auth.Load().(*AuthSuccess).AccessToken } -func (p *Pleasant) RequestRootCredentialGroup(invalidate bool) *CredentialGroup { - p.backend.Logger().Debug("RequestRootCredentialGroup") +func (p *Pleasant) RequestAuth(username string, password string) *AuthSuccess { + p.mutex.Lock() - root := p.root.Load() + resp, _ := p.resty.R().SetFormData(map[string]string{ + "grant_type": "password", + "username": username, + "password": password, + }).SetResult(AuthSuccess{}).Post("/OAuth2/Token") - if(root != nil && !invalidate) { - p.backend.Logger().Debug("Return cached RootCredentialGroup") - return root.(*CredentialGroup) - } + p.mutex.Unlock() + return resp.Result().(*AuthSuccess) +} + +func (p *Pleasant) Invalidate() *CredentialGroup { new_root := p.RequestCredentialGroup("") + if new_root != nil { - root = new_root + p.root.Store(new_root) } - p.root.Store(root) - - if !invalidate { - go func() { - p.backend.Logger().Debug("Starting cached RootCredentialGroup goroutine") - - ticker := time.NewTicker(time.Second) - duration := time.Duration(5*60) * time.Second + return new_root +} - for range ticker.C { - duration -= time.Second +func (p *Pleasant) RequestRootCredentialGroup(invalidate bool) *CredentialGroup { + p.backend.Logger().Debug("RequestRootCredentialGroup") - select { - case <- p.recredential_quit: - p.backend.Logger().Debug("Timer cached RootCredentialGroup cancelled") - return + root := p.root.Load() - default: - if duration <= 0 { - p.backend.Logger().Debug("Timer cached RootCredentialGroup") - p.root.Store(p.RequestCredentialGroup("")) - duration = time.Duration(5*60) * time.Second - } - } - } - }() + if(root != nil && !invalidate) { + p.backend.Logger().Debug("Return cached RootCredentialGroup") + return root.(*CredentialGroup) } + root = p.Invalidate() + return root.(*CredentialGroup) } @@ -233,7 +239,9 @@ func (p *Pleasant) Read(path string) (*CredentialGroup, *Credential) { b.Logger().Debug("PathSplitted", fmt.Sprintf("%v", path_splitted)) + p.mutex.Lock() node := p.RequestRootCredentialGroup(false) + p.mutex.Unlock() if(len(path_splitted) == 0) { return node, nil @@ -268,10 +276,7 @@ func (p *Pleasant) Read(path string) (*CredentialGroup, *Credential) { if credential.Name == last_leaf || credential.Name+"["+credential.Id+"]" == last_leaf { b.Logger().Debug("LastLeaf is Credential", credential.Name) - -// updated_credential := p.RequestCredential(credential.Id) - updated_credential := &credential - return node, updated_credential + return node, &credential } } @@ -280,10 +285,7 @@ func (p *Pleasant) Read(path string) (*CredentialGroup, *Credential) { if credential.Username == last_leaf || credential.Username+"["+credential.Id+"]" == last_leaf { b.Logger().Debug("LastLeaf is Credential", credential.Name) - -// updated_credential := p.RequestCredential(credential.Id) - updated_credential := &credential - return node, updated_credential + return node, &credential } } @@ -292,7 +294,6 @@ func (p *Pleasant) Read(path string) (*CredentialGroup, *Credential) { if (group.Name == last_leaf || group.Name + "[" + group.Id + "]" == last_leaf) { b.Logger().Debug("LastLeaf is CredentialGroup", group.Name) - // return p.RequestCredentialGroup(group.Id), nil, extra_path return &group, nil } } @@ -305,63 +306,85 @@ func (p *Pleasant) RequestCredentialGroup(id string) *CredentialGroup { request := p.resty.R().SetHeader("Authorization", p.GetAccessToken()).SetResult(CredentialGroup{}) resp, _ := request.Get(strings.Join([]string {"/api/v4/rest/credentialgroup/", id}, "/")) + var result *CredentialGroup = nil + if resp.StatusCode() == 200 && resp.Result() != nil { - return resp.Result().(*CredentialGroup) + result = resp.Result().(*CredentialGroup) } - return nil + return result } func (p *Pleasant) RequestCredential(id string) *Credential { + p.mutex.Lock() + request := p.resty.R().SetHeader("Authorization", p.GetAccessToken()).SetResult(Credential{}) resp, _ := request.Get(strings.Join([]string {"/api/v4/rest/credential/", id}, "/")) + var result *Credential = nil + if resp.StatusCode() == 200 && resp.Result() != nil { - return resp.Result().(*Credential) + result = resp.Result().(*Credential) } - return nil + p.mutex.Unlock() + + return result } func (p *Pleasant) UpdateCredential(credential *Credential) { + p.mutex.Lock() request := p.resty.R().SetHeader("Authorization", p.GetAccessToken()).SetBody(credential) request.Put(strings.Join([]string {"/api/v4/rest/credential/", credential.Id}, "/")) p.Invalidate() + p.mutex.Unlock() } func (p *Pleasant) UpdateCredentialGroup(group *CredentialGroup) { + p.mutex.Lock() request := p.resty.R().SetHeader("Authorization", p.GetAccessToken()).SetBody(group) request.Put(strings.Join([]string {"/api/v4/rest/credentialgroup/", group.Id}, "/")) p.Invalidate() + p.mutex.Unlock() } func (p *Pleasant) CreateCredential(credential *Credential) { + p.mutex.Lock() request := p.resty.R().SetHeader("Authorization", p.GetAccessToken()).SetBody(credential) request.Post("/api/v4/rest/credential") p.Invalidate() + p.mutex.Unlock() } func (p *Pleasant) CreateCredentialGroup(group *CredentialGroup) { + p.mutex.Lock() request := p.resty.R().SetHeader("Authorization", p.GetAccessToken()).SetBody(group) request.Post("/api/v4/rest/credentialgroup") p.Invalidate() + p.mutex.Unlock() } func (p *Pleasant) DeleteCredential(credential *Credential) { + p.mutex.Lock() request := p.resty.R().SetHeader("Authorization", p.GetAccessToken()) request.Delete(strings.Join([]string {"/api/v4/rest/credential/", credential.Id}, "/")) p.Invalidate() + p.mutex.Unlock() } func (p *Pleasant) DeleteCredentialGroup(group *CredentialGroup) { + p.mutex.Lock() request := p.resty.R().SetHeader("Authorization", p.GetAccessToken()) request.Delete(strings.Join([]string {"/api/v4/rest/credentialgroup/", group.Id}, "/")) p.Invalidate() + p.mutex.Unlock() } func (p *Pleasant) RequestCredentialPassword(id string) string { + p.mutex.Lock() request := p.resty.R().SetHeader("Authorization", p.GetAccessToken()) resp, _ := request.Get(strings.Join([]string {"/api/v4/rest/credential/", id, "password"}, "/")) + p.mutex.Unlock() if(len(resp.String()) > 0) { return (resp.String())[1 : len(resp.String())-1]