From e5b2894201635c27c1c19ea7cd9c5e60fdaa4d8b Mon Sep 17 00:00:00 2001 From: Javier Marcos <1271349+javuto@users.noreply.github.com> Date: Mon, 16 Sep 2024 10:05:02 +0200 Subject: [PATCH] Using only UUID to identify environment in osctrl-tls --- environments/environments.go | 27 +++++++++++++++++++++++++++ tls/handlers/post.go | 28 ++++++++++++++-------------- 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/environments/environments.go b/environments/environments.go index f1a2b039..27130baa 100644 --- a/environments/environments.go +++ b/environments/environments.go @@ -121,6 +121,33 @@ func (environment *Environment) Get(identifier string) (TLSEnvironment, error) { return env, nil } +// Get TLS Environment by UUID +func (environment *Environment) GetByUUID(uuid string) (TLSEnvironment, error) { + var env TLSEnvironment + if err := environment.DB.Where("uuid = ?", uuid).First(&env).Error; err != nil { + return env, err + } + return env, nil +} + +// Get TLS Environment by Name +func (environment *Environment) GetByName(name string) (TLSEnvironment, error) { + var env TLSEnvironment + if err := environment.DB.Where("name = ?", name).First(&env).Error; err != nil { + return env, err + } + return env, nil +} + +// Get TLS Environment by ID +func (environment *Environment) GetByID(id uint) (TLSEnvironment, error) { + var env TLSEnvironment + if err := environment.DB.Where("ID = ?", id).First(&env).Error; err != nil { + return env, err + } + return env, nil +} + // Empty generates an empty TLSEnvironment with default values func (environment *Environment) Empty(name, hostname string) TLSEnvironment { return TLSEnvironment{ diff --git a/tls/handlers/post.go b/tls/handlers/post.go index 9b4b8519..8508e2a2 100644 --- a/tls/handlers/post.go +++ b/tls/handlers/post.go @@ -31,7 +31,7 @@ func (h *HandlersTLS) EnrollHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricEnrollErr) log.Printf("error getting environment %v", err) @@ -117,7 +117,7 @@ func (h *HandlersTLS) ConfigHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricConfigErr) log.Printf("error getting environment %v", err) @@ -183,7 +183,7 @@ func (h *HandlersTLS) LogHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricLogErr) log.Printf("error getting environment %v", err) @@ -261,7 +261,7 @@ func (h *HandlersTLS) QueryReadHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricReadErr) log.Printf("error getting environment %v", err) @@ -340,7 +340,7 @@ func (h *HandlersTLS) QueryWriteHandler(w http.ResponseWriter, r *http.Request) return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricWriteErr) log.Printf("error getting environment %v", err) @@ -419,7 +419,7 @@ func (h *HandlersTLS) QuickEnrollHandler(w http.ResponseWriter, r *http.Request) return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricOnelinerErr) log.Printf("error getting environment - %v", err) @@ -496,7 +496,7 @@ func (h *HandlersTLS) QuickRemoveHandler(w http.ResponseWriter, r *http.Request) return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricOnelinerErr) log.Printf("error getting environment - %v", err) @@ -575,7 +575,7 @@ func (h *HandlersTLS) CarveInitHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricInitErr) log.Printf("error getting environment %v", err) @@ -646,7 +646,7 @@ func (h *HandlersTLS) CarveBlockHandler(w http.ResponseWriter, r *http.Request) return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricBlockErr) log.Printf("error getting environment %v", err) @@ -706,7 +706,7 @@ func (h *HandlersTLS) FlagsHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricFlagsErr) log.Printf("error getting environment %v", err) @@ -761,7 +761,7 @@ func (h *HandlersTLS) CertHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricCertErr) log.Printf("error getting environment %v", err) @@ -810,7 +810,7 @@ func (h *HandlersTLS) VerifyHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricVerifyErr) log.Printf("error getting environment %v", err) @@ -869,7 +869,7 @@ func (h *HandlersTLS) ScriptHandler(w http.ResponseWriter, r *http.Request) { return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricScriptErr) log.Printf("error getting environment %v", err) @@ -949,7 +949,7 @@ func (h *HandlersTLS) EnrollPackageHandler(w http.ResponseWriter, r *http.Reques return } // Get environment - env, err := h.Envs.Get(envVar) + env, err := h.Envs.GetByUUID(envVar) if err != nil { h.Inc(metricPackageErr) log.Printf("error getting environment - %v", err)