diff --git a/Makefile b/Makefile index 8d5ef1ba..64e64b14 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,10 @@ ADMIN_DIR = cmd/admin ADMIN_NAME = osctrl-admin ADMIN_CODE = ${ADMIN_DIR:=/*.go} +API_DIR = cmd/api +API_NAME = osctrl-api +API_CODE = ${API_DIR:=/*.go} + CLI_DIR = cmd/cli CLI_NAME = osctrl-cli CLI_CODE = ${CLI_DIR:=/*.go} @@ -30,6 +34,7 @@ build: make plugins make tls make admin + make api make cli # Build TLS endpoint @@ -40,6 +45,10 @@ tls: admin: go build -o $(OUTPUT)/$(ADMIN_NAME) $(ADMIN_CODE) +# Build API +api: + go build -o $(OUTPUT)/$(API_NAME) $(API_CODE) + # Build the CLI cli: go build -o $(OUTPUT)/$(CLI_NAME) $(CLI_CODE) @@ -55,6 +64,7 @@ plugins: clean: rm -rf $(OUTPUT)/$(TLS_NAME) rm -rf $(OUTPUT)/$(ADMIN_NAME) + rm -rf $(OUTPUT)/$(API_NAME) rm -rf $(OUTPUT)/$(CLI_NAME) rm -rf $(PLUGINS_DIR)/*.so @@ -70,6 +80,7 @@ install: make build make install_tls make install_admin + make install_api make install_cli # Install TLS server and restart service @@ -86,6 +97,13 @@ install_admin: sudo cp $(OUTPUT)/$(ADMIN_NAME) $(DEST) sudo systemctl start $(ADMIN_NAME) +# Install API server and restart service +# optional DEST=destination_path +install_api: + sudo systemctl stop $(API_NAME) + sudo cp $(OUTPUT)/$(API_NAME) $(DEST) + sudo systemctl start $(API_NAME) + # Install CLI # optional DEST=destination_path install_cli: @@ -99,6 +117,10 @@ logs_tls: logs_admin: sudo journalctl -f -t $(ADMIN_NAME) +# Display systemd logs for API server +logs_api: + sudo journalctl -f -t $(API_NAME) + # Build docker containers and run them (also generates new certificates) docker_all: ./docker/dockerize.sh -u -b -f @@ -131,6 +153,9 @@ gofmt-tls: gofmt-admin: gofmt $(GOFMT_ARGS) ./$(ADMIN_CODE) +gofmt-api: + gofmt $(GOFMT_ARGS) ./$(API_CODE) + gofmt-cli: gofmt $(GOFMT_ARGS) ./$(CLI_CODE) @@ -148,8 +173,12 @@ test: cd $(TLS_DIR) && go test . -v # Install dependencies for Admin cd $(ADMIN_DIR) && go test -i . -v - # Run TLS tests + # Run Admin tests cd $(ADMIN_DIR) && go test . -v + # Install dependencies for API + cd $(API_DIR) && go test -i . -v + # Run API tests + cd $(API_DIR) && go test . -v # Install dependencies for CLI cd $(CLI_DIR) && go test -i . -v # Run CLI tests diff --git a/Vagrantfile b/Vagrantfile index 750a4e26..7c9ced9a 100644 --- a/Vagrantfile +++ b/Vagrantfile @@ -3,9 +3,11 @@ VAGRANTFILE_API_VERSION = "2" +IP_ADDRESS = "10.10.10.6" + Vagrant.configure(VAGRANTFILE_API_VERSION) do |config| config.vm.box = "ubuntu/bionic64" - config.vm.network "private_network", ip: "10.10.10.6" + config.vm.network "private_network", ip: IP_ADDRESS # If we want to enroll nodes in the same network #config.vm.network "forwarded_port", guest: 443, host: 443 config.vm.hostname = "osctrl-Dev" @@ -13,8 +15,8 @@ Vagrant.configure(VAGRANTFILE_API_VERSION) do |config| config.vm.provision "shell" do |s| s.path = "deploy/provision.sh" s.args = [ - "--nginx", "--postgres", "-E", "--metrics", "--tls-hostname", - "10.10.10.6", "--admin-hostname", "10.10.10.6", "--password", "admin" + "--nginx", "--postgres", "-E", "--metrics", "--all-hostname", + IP_ADDRESS, "--password", "admin" ] privileged = false end diff --git a/cmd/admin/handlers-get.go b/cmd/admin/handlers-get.go index c921f1b0..a77a0b7f 100644 --- a/cmd/admin/handlers-get.go +++ b/cmd/admin/handlers-get.go @@ -1068,7 +1068,8 @@ func usersGETHandler(w http.ResponseWriter, r *http.Request) { } // Custom functions to handle formatting funcMap := template.FuncMap{ - "pastTimeAgo": pastTimeAgo, + "pastTimeAgo": pastTimeAgo, + "inFutureTime": inFutureTime, } // Prepare template t, err := template.New("users.html").Funcs(funcMap).ParseFiles( diff --git a/cmd/admin/handlers-post.go b/cmd/admin/handlers-post.go index 6a411124..0c7148a0 100644 --- a/cmd/admin/handlers-post.go +++ b/cmd/admin/handlers-post.go @@ -153,19 +153,7 @@ func queryRunPOSTHandler(w http.ResponseWriter, r *http.Request) { goto send_response } // Prepare and create new query - queryName := "query_" + generateQueryName() - newQuery := queries.DistributedQuery{ - Query: q.Query, - Name: queryName, - Creator: ctx[ctxUser], - Expected: 0, - Executions: 0, - Active: true, - Completed: false, - Deleted: false, - Repeat: 0, - Type: queries.StandardQueryType, - } + newQuery := newQueryReady(ctx[ctxUser], q.Query) if err := queriesmgr.Create(newQuery); err != nil { responseMessage = "error creating query" responseCode = http.StatusInternalServerError @@ -178,7 +166,7 @@ func queryRunPOSTHandler(w http.ResponseWriter, r *http.Request) { if len(q.Environments) > 0 { for _, e := range q.Environments { if (e != "") && envs.Exists(e) { - if err := queriesmgr.CreateTarget(queryName, queries.QueryTargetEnvironment, e); err != nil { + if err := queriesmgr.CreateTarget(newQuery.Name, queries.QueryTargetEnvironment, e); err != nil { responseMessage = "error creating query environment target" responseCode = http.StatusInternalServerError log.Printf("%s %v", responseMessage, err) @@ -201,7 +189,7 @@ func queryRunPOSTHandler(w http.ResponseWriter, r *http.Request) { if len(q.Platforms) > 0 { for _, p := range q.Platforms { if (p != "") && checkValidPlatform(p) { - if err := queriesmgr.CreateTarget(queryName, queries.QueryTargetPlatform, p); err != nil { + if err := queriesmgr.CreateTarget(newQuery.Name, queries.QueryTargetPlatform, p); err != nil { responseMessage = "error creating query platform target" responseCode = http.StatusInternalServerError log.Printf("%s %v", responseMessage, err) @@ -224,7 +212,7 @@ func queryRunPOSTHandler(w http.ResponseWriter, r *http.Request) { if len(q.UUIDs) > 0 { for _, u := range q.UUIDs { if (u != "") && nodesmgr.CheckByUUID(u) { - if err := queriesmgr.CreateTarget(queryName, queries.QueryTargetUUID, u); err != nil { + if err := queriesmgr.CreateTarget(newQuery.Name, queries.QueryTargetUUID, u); err != nil { responseMessage = "error creating query UUID target" responseCode = http.StatusInternalServerError log.Printf("%s %v", responseMessage, err) @@ -238,7 +226,7 @@ func queryRunPOSTHandler(w http.ResponseWriter, r *http.Request) { if len(q.Hosts) > 0 { for _, h := range q.Hosts { if (h != "") && nodesmgr.CheckByHost(h) { - if err := queriesmgr.CreateTarget(queryName, queries.QueryTargetLocalname, h); err != nil { + if err := queriesmgr.CreateTarget(newQuery.Name, queries.QueryTargetLocalname, h); err != nil { responseMessage = "error creating query hostname target" responseCode = http.StatusInternalServerError log.Printf("%s %v", responseMessage, err) @@ -251,7 +239,7 @@ func queryRunPOSTHandler(w http.ResponseWriter, r *http.Request) { // Remove duplicates from expected expectedClear := removeStringDuplicates(expected) // Update value for expected - if err := queriesmgr.SetExpected(queryName, len(expectedClear)); err != nil { + if err := queriesmgr.SetExpected(newQuery.Name, len(expectedClear)); err != nil { responseMessage = "error setting expected" responseCode = http.StatusInternalServerError log.Printf("%s %v", responseMessage, err) @@ -313,7 +301,7 @@ func carvesRunPOSTHandler(w http.ResponseWriter, r *http.Request) { } query := generateCarveQuery(c.Path, false) // Prepare and create new carve - carveName := "carve_" + generateQueryName() + carveName := generateCarveName() newQuery := queries.DistributedQuery{ Query: query, Name: carveName, @@ -323,7 +311,6 @@ func carvesRunPOSTHandler(w http.ResponseWriter, r *http.Request) { Active: true, Completed: false, Deleted: false, - Repeat: 0, Type: queries.CarveQueryType, Path: c.Path, } @@ -1244,6 +1231,25 @@ func usersPOSTHandler(w http.ResponseWriter, r *http.Request) { log.Printf("DebugService: %s %v", responseMessage, err) } } + if newUser.Admin { + token, exp, err := adminUsers.CreateToken(newUser.Username, jwtConfig.HoursToExpire, jwtConfig.JWTSecret) + if err != nil { + responseMessage = "error creating token" + responseCode = http.StatusInternalServerError + if settingsmgr.DebugService(settings.ServiceAdmin) { + log.Printf("DebugService: %s %v", responseMessage, err) + } + goto send_response + } + if err = adminUsers.UpdateToken(newUser.Username, token, exp); err != nil { + responseMessage = "error saving token" + responseCode = http.StatusInternalServerError + if settingsmgr.DebugService(settings.ServiceAdmin) { + log.Printf("DebugService: %s %v", responseMessage, err) + } + goto send_response + } + } responseMessage = "User added successfully" } } @@ -1279,9 +1285,27 @@ func usersPOSTHandler(w http.ResponseWriter, r *http.Request) { if settingsmgr.DebugService(settings.ServiceAdmin) { log.Printf("DebugService: %s %v", responseMessage, err) } - } else { - responseMessage = "Admin changed" } + if u.Admin { + token, exp, err := adminUsers.CreateToken(u.Username, jwtConfig.HoursToExpire, jwtConfig.JWTSecret) + if err != nil { + responseMessage = "error creating token" + responseCode = http.StatusInternalServerError + if settingsmgr.DebugService(settings.ServiceAdmin) { + log.Printf("DebugService: %s %v", responseMessage, err) + } + goto send_response + } + if err = adminUsers.UpdateToken(u.Username, token, exp); err != nil { + responseMessage = "error saving token" + responseCode = http.StatusInternalServerError + if settingsmgr.DebugService(settings.ServiceAdmin) { + log.Printf("DebugService: %s %v", responseMessage, err) + } + goto send_response + } + } + responseMessage = "Admin changed" } } } else { diff --git a/cmd/admin/handlers-tokens.go b/cmd/admin/handlers-tokens.go new file mode 100644 index 00000000..1b3f51e8 --- /dev/null +++ b/cmd/admin/handlers-tokens.go @@ -0,0 +1,136 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" + + "github.com/gorilla/mux" + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/utils" +) + +// TokenJSON to be used to populate a JSON token +type TokenJSON struct { + Token string `json:"token"` + Expires string `json:"expires"` + ExpiresTS string `json:"expires_ts"` +} + +// Handle GET requests for /tokens/{username} +func tokensGETHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricTokenReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAdmin), false) + // Get context data + ctx := r.Context().Value(contextKey("session")).(contextValue) + // Check permissions + if !checkAdminLevel(ctx[ctxLevel]) { + log.Printf("%s has insuficient permissions", ctx[ctxUser]) + incMetric(metricTokenErr) + return + } + vars := mux.Vars(r) + // Extract username + username, ok := vars["username"] + if !ok { + log.Println("error getting username") + incMetric(metricTokenErr) + return + } + returned := TokenJSON{} + if adminUsers.Exists(username) { + user, err := adminUsers.Get(username) + if err != nil { + log.Println("error getting user") + incMetric(metricTokenErr) + return + } + // Prepare data to be returned + returned = TokenJSON{ + Token: user.APIToken, + Expires: user.TokenExpire.String(), + ExpiresTS: user.TokenExpire.String(), + } + } + // Serialize JSON + returnedJSON, err := json.Marshal(returned) + if err != nil { + log.Printf("error serializing JSON %v", err) + incMetric(metricTokenErr) + return + } + // Header to serve JSON + w.Header().Set("Content-Type", JSONApplicationUTF8) + w.WriteHeader(http.StatusOK) + _, _ = w.Write(returnedJSON) + incMetric(metricTokenOK) +} + +// Handle POST request for /tokens/{username}/refresh +func tokensPOSTHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricTokenReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAdmin), false) + // Get context data + ctx := r.Context().Value(contextKey("session")).(contextValue) + // Check permissions + if !checkAdminLevel(ctx[ctxLevel]) { + adminErrorResponse(w, "insuficient permissions", http.StatusForbidden, nil) + incMetric(metricTokenErr) + return + } + vars := mux.Vars(r) + // Extract username + username, ok := vars["username"] + if !ok { + incMetric(metricTokenErr) + adminErrorResponse(w, "error getting username", http.StatusInternalServerError, nil) + return + } + // Parse request JSON body + if settingsmgr.DebugService(settings.ServiceAdmin) { + log.Println("DebugService: Decoding POST body") + } + var t TokenRequest + response := TokenResponse{} + if err := json.NewDecoder(r.Body).Decode(&t); err == nil { + // Check CSRF Token + if checkCSRFToken(ctx[ctxCSRF], t.CSRFToken) { + if adminUsers.Exists(username) { + user, err := adminUsers.Get(username) + if err != nil { + adminErrorResponse(w, "error getting user", http.StatusInternalServerError, err) + return + } + if user.Admin { + token, exp, err := adminUsers.CreateToken(user.Username, jwtConfig.HoursToExpire, jwtConfig.JWTSecret) + if err != nil { + adminErrorResponse(w, "error creating token", http.StatusInternalServerError, err) + return + } + if err = adminUsers.UpdateToken(user.Username, token, exp); err != nil { + adminErrorResponse(w, "error updating token", http.StatusInternalServerError, err) + return + } + response = TokenResponse{ + Token: user.APIToken, + ExpirationTS: exp.String(), + Expiration: exp.String(), + } + } + } else { + adminErrorResponse(w, "user not found", http.StatusNotFound, nil) + return + } + } else { + adminErrorResponse(w, "invalid CSRF token", http.StatusForbidden, nil) + return + } + } else { + incMetric(metricTokenErr) + adminErrorResponse(w, "error parsing POST body", http.StatusInternalServerError, nil) + return + } + // Serialize and serve JSON + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, response) + incMetric(metricTokenOK) +} diff --git a/cmd/admin/handlers.go b/cmd/admin/handlers.go index 7f58710d..9d06e35d 100644 --- a/cmd/admin/handlers.go +++ b/cmd/admin/handlers.go @@ -14,6 +14,9 @@ const ( metricJSONReq = "admin-json-req" metricJSONErr = "admin-json-err" metricJSONOK = "admin-json-ok" + metricTokenReq = "admin-token-req" + metricTokenErr = "admin-token-err" + metricTokenOK = "admin-token-ok" metricHealthReq = "health-req" metricHealthOK = "health-ok" ) @@ -21,6 +24,9 @@ const ( // JSONApplication for Content-Type headers const JSONApplication string = "application/json" +// ContentType for header key +const contentType string = "Content-Type" + // JSONApplicationUTF8 for Content-Type headers, UTF charset const JSONApplicationUTF8 string = JSONApplication + "; charset=UTF-8" diff --git a/cmd/admin/jwt.go b/cmd/admin/jwt.go new file mode 100644 index 00000000..1600d721 --- /dev/null +++ b/cmd/admin/jwt.go @@ -0,0 +1,30 @@ +package main + +import ( + "log" + + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/types" + "github.com/spf13/viper" +) + +// Function to load the configuration file +func loadJWTConfiguration(file string) (types.JSONConfigurationJWT, error) { + var cfg types.JSONConfigurationJWT + log.Printf("Loading %s", file) + // Load file and read config + viper.SetConfigFile(file) + err := viper.ReadInConfig() + if err != nil { + return cfg, err + } + // JWT values + headersRaw := viper.Sub(settings.AuthJWT) + err = headersRaw.Unmarshal(&cfg) + if err != nil { + return cfg, err + } + + // No errors! + return cfg, nil +} diff --git a/cmd/admin/main.go b/cmd/admin/main.go index 1d0264ed..0d4cee9b 100644 --- a/cmd/admin/main.go +++ b/cmd/admin/main.go @@ -50,6 +50,8 @@ const ( dbConfigurationFile string = "config/db.json" // Default SAML configuration file samlConfigurationFile string = "config/saml.json" + // Default JWT configuration file + jwtConfigurationFile string = "config/jwt.json" // Default Headers configuration file headersConfigurationFile string = "config/headers.json" // osquery version to display tables @@ -90,6 +92,7 @@ var ( dbFlag *string samlFlag *string headersFlag *string + jwtFlag *string ) // SAML variables @@ -101,7 +104,12 @@ var ( // Headers variables var ( - headersConfig types.JSONConfigurationHeaders + headersConfig types.JSONConfigurationHeaders +) + +// JWT variables +var ( + jwtConfig types.JSONConfigurationJWT ) // Valid values for auth in configuration @@ -157,6 +165,7 @@ func init() { dbFlag = flag.String("D", dbConfigurationFile, "DB configuration JSON file to use.") samlFlag = flag.String("S", samlConfigurationFile, "SAML configuration JSON file to use.") headersFlag = flag.String("H", headersConfigurationFile, "Headers configuration JSON file to use.") + jwtFlag = flag.String("J", jwtConfigurationFile, "JWT configuration JSON file to use.") // Parse all flags flag.Parse() if *versionFlag { @@ -169,29 +178,29 @@ func init() { if err != nil { log.Fatalf("Error loading %s - %s", *configFlag, err) } - // Load osquery tables JSON osqueryTables, err = loadOsqueryTables(osqueryTablesFile) if err != nil { log.Fatalf("Error loading osquery tables %s", err) } - // Load configuration for SAML if enabled if adminConfig.Auth == settings.AuthSAML { samlConfig, err = loadSAML(*samlFlag) if err != nil { log.Fatalf("Error loading %s - %s", *samlFlag, err) } - return } - // Load configuration for Headers if enabled if adminConfig.Auth == settings.AuthHeaders { headersConfig, err = loadHeaders(*headersFlag) if err != nil { log.Fatalf("Error loading %s - %s", *headersFlag, err) } - return + } + // Load JWT configuration + jwtConfig, err = loadJWTConfiguration(*jwtFlag) + if err != nil { + log.Fatalf("Error loading %s - %s", *jwtFlag, err) } } @@ -354,6 +363,9 @@ func main() { // Admin: manage users routerAdmin.Handle("/users", handlerAuthCheck(http.HandlerFunc(usersGETHandler))).Methods("GET") routerAdmin.Handle("/users", handlerAuthCheck(http.HandlerFunc(usersPOSTHandler))).Methods("POST") + // Admin: manage tokens + routerAdmin.Handle("/tokens/{username}", handlerAuthCheck(http.HandlerFunc(tokensGETHandler))).Methods("GET") + routerAdmin.Handle("/tokens/{username}/refresh", handlerAuthCheck(http.HandlerFunc(tokensPOSTHandler))).Methods("POST") // logout routerAdmin.Handle("/logout", handlerAuthCheck(http.HandlerFunc(logoutHandler))).Methods("POST") diff --git a/cmd/admin/sessions.go b/cmd/admin/sessions.go index 2c9c9717..12e8882a 100644 --- a/cmd/admin/sessions.go +++ b/cmd/admin/sessions.go @@ -44,7 +44,7 @@ type UserSession struct { IPAddress string UserAgent string ExpiresAt time.Time - Cookie string `gorm:"index"` + Cookie string `gorm:"index"` Values sessionValues `gorm:"-"` } diff --git a/cmd/admin/static/js/functions.js b/cmd/admin/static/js/functions.js index e8e3f3a2..eb9c2de2 100644 --- a/cmd/admin/static/js/functions.js +++ b/cmd/admin/static/js/functions.js @@ -1,4 +1,4 @@ -function sendPostRequest(req_data, req_url, _redir, _modal) { +function sendPostRequest(req_data, req_url, _redir, _modal, _callback) { $.ajax({ url: req_url, dataType: 'json', @@ -6,7 +6,7 @@ function sendPostRequest(req_data, req_url, _redir, _modal) { contentType: 'application/json', data: JSON.stringify(req_data), processData: false, - success: function(data, textStatus, jQxhr){ + success: function (data, textStatus, jQxhr) { console.log('OK'); console.log(data); if (_modal) { @@ -16,8 +16,11 @@ function sendPostRequest(req_data, req_url, _redir, _modal) { if (_redir !== "") { window.location.replace(_redir); } + if (_callback) { + _callback(data); + } }, - error: function(jqXhr, textStatus, errorThrown){ + error: function (jqXhr, textStatus, errorThrown) { var _clientmsg = 'Client: ' + errorThrown; var _serverJSON = $.parseJSON(jqXhr.responseText); var _servermsg = 'Server: ' + _serverJSON.message; @@ -27,4 +30,4 @@ function sendPostRequest(req_data, req_url, _redir, _modal) { $("#errorModal").modal(); } }); -} \ No newline at end of file +} diff --git a/cmd/admin/static/js/login.js b/cmd/admin/static/js/login.js index 8dbdf09a..d04e6cad 100644 --- a/cmd/admin/static/js/login.js +++ b/cmd/admin/static/js/login.js @@ -1,7 +1,7 @@ function sendLogin() { var _user = $("#login_user").val(); var _password = $("#login_password").val(); - + var _url = '/login'; var data = { username: _user, @@ -12,7 +12,7 @@ function sendLogin() { function sendLogout() { var _csrf = $("#csrftoken").val(); - + var _url = '/logout'; var data = { csrftoken: _csrf diff --git a/cmd/admin/static/js/users.js b/cmd/admin/static/js/users.js index 585de855..e6c9f518 100644 --- a/cmd/admin/static/js/users.js +++ b/cmd/admin/static/js/users.js @@ -65,3 +65,29 @@ function deleteUser(_user) { }; sendPostRequest(data, _url, _url, false); } + +function showAPIToken(_token, _exp, _username) { + $("#user_api_token").val(_token); + $("#user_token_expiration").val(_exp); + $("#user_token_username").val(_username); + $("#apiTokenModal").modal(); +} + +function refreshUserToken() { + $("#refreshTokenButton").prop("disabled", true); + $("#refreshTokenButton").html(''); + var _csrftoken = $("#csrftoken").val(); + var _username = $("#user_token_username").val(); + + var data = { + csrftoken: _csrftoken, + username: _username, + }; + sendPostRequest(data, '/tokens/' + _username + '/refresh', '', false, function (data) { + console.log(data); + $("#user_api_token").val(data.token); + $("#user_token_expiration").val(data.exp_ts); + $("#refreshTokenButton").prop("disabled", false); + $("#refreshTokenButton").text('Refresh'); + }); +} diff --git a/cmd/admin/templates/carves-run.html b/cmd/admin/templates/carves-run.html index 3378b363..a58d1ff8 100644 --- a/cmd/admin/templates/carves-run.html +++ b/cmd/admin/templates/carves-run.html @@ -69,26 +69,6 @@ ex. ubuntu -
diff --git a/cmd/admin/templates/queries-run.html b/cmd/admin/templates/queries-run.html index 9b792de8..fc2f9ae0 100644 --- a/cmd/admin/templates/queries-run.html +++ b/cmd/admin/templates/queries-run.html @@ -115,26 +115,6 @@ ex. ubuntu
-
diff --git a/cmd/admin/templates/users.html b/cmd/admin/templates/users.html index 3b56eacc..0e9e868e 100644 --- a/cmd/admin/templates/users.html +++ b/cmd/admin/templates/users.html @@ -44,12 +44,13 @@ Username Email - Fullname + Fullname Last IP Last UserAgent Admin Last Session + @@ -67,6 +68,13 @@ {{ pastTimeAgo $e.LastAccess }} + + {{ if $e.Admin }} + + {{ end }} +
+ + + {{ template "page-modals" . }}
diff --git a/cmd/admin/types-requests.go b/cmd/admin/types-requests.go index c6ecea28..e08332e2 100644 --- a/cmd/admin/types-requests.go +++ b/cmd/admin/types-requests.go @@ -19,7 +19,6 @@ type DistributedQueryRequest struct { UUIDs []string `json:"uuid_list"` Hosts []string `json:"host_list"` Query string `json:"query"` - Repeat int `json:"repeat"` } // DistributedCarveRequest to receive carve requests @@ -30,7 +29,6 @@ type DistributedCarveRequest struct { UUIDs []string `json:"uuid_list"` Hosts []string `json:"host_list"` Path string `json:"path"` - Repeat int `json:"repeat"` } // DistributedQueryActionRequest to receive query requests @@ -117,3 +115,16 @@ type UsersRequest struct { type AdminResponse struct { Message string `json:"message"` } + +// TokenRequest to receive API token related requests +type TokenRequest struct { + CSRFToken string `json:"csrftoken"` + Username string `json:"username"` +} + +// TokenResponse to be returned to API token requests +type TokenResponse struct { + Token string `json:"token"` + Expiration string `json:"expiration"` + ExpirationTS string `json:"exp_ts"` +} diff --git a/cmd/admin/utils.go b/cmd/admin/utils.go index 1eb63245..05ab4995 100644 --- a/cmd/admin/utils.go +++ b/cmd/admin/utils.go @@ -13,12 +13,14 @@ import ( "fmt" "io/ioutil" "log" + "net/http" "os" "strconv" "strings" "time" jwt "github.com/dgrijalva/jwt-go" + "github.com/jmpsec/osctrl/pkg/queries" "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/types" ) @@ -55,8 +57,18 @@ func checkCSRFToken(ctxToken, receivedToken string) bool { return (strings.TrimSpace(ctxToken) == strings.TrimSpace(receivedToken)) } -// Helper to generate a random MD5 to be used as query name +// Helper to generate a random query name func generateQueryName() string { + return "query_" + randomForNames() +} + +// Helper to generate a random carve name +func generateCarveName() string { + return "carve_" + randomForNames() +} + +// Helper to generate a random MD5 to be used with queries/carves +func randomForNames() string { b := make([]byte, 32) _, _ = rand.Read(b) hasher := md5.New() @@ -64,6 +76,35 @@ func generateQueryName() string { return hex.EncodeToString(hasher.Sum(nil)) } +// Helper to determine if a query may be a carve +func newQueryReady(user, query string) queries.DistributedQuery { + if strings.Contains(query, "carve(") || strings.Contains(query, "carve=1") { + return queries.DistributedQuery{ + Query: query, + Name: generateCarveName(), + Creator: user, + Expected: 0, + Executions: 0, + Active: true, + Completed: false, + Deleted: false, + Type: queries.CarveQueryType, + Path: query, + } + } + return queries.DistributedQuery{ + Query: query, + Name: generateQueryName(), + Creator: user, + Expected: 0, + Executions: 0, + Active: true, + Completed: false, + Deleted: false, + Type: queries.StandardQueryType, + } +} + // Helper to generate the carve query func generateCarveQuery(file string, glob bool) string { if glob { @@ -72,6 +113,7 @@ func generateCarveQuery(file string, glob bool) string { return "SELECT * FROM carves WHERE carve=1 AND path = '" + file + "';" } +// Helper to verify if a platform is valid func checkValidPlatform(platform string) bool { platforms, err := nodesmgr.GetAllPlatforms() if err != nil { @@ -326,13 +368,34 @@ func parseJWTFromCookie(keypair tls.Certificate, cookie string) (JWTData, error) // Helper to prepare template metadata func templateMetadata(ctx contextValue, service, version string) TemplateMetadata { return TemplateMetadata{ - Username: ctx[ctxUser], - Level: ctx[ctxLevel], - CSRFToken: ctx[ctxCSRF], - Service: service, - Version: version, + Username: ctx[ctxUser], + Level: ctx[ctxLevel], + CSRFToken: ctx[ctxCSRF], + Service: service, + Version: version, TLSDebug: settingsmgr.DebugService(settings.ServiceTLS), AdminDebug: settingsmgr.DebugService(settings.ServiceAdmin), AdminDebugHTTP: settingsmgr.DebugHTTP(settings.ServiceAdmin), } } + +// Helper to send HTTP response +func apiHTTPResponse(w http.ResponseWriter, cType string, code int, data interface{}) { + if cType != "" { + w.Header().Set(contentType, cType) + } + content, err := json.Marshal(data) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.Printf("error serializing response: %v", err) + content = []byte("error serializing response") + } + w.WriteHeader(code) + _, _ = w.Write(content) +} + +// Helper to handle admin error responses +func adminErrorResponse(w http.ResponseWriter, msg string, code int, err error) { + log.Printf("%s: %v", msg, err) + apiHTTPResponse(w, JSONApplicationUTF8, code, AdminResponse{Message: msg}) +} diff --git a/cmd/api/auth.go b/cmd/api/auth.go new file mode 100644 index 00000000..681d7d29 --- /dev/null +++ b/cmd/api/auth.go @@ -0,0 +1,82 @@ +package main + +import ( + "context" + "log" + "net/http" + "strings" + + "github.com/jmpsec/osctrl/pkg/settings" +) + +// contextValue to hold session data in the context +type contextValue map[string]string + +// contextKey to help with the context key, to pass session data +type contextKey string + +const ( + // Username to use when there is no authentication + usernameAPI string = "osctrl-api-user" + // Key to identify request context + contextAPI string = "osctrl-api-context" +) + +const ( + adminLevel string = "admin" + userLevel string = "user" +) + +// Helper to verify if user is an admin +func checkAdminLevel(level string) bool { + return (level == adminLevel) +} + +// Helper to extract token from header +func extractHeaderToken(r *http.Request) string { + reqToken := r.Header.Get("Authorization") + splitToken := strings.Split(reqToken, "Bearer") + if len(splitToken) != 2 { + return "" + } + return strings.TrimSpace(splitToken[1]) +} + +// Handler to check access to a resource based on the authentication enabled +func handlerAuthCheck(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch apiConfig.Auth { + case settings.AuthNone: + // Set middleware values + s := make(contextValue) + s["user"] = "admin" + ctx := context.WithValue(r.Context(), contextKey(contextAPI), s) + // Access granted + h.ServeHTTP(w, r.WithContext(ctx)) + case settings.AuthJWT: + // Set middleware values + //utils.DebugHTTPDump(r, true, true) + token := extractHeaderToken(r) + if token == "" { + http.Redirect(w, r, forbiddenPath, http.StatusForbidden) + return + } + claims, valid := apiUsers.CheckToken(jwtConfig.JWTSecret, token) + if !valid { + http.Redirect(w, r, forbiddenPath, http.StatusForbidden) + return + } + // Update metadata for the user + err := apiUsers.UpdateTokenIPAddress(r.Header.Get("X-Real-IP"), claims.Username) + if err != nil { + log.Printf("error updating token for user %s: %v", claims.Username, err) + } + // Set middleware values + s := make(contextValue) + s["user"] = claims.Username + ctx := context.WithValue(r.Context(), contextKey(contextAPI), s) + // Access granted + h.ServeHTTP(w, r.WithContext(ctx)) + } + }) +} diff --git a/cmd/api/db.go b/cmd/api/db.go new file mode 100644 index 00000000..9931a87f --- /dev/null +++ b/cmd/api/db.go @@ -0,0 +1,60 @@ +package main + +import ( + "fmt" + "log" + "time" + + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/postgres" + "github.com/jmpsec/osctrl/pkg/types" + "github.com/spf13/viper" +) + +// Function to load the DB configuration file and assign to variables +func loadDBConfiguration(file string) (types.JSONConfigurationDB, error) { + var config types.JSONConfigurationDB + log.Printf("Loading %s", file) + // Load file and read config + viper.SetConfigFile(file) + err := viper.ReadInConfig() + if err != nil { + return config, err + } + // Backend values + dbRaw := viper.Sub("db") + err = dbRaw.Unmarshal(&config) + if err != nil { + return config, err + } + // No errors! + return config, nil +} + +// Get PostgreSQL DB using GORM +func getDB(file string) *gorm.DB { + // Load DB configuration + dbConfig, err := loadDBConfiguration(file) + if err != nil { + log.Fatalf("Error loading DB configuration %v", err) + } + t := "host=%s port=%s dbname=%s user=%s password=%s sslmode=disable" + postgresDSN := fmt.Sprintf( + t, dbConfig.Host, dbConfig.Port, dbConfig.Name, dbConfig.Username, dbConfig.Password) + db, err := gorm.Open("postgres", postgresDSN) + if err != nil { + log.Fatalf("Failed to open database connection: %v", err) + } + // Performance settings for DB access + db.DB().SetMaxIdleConns(dbConfig.MaxIdleConns) + db.DB().SetMaxOpenConns(dbConfig.MaxOpenConns) + db.DB().SetConnMaxLifetime(time.Second * time.Duration(dbConfig.ConnMaxLifetime)) + + return db +} + +// Automigrate of tables +func automigrateDB() error { + var err error + return err +} diff --git a/cmd/api/handlers-environments.go b/cmd/api/handlers-environments.go new file mode 100644 index 00000000..91338c63 --- /dev/null +++ b/cmd/api/handlers-environments.go @@ -0,0 +1,69 @@ +package main + +import ( + "log" + "net/http" + + "github.com/gorilla/mux" + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/utils" +) + +// GET Handler to return one environment as JSON +func apiEnvironmentHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), false) + vars := mux.Vars(r) + // Extract name + name, ok := vars["name"] + if !ok { + incMetric(metricAPIErr) + apiErrorResponse(w, "error getting name", http.StatusInternalServerError, nil) + return + } + // Get context data and check access + ctx := r.Context().Value(contextKey(contextAPI)).(contextValue) + if !apiUsers.IsAdmin(ctx["user"]) { + log.Printf("attempt to use API by user %s", ctx["user"]) + apiErrorResponse(w, "no access", http.StatusForbidden, nil) + return + } + // Get environment by name + env, err := envs.Get(name) + if err != nil { + incMetric(metricAPIErr) + if err.Error() == "record not found" { + log.Printf("environment not found: %s", name) + apiErrorResponse(w, "environment not found", http.StatusNotFound, nil) + } else { + apiErrorResponse(w, "error getting environment", http.StatusInternalServerError, err) + } + return + } + // Header to serve JSON + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, env) + incMetric(metricAPIOK) +} + +// GET Handler to return all environments as JSON +func apiEnvironmentsHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), false) + // Get context data and check access + ctx := r.Context().Value(contextKey(contextAPI)).(contextValue) + if !apiUsers.IsAdmin(ctx["user"]) { + log.Printf("attempt to use API by user %s", ctx["user"]) + apiErrorResponse(w, "no access", http.StatusForbidden, nil) + return + } + // Get platforms + envAll, err := envs.All() + if err != nil { + incMetric(metricAPIErr) + apiErrorResponse(w, "error getting environments", http.StatusInternalServerError, err) + return + } + // Header to serve JSON + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, envAll) + incMetric(metricAPIOK) +} diff --git a/cmd/api/handlers-nodes.go b/cmd/api/handlers-nodes.go new file mode 100644 index 00000000..16e8245d --- /dev/null +++ b/cmd/api/handlers-nodes.go @@ -0,0 +1,75 @@ +package main + +import ( + "log" + "net/http" + + "github.com/gorilla/mux" + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/utils" +) + +// GET Handler for single JSON nodes +func apiNodeHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), false) + vars := mux.Vars(r) + // Extract uuid + uuid, ok := vars["uuid"] + if !ok { + incMetric(metricAPIErr) + apiErrorResponse(w, "error getting uuid", http.StatusInternalServerError, nil) + return + } + // Get context data and check access + ctx := r.Context().Value(contextKey(contextAPI)).(contextValue) + if !apiUsers.IsAdmin(ctx["user"]) { + log.Printf("attempt to use API by user %s", ctx["user"]) + apiErrorResponse(w, "no access", http.StatusForbidden, nil) + return + } + // Get node by UUID + node, err := nodesmgr.GetByUUID(uuid) + if err != nil { + incMetric(metricAPIErr) + if err.Error() == "record not found" { + log.Printf("node not found: %s", uuid) + apiErrorResponse(w, "node not found", http.StatusNotFound, nil) + } else { + apiErrorResponse(w, "error getting node", http.StatusInternalServerError, err) + } + return + } + // Serialize and serve JSON + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, node) + incMetric(metricAPIOK) +} + +// GET Handler for multiple JSON nodes +func apiNodesHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), false) + // Get context data and check access + ctx := r.Context().Value(contextKey(contextAPI)).(contextValue) + if !apiUsers.IsAdmin(ctx["user"]) { + log.Printf("attempt to use API by user %s", ctx["user"]) + apiErrorResponse(w, "no access", http.StatusForbidden, nil) + return + } + // Get nodes + nodes, err := nodesmgr.Gets("all", 0) + if err != nil { + incMetric(metricAPIErr) + apiErrorResponse(w, "error getting nodes", http.StatusInternalServerError, err) + return + } + if len(nodes) == 0 { + incMetric(metricAPIErr) + log.Printf("no nodes") + apiErrorResponse(w, "no nodes", http.StatusNotFound, nil) + return + } + // Serialize and serve JSON + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, nodes) + incMetric(metricAPIOK) +} diff --git a/cmd/api/handlers-platforms.go b/cmd/api/handlers-platforms.go new file mode 100644 index 00000000..cbc608bd --- /dev/null +++ b/cmd/api/handlers-platforms.go @@ -0,0 +1,32 @@ +package main + +import ( + "log" + "net/http" + + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/utils" +) + +// GET Handler for multiple JSON platforms +func apiPlatformsHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), false) + // Get context data and check access + ctx := r.Context().Value(contextKey(contextAPI)).(contextValue) + if !apiUsers.IsAdmin(ctx["user"]) { + log.Printf("attempt to use API by user %s", ctx["user"]) + apiErrorResponse(w, "no access", http.StatusForbidden, nil) + return + } + // Get platforms + platforms, err := nodesmgr.GetAllPlatforms() + if err != nil { + incMetric(metricAPIErr) + apiErrorResponse(w, "error getting platforms", http.StatusInternalServerError, err) + return + } + // Serialize and serve JSON + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, platforms) + incMetric(metricAPIOK) +} diff --git a/cmd/api/handlers-queries.go b/cmd/api/handlers-queries.go new file mode 100644 index 00000000..cb0e3eca --- /dev/null +++ b/cmd/api/handlers-queries.go @@ -0,0 +1,223 @@ +package main + +import ( + "encoding/json" + "log" + "net/http" + + "github.com/gorilla/mux" + "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/utils" +) + +// GET Handler to return a single query in JSON +func apiQueryShowHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), false) + vars := mux.Vars(r) + // Extract name + name, ok := vars["name"] + if !ok { + incMetric(metricAPIErr) + apiErrorResponse(w, "error getting name", http.StatusInternalServerError, nil) + return + } + // Get context data and check access + ctx := r.Context().Value(contextKey(contextAPI)).(contextValue) + if !apiUsers.IsAdmin(ctx["user"]) { + log.Printf("attempt to use API by user %s", ctx["user"]) + apiErrorResponse(w, "no access", http.StatusForbidden, nil) + return + } + // Get query by name + query, err := queriesmgr.Get(name) + if err != nil { + incMetric(metricAPIErr) + if err.Error() == "record not found" { + log.Printf("query not found: %s", name) + apiErrorResponse(w, "query not found", http.StatusNotFound, nil) + } else { + apiErrorResponse(w, "error getting query", http.StatusInternalServerError, err) + } + return + } + // Serialize and serve JSON + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, query) + incMetric(metricAPIOK) +} + +// POST Handler to run a query +func apiQueriesRunHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), false) + // Get context data and check access + ctx := r.Context().Value(contextKey(contextAPI)).(contextValue) + if !apiUsers.IsAdmin(ctx["user"]) { + log.Printf("attempt to use API by user %s", ctx["user"]) + apiErrorResponse(w, "no access", http.StatusForbidden, nil) + return + } + var q DistributedQueryRequest + // Parse request JSON body + if err := json.NewDecoder(r.Body).Decode(&q); err != nil { + incMetric(metricAPIErr) + apiErrorResponse(w, "error parsing POST body", http.StatusInternalServerError, err) + return + } + // FIXME check validity of query + // Query can not be empty + if q.Query == "" { + apiErrorResponse(w, "query can not be empty", http.StatusInternalServerError, nil) + return + } + // Prepare and create new query + queryName := generateQueryName() + newQuery := queries.DistributedQuery{ + Query: q.Query, + Name: queryName, + Creator: ctx["user"], + Expected: 0, + Executions: 0, + Active: true, + Completed: false, + Deleted: false, + Type: queries.StandardQueryType, + } + if err := queriesmgr.Create(newQuery); err != nil { + apiErrorResponse(w, "error creating query", http.StatusInternalServerError, err) + return + } + // Temporary list of UUIDs to calculate Expected + var expected []string + // Create environment target + if len(q.Environments) > 0 { + for _, e := range q.Environments { + if (e != "") && envs.Exists(e) { + if err := queriesmgr.CreateTarget(queryName, queries.QueryTargetEnvironment, e); err != nil { + apiErrorResponse(w, "error creating query environment target", http.StatusInternalServerError, err) + return + } + nodes, err := nodesmgr.GetByEnv(e, "active", settingsmgr.InactiveHours()) + if err != nil { + apiErrorResponse(w, "error getting nodes by environment", http.StatusInternalServerError, err) + return + } + for _, n := range nodes { + expected = append(expected, n.UUID) + } + } + } + } + // Create platform target + if len(q.Platforms) > 0 { + for _, p := range q.Platforms { + if (p != "") && checkValidPlatform(p) { + if err := queriesmgr.CreateTarget(queryName, queries.QueryTargetPlatform, p); err != nil { + apiErrorResponse(w, "error creating query platform target", http.StatusInternalServerError, err) + return + } + nodes, err := nodesmgr.GetByPlatform(p, "active", settingsmgr.InactiveHours()) + if err != nil { + apiErrorResponse(w, "error getting nodes by platform", http.StatusInternalServerError, err) + return + } + for _, n := range nodes { + expected = append(expected, n.UUID) + } + } + } + } + // Create UUIDs target + if len(q.UUIDs) > 0 { + for _, u := range q.UUIDs { + if (u != "") && nodesmgr.CheckByUUID(u) { + if err := queriesmgr.CreateTarget(queryName, queries.QueryTargetUUID, u); err != nil { + apiErrorResponse(w, "error creating query UUID target", http.StatusInternalServerError, err) + return + } + expected = append(expected, u) + } + } + } + // Create hostnames target + if len(q.Hosts) > 0 { + for _, h := range q.Hosts { + if (h != "") && nodesmgr.CheckByHost(h) { + if err := queriesmgr.CreateTarget(queryName, queries.QueryTargetLocalname, h); err != nil { + apiErrorResponse(w, "error creating query hostname target", http.StatusInternalServerError, err) + return + } + expected = append(expected, h) + } + } + } + // Remove duplicates from expected + expectedClear := removeStringDuplicates(expected) + // Update value for expected + if err := queriesmgr.SetExpected(queryName, len(expectedClear)); err != nil { + apiErrorResponse(w, "error setting expected", http.StatusInternalServerError, err) + return + } + // Return query name as serialized response + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, ApiQueriesResponse{Name: newQuery.Name}) + incMetric(metricAPIOK) +} + +// GET Handler to return multiple queries in JSON +func apiQueriesShowHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), false) + // Get queries + queries, err := queriesmgr.GetQueries(queries.TargetAll) + if err != nil { + incMetric(metricAPIErr) + apiErrorResponse(w, "error getting queries", http.StatusInternalServerError, err) + return + } + if len(queries) == 0 { + incMetric(metricAPIErr) + log.Printf("no queries") + apiErrorResponse(w, "no queries", http.StatusNotFound, nil) + return + } + // Serialize and serve JSON + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, queries) + incMetric(metricAPIOK) +} + +// GET Handler to return a single query results in JSON +func apiQueryResultsHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), false) + vars := mux.Vars(r) + // Extract name + name, ok := vars["name"] + if !ok { + incMetric(metricAPIErr) + apiErrorResponse(w, "error getting name", http.StatusInternalServerError, nil) + return + } + // Get context data and check access + ctx := r.Context().Value(contextKey(contextAPI)).(contextValue) + if !apiUsers.IsAdmin(ctx["user"]) { + log.Printf("attempt to use API by user %s", ctx["user"]) + apiErrorResponse(w, "no access", http.StatusForbidden, nil) + return + } + // Get query by name + queryLogs, err := postgresQueryLogs(name) + if err != nil { + incMetric(metricAPIErr) + if err.Error() == "record not found" { + log.Printf("query not found: %s", name) + apiErrorResponse(w, "query not found", http.StatusNotFound, nil) + } else { + apiErrorResponse(w, "error getting results", http.StatusInternalServerError, err) + } + return + } + // Serialize and serve JSON + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, queryLogs) + incMetric(metricAPIOK) +} diff --git a/cmd/api/handlers.go b/cmd/api/handlers.go new file mode 100644 index 00000000..be59d432 --- /dev/null +++ b/cmd/api/handlers.go @@ -0,0 +1,64 @@ +package main + +import ( + "net/http" + + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/utils" +) + +const ( + metricAPIReq = "api-req" + metricAPIErr = "api-err" + metricAPIOK = "api-ok" + metricHealthReq = "health-req" + metricHealthOK = "health-ok" +) + +// JSONApplication for Content-Type headers +const JSONApplication string = "application/json" + +// ContentType for header key +const contentType string = "Content-Type" + +// JSONApplicationUTF8 for Content-Type headers, UTF charset +const JSONApplicationUTF8 string = JSONApplication + "; charset=UTF-8" + +var errorContent = []byte("❌") +var okContent = []byte("✅") + +// Handle health requests +func healthHTTPHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricHealthReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), true) + // Send response + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, okContent) + incMetric(metricHealthOK) +} + +// Handle root requests +func rootHTTPHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricHealthReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), true) + // Send response + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusOK, okContent) + incMetric(metricHealthOK) +} + +// Handle error requests +func errorHTTPHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAPI), true) + // Send response + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusInternalServerError, errorContent) + incMetric(metricAPIErr) +} + +// Handle forbidden error requests +func forbiddenHTTPHandler(w http.ResponseWriter, r *http.Request) { + incMetric(metricAPIReq) + utils.DebugHTTPDump(r, settingsmgr.DebugHTTP(settings.ServiceAdmin), true) + // Send response + apiHTTPResponse(w, JSONApplicationUTF8, http.StatusForbidden, errorContent) + incMetric(metricAPIErr) +} diff --git a/cmd/api/jwt.go b/cmd/api/jwt.go new file mode 100644 index 00000000..1600d721 --- /dev/null +++ b/cmd/api/jwt.go @@ -0,0 +1,30 @@ +package main + +import ( + "log" + + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/types" + "github.com/spf13/viper" +) + +// Function to load the configuration file +func loadJWTConfiguration(file string) (types.JSONConfigurationJWT, error) { + var cfg types.JSONConfigurationJWT + log.Printf("Loading %s", file) + // Load file and read config + viper.SetConfigFile(file) + err := viper.ReadInConfig() + if err != nil { + return cfg, err + } + // JWT values + headersRaw := viper.Sub(settings.AuthJWT) + err = headersRaw.Unmarshal(&cfg) + if err != nil { + return cfg, err + } + + // No errors! + return cfg, nil +} diff --git a/cmd/api/main.go b/cmd/api/main.go new file mode 100644 index 00000000..7ad375a3 --- /dev/null +++ b/cmd/api/main.go @@ -0,0 +1,287 @@ +package main + +import ( + "flag" + "fmt" + "log" + "net/http" + "time" + + "github.com/jmpsec/osctrl/pkg/carves" + "github.com/jmpsec/osctrl/pkg/environments" + "github.com/jmpsec/osctrl/pkg/metrics" + "github.com/jmpsec/osctrl/pkg/nodes" + "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/settings" + "github.com/jmpsec/osctrl/pkg/types" + "github.com/jmpsec/osctrl/pkg/users" + + "github.com/gorilla/mux" + "github.com/jinzhu/gorm" + _ "github.com/jinzhu/gorm/dialects/postgres" + "github.com/spf13/viper" +) + +const ( + // Project name + projectName string = "osctrl" + // Service name + serviceName string = projectName + "-" + settings.ServiceAPI + // Service version + serviceVersion string = "0.1.9" + // Service description + serviceDescription string = "API service for osctrl" + // Application description + appDescription string = serviceDescription + ", a fast and efficient osquery management" + // Default service configuration file + configurationFile string = "config/" + settings.ServiceAPI + ".json" + // Default DB configuration file + dbConfigurationFile string = "config/db.json" + // Default JWT configuration file + jwtConfigurationFile string = "config/jwt.json" + // Default refreshing interval in seconds + defaultRefresh int = 300 +) + +// Paths +const ( + // HTTP health path + healthPath string = "/health" + // HTTP errors path + errorPath string = "/error" + forbiddenPath string = "/forbidden" + // API prefix path + apiPrefixPath string = "/api" + // API version path + apiVersionPath string = "/v1" + // API nodes path + apiNodesPath string = "/nodes" + // API queries path + apiQueriesPath string = "/queries" + // API carves path + apiCarvesPath string = "/carves" + // API platforms path + apiPlatformsPath string = "/platforms" + // API environments path + apiEnvironmentsPath string = "/environments" +) + +// Global variables +var ( + apiConfig types.JSONConfigurationService + jwtConfig types.JSONConfigurationJWT + db *gorm.DB + apiUsers *users.UserManager + settingsmgr *settings.Settings + envs *environments.Environment + envsmap environments.MapEnvironments + envsTicker *time.Ticker + settingsmap settings.MapSettings + settingsTicker *time.Ticker + nodesmgr *nodes.NodeManager + queriesmgr *queries.Queries + filecarves *carves.Carves + _metrics *metrics.Metrics +) + +// Variables for flags +var ( + versionFlag *bool + configFlag *string + dbFlag *string + jwtFlag *string +) + +// Valid values for auth and logging in configuration +var validAuth = map[string]bool{ + settings.AuthNone: true, + settings.AuthJWT: true, +} +var validLogging = map[string]bool{ + settings.LoggingNone: true, +} + +// Function to load the configuration file and assign to variables +func loadConfiguration(file string) (types.JSONConfigurationService, error) { + var cfg types.JSONConfigurationService + log.Printf("Loading %s", file) + // Load file and read config + viper.SetConfigFile(file) + err := viper.ReadInConfig() + if err != nil { + return cfg, err + } + // TLS endpoint values + tlsRaw := viper.Sub(settings.ServiceAPI) + err = tlsRaw.Unmarshal(&cfg) + if err != nil { + return cfg, err + } + // Check if values are valid + if !validAuth[cfg.Auth] { + return cfg, fmt.Errorf("Invalid auth method") + } + if !validLogging[cfg.Logging] { + return cfg, fmt.Errorf("Invalid logging method") + } + // No errors! + return cfg, nil +} + +// Initialization code +func init() { + var err error + // Command line flags + flag.Usage = apiUsage + // Define flags + versionFlag = flag.Bool("v", false, "Displays the binary version.") + configFlag = flag.String("c", configurationFile, "Service configuration JSON file to use.") + dbFlag = flag.String("D", dbConfigurationFile, "DB configuration JSON file to use.") + jwtFlag = flag.String("J", jwtConfigurationFile, "JWT configuration JSON file to use.") + // Parse all flags + flag.Parse() + if *versionFlag { + apiVersion() + } + // Logging format flags + log.SetFlags(log.Lshortfile) + // Load API configuration + apiConfig, err = loadConfiguration(*configFlag) + if err != nil { + log.Fatalf("Error loading %s - %s", *configFlag, err) + } + // Load JWT configuration + // Load configuration for JWT if enabled + if apiConfig.Auth == settings.AuthJWT { + jwtConfig, err = loadJWTConfiguration(*jwtFlag) + if err != nil { + log.Fatalf("Error loading %s - %s", *jwtFlag, err) + } + return + } +} + +// Go go! +func main() { + log.Println("Loading DB") + // Database handler + db = getDB(*dbFlag) + // Close when exit + //defer db.Close() + defer func() { + err := db.Close() + if err != nil { + log.Fatalf("Failed to close Database handler - %v", err) + } + }() + // Initialize users + apiUsers = users.CreateUserManager(db) + // Initialize environment + envs = environments.CreateEnvironment(db) + // Initialize settings + settingsmgr = settings.NewSettings(db) + // Initialize nodes + nodesmgr = nodes.CreateNodes(db) + // Initialize queries + queriesmgr = queries.CreateQueries(db) + // Initialize carves + filecarves = carves.CreateFileCarves(db) + // Initialize service settings + log.Println("Loading service settings") + loadingSettings() + + // multiple listeners channel + finish := make(chan bool) + + /////////////////////////// API + if settingsmgr.DebugService(settings.ServiceAPI) { + log.Println("DebugService: Creating router") + } + // Create router for API endpoint + routerAPI := mux.NewRouter() + // API: root + routerAPI.HandleFunc("/", rootHTTPHandler) + // API: testing + routerAPI.HandleFunc(healthPath, healthHTTPHandler).Methods("GET") + // API: error + routerAPI.HandleFunc(errorPath, errorHTTPHandler).Methods("GET") + // API: forbidden + routerAPI.HandleFunc(forbiddenPath, forbiddenHTTPHandler).Methods("GET") + + /////////////////////////// AUTHENTICATED + // API: nodes + routerAPI.Handle(_apiPath(apiNodesPath)+"/{uuid}", handlerAuthCheck(http.HandlerFunc(apiNodeHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiNodesPath)+"/{uuid}/", handlerAuthCheck(http.HandlerFunc(apiNodeHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiNodesPath), handlerAuthCheck(http.HandlerFunc(apiNodesHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiNodesPath)+"/", handlerAuthCheck(http.HandlerFunc(apiNodesHandler))).Methods("GET") + // API: queries + routerAPI.Handle(_apiPath(apiQueriesPath)+"/{name}", handlerAuthCheck(http.HandlerFunc(apiQueryShowHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiQueriesPath)+"/{name}/", handlerAuthCheck(http.HandlerFunc(apiQueryShowHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiQueriesPath), handlerAuthCheck(http.HandlerFunc(apiQueriesRunHandler))).Methods("POST") + routerAPI.Handle(_apiPath(apiQueriesPath)+"/", handlerAuthCheck(http.HandlerFunc(apiQueriesRunHandler))).Methods("POST") + routerAPI.Handle(_apiPath(apiQueriesPath), handlerAuthCheck(http.HandlerFunc(apiQueriesShowHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiQueriesPath)+"/", handlerAuthCheck(http.HandlerFunc(apiQueriesShowHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiQueriesPath)+"/results/{name}", handlerAuthCheck(http.HandlerFunc(apiQueryResultsHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiQueriesPath)+"/results/{name}/", handlerAuthCheck(http.HandlerFunc(apiQueryResultsHandler))).Methods("GET") + // API: platforms + routerAPI.Handle(_apiPath(apiPlatformsPath), handlerAuthCheck(http.HandlerFunc(apiPlatformsHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiPlatformsPath)+"/", handlerAuthCheck(http.HandlerFunc(apiPlatformsHandler))).Methods("GET") + // API: environments + routerAPI.Handle(_apiPath(apiEnvironmentsPath)+"/{name}", handlerAuthCheck(http.HandlerFunc(apiEnvironmentHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiEnvironmentsPath)+"/{name}/", handlerAuthCheck(http.HandlerFunc(apiEnvironmentHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiEnvironmentsPath), handlerAuthCheck(http.HandlerFunc(apiEnvironmentsHandler))).Methods("GET") + routerAPI.Handle(_apiPath(apiEnvironmentsPath)+"/", handlerAuthCheck(http.HandlerFunc(apiEnvironmentsHandler))).Methods("GET") + + // Ticker to reload environments + // FIXME Implement Redis cache + // FIXME splay this? + if settingsmgr.DebugService(settings.ServiceAPI) { + log.Println("DebugService: Environments ticker") + } + // Refresh environments as soon as service starts + go refreshEnvironments() + go func() { + _t := settingsmgr.RefreshEnvs(settings.ServiceAPI) + if _t == 0 { + _t = int64(defaultRefresh) + } + envsTicker = time.NewTicker(time.Duration(_t) * time.Second) + for { + select { + case <-envsTicker.C: + go refreshEnvironments() + } + } + }() + + // Ticker to reload settings + // FIXME Implement Redis cache + // FIXME splay this? + if settingsmgr.DebugService(settings.ServiceAPI) { + log.Println("DebugService: Settings ticker") + } + // Refresh settings as soon as the service starts + go refreshSettings() + go func() { + _t := settingsmgr.RefreshSettings(settings.ServiceAPI) + if _t == 0 { + _t = int64(defaultRefresh) + } + settingsTicker = time.NewTicker(time.Duration(_t) * time.Second) + for { + select { + case <-settingsTicker.C: + go refreshSettings() + } + } + }() + + // Launch HTTP server for TLS endpoint + go func() { + serviceListener := apiConfig.Listener + ":" + apiConfig.Port + log.Printf("%s v%s - HTTP listening %s", serviceName, serviceVersion, serviceListener) + log.Fatal(http.ListenAndServe(serviceListener, routerAPI)) + }() + + <-finish +} diff --git a/cmd/api/postgres.go b/cmd/api/postgres.go new file mode 100644 index 00000000..e7451bbd --- /dev/null +++ b/cmd/api/postgres.go @@ -0,0 +1,33 @@ +package main + +import ( + "encoding/json" + + "github.com/jinzhu/gorm" +) + +// OsqueryQueryData to log query data to database +type OsqueryQueryData struct { + gorm.Model + UUID string `gorm:"index"` + Environment string + Name string + Data json.RawMessage + Status int +} + +// APIQueryData to return query results from API +type APIQueryData map[string]json.RawMessage + +// Function to retrieve the query log by name +func postgresQueryLogs(name string) (APIQueryData, error) { + var logs []OsqueryQueryData + data := make(APIQueryData) + if err := db.Where("name = ?", name).Find(&logs).Error; err != nil { + return data, err + } + for _, l := range logs { + data[l.UUID] = l.Data + } + return data, nil +} diff --git a/cmd/api/settings.go b/cmd/api/settings.go new file mode 100644 index 00000000..a9d8a471 --- /dev/null +++ b/cmd/api/settings.go @@ -0,0 +1,56 @@ +package main + +import ( + "log" + + "github.com/jmpsec/osctrl/pkg/metrics" + "github.com/jmpsec/osctrl/pkg/settings" +) + +// Function to load all settings for the service +func loadingSettings() { + // Check if service settings for debug service is ready + if !settingsmgr.IsValue(settings.ServiceAPI, settings.DebugService) { + if err := settingsmgr.NewBooleanValue(settings.ServiceAPI, settings.DebugService, false); err != nil { + log.Fatalf("Failed to add %s to configuration: %v", settings.DebugService, err) + } + } + // Check if service settings for metrics is ready, initialize if so + if !settingsmgr.IsValue(settings.ServiceAPI, settings.ServiceMetrics) { + if err := settingsmgr.NewBooleanValue(settings.ServiceAPI, settings.ServiceMetrics, false); err != nil { + log.Printf("Failed to add %s to configuration: %v", settings.ServiceMetrics, err) + } + } else if settingsmgr.ServiceMetrics(settings.ServiceAPI) { + _mCfg, err := metrics.LoadConfiguration() + if err != nil { + if err := settingsmgr.SetBoolean(false, settings.ServiceAPI, settings.ServiceMetrics); err != nil { + log.Fatalf("Failed to disable metrics: %v", err) + } + log.Printf("Failed to initialize metrics: %v", err) + } else { + _metrics, err = metrics.CreateMetrics(_mCfg.Protocol, _mCfg.Host, _mCfg.Port, serviceName) + if err != nil { + log.Fatalf("Failed to initialize metrics: %v", err) + if err := settingsmgr.SetBoolean(false, settings.ServiceAPI, settings.ServiceMetrics); err != nil { + log.Fatalf("Failed to disable metrics: %v", err) + } + } + } + } + // Check if service settings for environments refresh is ready + if !settingsmgr.IsValue(settings.ServiceAPI, settings.RefreshEnvs) { + if err := settingsmgr.NewIntegerValue(settings.ServiceAPI, settings.RefreshEnvs, int64(defaultRefresh)); err != nil { + log.Fatalf("Failed to add %s to configuration: %v", settings.RefreshEnvs, err) + } + } + // Check if service settings for settings refresh is ready + if !settingsmgr.IsValue(settings.ServiceAPI, settings.RefreshSettings) { + if err := settingsmgr.NewIntegerValue(settings.ServiceAPI, settings.RefreshSettings, int64(defaultRefresh)); err != nil { + log.Fatalf("Failed to add %s to configuration: %v", settings.RefreshSettings, err) + } + } + // Write JSON config to settings + if err := settingsmgr.SetAllJSON(settings.ServiceAPI, apiConfig.Listener, apiConfig.Port, apiConfig.Host, apiConfig.Auth, apiConfig.Logging); err != nil { + log.Fatalf("Failed to add JSON values to configuration: %v", err) + } +} diff --git a/cmd/api/types-requests.go b/cmd/api/types-requests.go new file mode 100644 index 00000000..928acc96 --- /dev/null +++ b/cmd/api/types-requests.go @@ -0,0 +1,20 @@ +package main + +// DistributedQueryRequest to receive query requests +type DistributedQueryRequest struct { + Environments []string `json:"environment_list"` + Platforms []string `json:"platform_list"` + UUIDs []string `json:"uuid_list"` + Hosts []string `json:"host_list"` + Query string `json:"query"` +} + +// ApiErrorResponse to be returned to API requests with the error message +type ApiErrorResponse struct { + Error string `json:"error"` +} + +// ApiQueriesResponse to be returned to API requests for queries +type ApiQueriesResponse struct { + Name string `json:"query_name"` +} diff --git a/cmd/api/utils.go b/cmd/api/utils.go new file mode 100644 index 00000000..6ea755ad --- /dev/null +++ b/cmd/api/utils.go @@ -0,0 +1,133 @@ +package main + +import ( + "crypto/md5" + "crypto/rand" + "encoding/hex" + "encoding/json" + "flag" + "fmt" + "log" + "net/http" + "os" + + "github.com/jmpsec/osctrl/pkg/settings" +) + +// Helper to send metrics if it is enabled +func incMetric(name string) { + if _metrics != nil && settingsmgr.ServiceMetrics(settings.ServiceAPI) { + _metrics.Inc(name) + } +} + +// Helper to refresh the environments map until cache/Redis support is implemented +func refreshEnvironments() { + log.Printf("Refreshing environments...\n") + var err error + envsmap, err = envs.GetMap() + if err != nil { + log.Printf("error refreshing environments %v\n", err) + } +} + +// Helper to refresh the settings until cache/Redis support is implemented +func refreshSettings() { + log.Printf("Refreshing settings...\n") + var err error + settingsmap, err = settingsmgr.GetMap(settings.ServiceAPI) + if err != nil { + log.Printf("error refreshing settings %v\n", err) + } +} + +// Usage for service binary +func apiUsage() { + fmt.Printf("NAME:\n %s - %s\n\n", serviceName, serviceDescription) + fmt.Printf("USAGE: %s [global options] [arguments...]\n\n", serviceName) + fmt.Printf("VERSION:\n %s\n\n", serviceVersion) + fmt.Printf("DESCRIPTION:\n %s\n\n", appDescription) + fmt.Printf("GLOBAL OPTIONS:\n") + flag.PrintDefaults() + fmt.Printf("\n") +} + +// Display binary version +func apiVersion() { + fmt.Printf("%s v%s\n", serviceName, serviceVersion) + os.Exit(0) +} + +// Helper to compose paths for API +func _apiPath(target string) string { + return apiPrefixPath + apiVersionPath + target +} + +// Helper to verify if a platform is valid +func checkValidPlatform(platform string) bool { + platforms, err := nodesmgr.GetAllPlatforms() + if err != nil { + return false + } + for _, p := range platforms { + if p == platform { + return true + } + } + return false +} + +// Helper to remove duplicates from []string +func removeStringDuplicates(s []string) []string { + seen := make(map[string]struct{}, len(s)) + i := 0 + for _, v := range s { + if _, ok := seen[v]; ok { + continue + } + seen[v] = struct{}{} + s[i] = v + i++ + } + return s[:i] +} + +// Helper to send HTTP response +func apiHTTPResponse(w http.ResponseWriter, cType string, code int, data interface{}) { + if cType != "" { + w.Header().Set(contentType, cType) + } + content, err := json.Marshal(data) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + log.Printf("error serializing response: %v", err) + content = []byte("error serializing response") + } + w.WriteHeader(code) + _, _ = w.Write(content) +} + +// Helper to handle API error responses +func apiErrorResponse(w http.ResponseWriter, msg string, code int, err error) { + log.Printf("%s: %v", msg, err) + apiHTTPResponse(w, JSONApplicationUTF8, code, ApiErrorResponse{Error: msg}) +} + +// Helper to generate a random query name +func generateQueryName() string { + return "query_" + randomForNames() +} + +// Helper to generate a random carve name +func generateCarveName() string { + return "carve_" + randomForNames() +} + +// Helper to generate a random MD5 to be used with queries/carves +func randomForNames() string { + b := make([]byte, 32) + _, _ = rand.Read(b) + hasher := md5.New() + _, _ = hasher.Write([]byte(fmt.Sprintf("%x", b))) + return hex.EncodeToString(hasher.Sum(nil)) +} diff --git a/cmd/tls/handlers-tls.go b/cmd/tls/handlers-tls.go index 2c206e64..e848d357 100644 --- a/cmd/tls/handlers-tls.go +++ b/cmd/tls/handlers-tls.go @@ -12,6 +12,7 @@ import ( "github.com/jmpsec/osctrl/pkg/environments" "github.com/jmpsec/osctrl/pkg/nodes" "github.com/jmpsec/osctrl/pkg/queries" + "github.com/jmpsec/osctrl/pkg/settings" "github.com/jmpsec/osctrl/pkg/types" "github.com/jmpsec/osctrl/pkg/utils" ) @@ -38,8 +39,8 @@ const ( metricBlockReq = "block-req" metricBlockErr = "block-err" metricBlockOK = "block-ok" - metricHealthReq = "health-req" - metricHealthOK = "health-ok" + metricHealthReq = "health-req" + metricHealthOK = "health-ok" ) // JSONApplication for Content-Type headers @@ -431,7 +432,7 @@ func queryReadHandler(w http.ResponseWriter, r *http.Request) { log.Printf("error parsing POST body %v", err) return } - var nodeInvalid bool + var nodeInvalid, accelerate bool qs := make(queries.QueryReadQueries) // Lookup node by node_key node, err := nodesmgr.GetByKey(t.NodeKey) @@ -442,7 +443,7 @@ func queryReadHandler(w http.ResponseWriter, r *http.Request) { log.Printf("error updating IP Address %v", err) } nodeInvalid = false - qs, err = queriesmgr.NodeQueries(node) + qs, accelerate, err = queriesmgr.NodeQueries(node) if err != nil { incMetric(metricReadErr) log.Printf("error getting queries from db %v", err) @@ -455,9 +456,15 @@ func queryReadHandler(w http.ResponseWriter, r *http.Request) { } } else { nodeInvalid = true + accelerate = false + } + // Prepare response and serialize queries + if accelerate { + sAccelerate := int(settingsmap[settings.AcceleratedSeconds].Integer) + response, err = json.Marshal(types.AcceleratedQueryReadResponse{Queries: qs, Accelerated: sAccelerate, NodeInvalid: nodeInvalid}) + } else { + response, err = json.Marshal(types.QueryReadResponse{Queries: qs, NodeInvalid: nodeInvalid}) } - // Prepare response for invalid key - response, err = json.Marshal(types.QueryReadResponse{Queries: qs, NodeInvalid: nodeInvalid}) if err != nil { incMetric(metricReadErr) log.Printf("error formating response %v", err) diff --git a/cmd/tls/main.go b/cmd/tls/main.go index 7aa3b52a..d3174db4 100644 --- a/cmd/tls/main.go +++ b/cmd/tls/main.go @@ -42,6 +42,8 @@ const ( dbConfigurationFile string = "config/db.json" // Default refreshing interval in seconds defaultRefresh int = 300 + // Default accelerate interval in seconds + defaultAccelerate int = 300 ) // Global variables diff --git a/cmd/tls/settings.go b/cmd/tls/settings.go index 6e3cfae9..145a08c0 100644 --- a/cmd/tls/settings.go +++ b/cmd/tls/settings.go @@ -15,6 +15,12 @@ func loadingSettings() { log.Fatalf("Failed to add %s to configuration: %v", settings.DebugService, err) } } + // Check if service settings for accelerated seconds is ready + if !settingsmgr.IsValue(settings.ServiceTLS, settings.AcceleratedSeconds) { + if err := settingsmgr.NewIntegerValue(settings.ServiceTLS, settings.AcceleratedSeconds, int64(defaultAccelerate)); err != nil { + log.Fatalf("Failed to add %s to configuration: %v", settings.AcceleratedSeconds, err) + } + } // Check if service settings for metrics is ready, initialize if so if !settingsmgr.IsValue(settings.ServiceTLS, settings.ServiceMetrics) { if err := settingsmgr.NewBooleanValue(settings.ServiceTLS, settings.ServiceMetrics, false); err != nil { diff --git a/deploy/jwt.json b/deploy/jwt.json new file mode 100644 index 00000000..b4abcf2b --- /dev/null +++ b/deploy/jwt.json @@ -0,0 +1,6 @@ +{ + "jwt": { + "jwtSecret": "_JWT_SECRET", + "hoursToExpire": 3 + } +} diff --git a/deploy/provision.sh b/deploy/provision.sh index 271e9e14..b1f4e018 100755 --- a/deploy/provision.sh +++ b/deploy/provision.sh @@ -28,10 +28,14 @@ # Optional Parameters: # --public-tls-port PORT Port for the TLS endpoint service. Default is 443 # --public-admin-port PORT Port for the admin service. Default is 8443 +# --public-api-port PORT Port for the API service. Default is 8444 # --private-tls-port PORT Port for the TLS endpoint service. Default is 9000 # --private-admin-port PORT Port for the admin service. Default is 9001 +# --private-api-port PORT Port for the API service. Default is 9002 +# --all-hostname HOSTNAME Hostname for all the services. Default is 127.0.0.1 # --tls-hostname HOSTNAME Hostname for the TLS endpoint service. Default is 127.0.0.1 # --admin-hostname HOSTNAME Hostname for the admin service. Default is 127.0.0.1 +# --api-hostname HOSTNAME Hostname for the API service. Default is 127.0.0.1 # -X PASS --password Force the admin password for the admin interface. Default is random # -U --update Pull from master and sync files to the current folder # -k PATH --keyfile PATH Path to supplied TLS key file @@ -46,7 +50,7 @@ # -E --enroll Enroll the serve into itself using osquery. Default is disabled # # Examples: -# Provision service in development mode, code is in /vagrant and both admin and tls: +# Provision service in development mode, code is in /vagrant and all components (admin, tls, api): # provision.sh -m dev -s /vagrant -p all # Provision service in production mode using my own certificate and only with TLS endpoint: # provision.sh -m prod -t own -k /etc/certs/my.key -c /etc/certs/cert.crt -p tls @@ -95,10 +99,14 @@ function usage() { printf "\nOptional Parameters:\n" printf " --public-tls-port PORT \tPort for the TLS endpoint service. Default is 443\n" printf " --public-admin-port PORT \tPort for the admin service. Default is 8443\n" + printf " --public-api-port PORT \tPort for the API service. Default is 8444\n" printf " --private-tls-port PORT \tPort for the TLS endpoint service. Default is 9000\n" printf " --private-admin-port PORT \tPort for the admin service. Default is 9001\n" + printf " --private-api-port PORT \tPort for the API service. Default is 9002\n" + printf " --all-hostname HOSTNAME \tHostname for all the services. Default is 127.0.0.1\n" printf " --tls-hostname HOSTNAME \tHostname for the TLS endpoint service. Default is 127.0.0.1\n" printf " --admin-hostname HOSTNAME \tHostname for the admin service. Default is 127.0.0.1\n" + printf " --api-hostname HOSTNAME \tHostname for the API service. Default is 127.0.0.1\n" printf " -X PASS --password \tForce the admin password for the admin interface. Default is random\n" printf " -U --update \t\tPull from master and sync files to the current folder\n" printf " -c PATH --certfile PATH \tPath to supplied TLS server PEM certificate(s) bundle\n" @@ -111,7 +119,7 @@ function usage() { printf " -M --metrics \tInstall and configure all services for metrics (InfluxDB + Telegraf + Grafana)\n" printf " -E --enroll \tEnroll the serve into itself using osquery. Default is disabled\n" printf "\nExamples:\n" - printf " Provision service in development mode, code is in /vagrant and both admin and tls:\n" + printf " Provision service in development mode, code is in /vagrant and all components (admin, tls, api):\n" printf "\t%s -m dev -s /vagrant -p all\n" "${0}" printf " Provision service in production mode using my own certificate and only with TLS endpoint:\n" printf "\t%s -m prod -t own -k /etc/certs/my.key -c /etc/certs/cert.crt -p tls\n" "${0}" @@ -126,11 +134,15 @@ set -e # Values not intended to change TLS_COMPONENT="tls" ADMIN_COMPONENT="admin" +API_COMPONENT="api" TLS_CONF="$TLS_COMPONENT.json" ADMIN_CONF="$ADMIN_COMPONENT.json" +API_CONF="$API_COMPONENT.json" DB_CONF="db.json" +JWT_CONF="jwt.json" SERVICE_TEMPLATE="service.json" DB_TEMPLATE="db.json" +JWT_TEMPLATE="jwt.json" SYSTEMD_TEMPLATE="systemd.service" # Default values for arguments @@ -149,6 +161,7 @@ NGINX=false POSTGRES=false SOURCE_PATH=/vagrant DEST_PATH=/opt/osctrl +ALL_HOST="127.0.0.1" # Backend values _DB_HOST="localhost" @@ -161,28 +174,38 @@ _DB_PORT="5432" # TLS Service _T_INT_PORT="9000" _T_PUB_PORT="443" -_T_HOST="127.0.0.1" +_T_HOST="$ALL_HOST" _T_AUTH="none" _T_LOGGING="db" # Admin Service _A_INT_PORT="9001" _A_PUB_PORT="8443" -_A_HOST="127.0.0.1" +_A_HOST="$ALL_HOST" _A_AUTH="db" _A_LOGGING="db" +# API Service +_P_INT_PORT="9002" +_P_PUB_PORT="8444" +_P_HOST="$ALL_HOST" +_P_AUTH="jwt" +_P_LOGGING="none" + # Default admin credentials with random password _ADMIN_USER="admin" _ADMIN_PASS=$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 64 | head -n 1 | md5sum | cut -d " " -f1) +# Secret for API JWT +_JWT_SECRET="$(cat /dev/urandom | tr -dc 'a-zA-Z0-9' | fold -w 64 | head -n 1 | sha256sum | cut -d " " -f1)" + # Arrays with valid arguments VALID_MODE=("dev" "prod" "update") VALID_TYPE=("self" "own" "certbot") -VALID_PART=("$TLS_COMPONENT" "$ADMIN_COMPONENT" "all") +VALID_PART=("$TLS_COMPONENT" "$ADMIN_COMPONENT" "$API_COMPONENT" "all") # Extract arguments -ARGS=$(getopt -n "$0" -o hm:t:p:UPk:nMEc:d:e:s:S:X: -l "help,mode:,type:,part:,public-tls-port:,private-tls-port:,public-admin-port:,private-admin-port:,tls-hostname:,admin-hostname:,update,keyfile:,nginx,postgres,metrics,enroll,certfile:,domain:,email:,source:,dest:,password:" -- "$@") +ARGS=$(getopt -n "$0" -o hm:t:p:UPk:nMEc:d:e:s:S:X: -l "help,mode:,type:,part:,public-tls-port:,private-tls-port:,public-admin-port:,private-admin-port:,public-api-port:,private-api-port:,all-hostname:,tls-hostname:,admin-hostname:,api-hostname:,update,keyfile:,nginx,postgres,metrics,enroll,certfile:,domain:,email:,source:,dest:,password:" -- "$@") if [ $? != 0 ] ; then echo "Failed parsing options." >&2 ; exit 1 ; fi @@ -250,6 +273,16 @@ while true; do _A_INT_PORT=$2 shift 2 ;; + --public-api-port) + SHOW_USAGE=false + _P_PUB_PORT=$2 + shift 2 + ;; + --private-api-port) + SHOW_USAGE=false + _P_INT_PORT=$2 + shift 2 + ;; --tls-hostname) SHOW_USAGE=false _T_HOST=$2 @@ -260,6 +293,19 @@ while true; do _A_HOST=$2 shift 2 ;; + --api-hostname) + SHOW_USAGE=false + _P_HOST=$2 + shift 2 + ;; + --all-hostname) + SHOW_USAGE=false + ALL_HOST=$2 + _T_HOST=$ALL_HOST + _A_HOST=$ALL_HOST + _P_HOST=$ALL_HOST + shift 2 + ;; -U|--update) SHOW_USAGE=false UPDATE=true @@ -357,16 +403,25 @@ log "Provisioning [ osctrl ][ $PART ] for $DISTRO" log "" log " -> [ $MODE ] mode and with [ $TYPE ] certificate" log "" + if [[ "$PART" == "all" ]] || [[ "$PART" == "$TLS_COMPONENT" ]]; then log " -> Deploying TLS service for ports $_T_PUB_PORT:$_T_INT_PORT" log " -> Hostname for TLS endpoint: $_T_HOST" fi log "" + if [[ "$PART" == "all" ]] || [[ "$PART" == "$ADMIN_COMPONENT" ]]; then log " -> Deploying Admin service for ports $_A_PUB_PORT:$_A_INT_PORT" log " -> Hostname for admin: $_A_HOST" fi log "" + +if [[ "$PART" == "all" ]] || [[ "$PART" == "$API_COMPONENT" ]]; then + log " -> Deploying API service for ports $_P_PUB_PORT:$_P_INT_PORT" + log " -> Hostname for API: $_P_HOST" +fi +log "" + log "" # Update distro @@ -438,6 +493,9 @@ if [[ "$NGINX" == true ]]; then # Configuration for Admin service nginx_service "$SOURCE_PATH/deploy/nginx/ssl.conf" "$_cert_file" "$_key_file" "$_dh_file" "$_A_PUB_PORT" "$_A_INT_PORT" "admin.conf" "$NGINX_PATH" + # Configuration for API service + nginx_service "$SOURCE_PATH/deploy/nginx/ssl.conf" "$_cert_file" "$_key_file" "$_dh_file" "$_P_PUB_PORT" "$_P_INT_PORT" "api.conf" "$NGINX_PATH" + # Restart nginx sudo nginx -t sudo service nginx restart @@ -487,6 +545,9 @@ sudo mkdir -p "$DEST_PATH/config" # Generate DB configuration file for services configuration_db "$SOURCE_PATH/deploy/$DB_TEMPLATE" "$DEST_PATH/config/$DB_CONF" "$_DB_HOST" "$_DB_PORT" "$_DB_NAME" "$_DB_USER" "$_DB_PASS" "sudo" +# JWT configuration +cat "$SOURCE_PATH/deploy/$JWT_TEMPLATE" | sed "s|_JWT_SECRET|$_JWT_SECRET|g" | sudo tee "$DEST_PATH/config/$JWT_CONF" + # Build code cd "$SOURCE_PATH" make clean @@ -542,6 +603,17 @@ if [[ "$PART" == "all" ]] || [[ "$PART" == "$ADMIN_COMPONENT" ]]; then _systemd "osctrl" "osctrl" "osctrl-admin" "$SOURCE_PATH" "$DEST_PATH" fi +if [[ "$PART" == "all" ]] || [[ "$PART" == "$API_COMPONENT" ]]; then + # Build API service + make api + + # Configuration file generation for API service + configuration_service "$SOURCE_PATH/deploy/$SERVICE_TEMPLATE" "$DEST_PATH/config/$API_CONF" "$_P_HOST|$_P_INT_PORT" "$API_COMPONENT" "127.0.0.1" "$_P_AUTH" "$_P_LOGGING" "sudo" + + # Systemd configuration for API service + _systemd "osctrl" "osctrl" "osctrl-api" "$SOURCE_PATH" "$DEST_PATH" +fi + # Compile CLI make cli @@ -609,4 +681,4 @@ exit 0 # kthxbai # Standard deployment in a linux box would be like: -# ./deploy/provision.sh --nginx --postgres -p all --tls-hostname "dev.osctrl.net" --admin-hostname "dev.osctrl.net" -E +# ./deploy/provision.sh --nginx --postgres -p all --all-hostname "dev.osctrl.net" -E diff --git a/docker/admin/wait.sh b/docker/admin/wait.sh index 87cf62f2..1de367be 100644 --- a/docker/admin/wait.sh +++ b/docker/admin/wait.sh @@ -46,4 +46,4 @@ else fi # Run service -./bin/osctrl-admin +./bin/$NAME diff --git a/docker/api/Dockerfile b/docker/api/Dockerfile new file mode 100644 index 00000000..b379830c --- /dev/null +++ b/docker/api/Dockerfile @@ -0,0 +1,23 @@ +FROM golang:latest +LABEL maintainer="javuto" + +ENV GO111MODULE=on + +WORKDIR /osctrl-api + +#COPY /config/api.json config/ +#COPY /config/db.json config/ + +COPY go.mod . +COPY go.sum . + +COPY cmd/api/ cmd/api +COPY cmd/cli/ cmd/cli +COPY pkg/ pkg + +RUN go build -o bin/osctrl-api cmd/api/*.go +RUN go build -o bin/osctrl-cli cmd/cli/*.go + +COPY docker/api/wait.sh . + +CMD [ "/bin/sh", "/osctrl-api/wait.sh" ] diff --git a/docker/api/wait.sh b/docker/api/wait.sh new file mode 100644 index 00000000..4a86ac79 --- /dev/null +++ b/docker/api/wait.sh @@ -0,0 +1,21 @@ +#!/bin/sh +# +# [ osctrl 🎛 ]: Script to wait for database to initialize osctrl-api +# +# Usage: wait.sh + +NAME="osctrl-api" +WAIT=3 +CONFIG="config" +DB_JSON="$CONFIG/db.json" + +# Check if database is ready, otherwise commands will fail +until $(./bin/osctrl-cli -D "$DB_JSON" check); do + >&2 echo "Postgres is unavailable - Waiting..." + sleep $WAIT +done +>&2 echo "Postgres is up - Starting $NAME" +sleep $WAIT + +# Run service +./bin/$NAME diff --git a/docker/docker-compose.yml b/docker/docker-compose.yml index 1d561dd0..2650eb06 100644 --- a/docker/docker-compose.yml +++ b/docker/docker-compose.yml @@ -45,12 +45,28 @@ services: volumes: - ./docker/certs:/osctrl-admin/certs - ./docker/config:/osctrl-admin/config + osctrl-api: + container_name: osctrl-api + depends_on: + - "osctrl-db" + build: + context: . + dockerfile: "docker/api/Dockerfile" + links: + - "osctrl-db" + ports: + - "9002:9002" + networks: + - private-net + volumes: + - ./docker/config:/osctrl-api/config osctrl-nginx: image: "nginx:1.13.5" container_name: osctrl-nginx depends_on: - "osctrl-tls" - "osctrl-admin" + - "osctrl-api" ports: - "443:443" - "8443:8443" diff --git a/docker/dockerize.sh b/docker/dockerize.sh index 28dc673a..d6f9bd8e 100755 --- a/docker/dockerize.sh +++ b/docker/dockerize.sh @@ -71,6 +71,9 @@ COMPOSERFILE="$DOCKERDIR/docker-compose.yml" mkdir -p "$CERTSDIR" mkdir -p "$CONFIGDIR" +# Secret for API JWT +_JWT_SECRET="$(head -c64 < /dev/random | base64 | head -n 1 | openssl dgst -sha256 | cut -d " " -f1)" + # Default values for arguments SHOW_USAGE=true _BUILD=false @@ -178,7 +181,7 @@ CRT_DST="/etc/certs/$NAME.crt" KEY_DST="/etc/certs/$NAME.key" DH_DST="/etc/certs/dhparam.pem" -log "Preparing configuration" +log "Preparing configuration for nginx" TLS_CONF="$CONFIGDIR/tls.conf" if [[ -f "$TLS_CONF" && "$_FORCE" == false ]]; then @@ -194,6 +197,14 @@ else nginx_generate "$DEPLOYDIR/nginx/ssl.conf" "$CRT_DST" "$KEY_DST" "$DH_DST" "8443" "9001" "osctrl-admin" "$ADMIN_CONF" fi +API_CONF="$CONFIGDIR/api.conf" +if [[ -f "$API_CONF" && "$_FORCE" == false ]]; then + log "Using existing $API_CONF" +else + nginx_generate "$DEPLOYDIR/nginx/ssl.conf" "$CRT_DST" "$KEY_DST" "$DH_DST" "8444" "9002" "osctrl-api" "$API_CONF" +fi + +log "Preparing configuration for TLS" TLS_JSON="$CONFIGDIR/tls.json" if [[ -f "$TLS_JSON" && "$_FORCE" == false ]]; then log "Using existing $TLS_JSON" @@ -201,6 +212,7 @@ else configuration_service "$DEPLOYDIR/service.json" "$TLS_JSON" "localhost|9000" "tls" "0.0.0.0" "none" "db" fi +log "Preparing configuration for Admin" ADMIN_JSON="$CONFIGDIR/admin.json" if [[ -f "$ADMIN_JSON" && "$_FORCE" == false ]]; then log "Using existing $ADMIN_JSON" @@ -208,6 +220,23 @@ else configuration_service "$DEPLOYDIR/service.json" "$ADMIN_JSON" "localhost|9001" "admin" "0.0.0.0" "db" "db" fi +log "Preparing configuration for API" +API_JSON="$CONFIGDIR/api.json" +if [[ -f "$API_JSON" && "$_FORCE" == false ]]; then + log "Using existing $API_JSON" +else + configuration_service "$DEPLOYDIR/service.json" "$API_JSON" "localhost|9002" "api" "0.0.0.0" "jwt" "db" +fi + +log "Preparing configuration for JWT" +JWT_JSON="$CONFIGDIR/jwt.json" +if [[ -f "$JWT_JSON" && "$_FORCE" == false ]]; then + log "Using existing $JWT_JSON" +else + cat "$DEPLOYDIR/jwt.json" | sed "s|_JWT_SECRET|$_JWT_SECRET|g" | tee "$JWT_JSON" +fi + +log "Preparing configuration for backend" DB_JSON="$CONFIGDIR/db.json" if [[ -f "$DB_JSON" && "$_FORCE" == false ]]; then log "Using existing $DB_JSON" diff --git a/docker/tls/wait.sh b/docker/tls/wait.sh index 30e99256..0f9db2e9 100644 --- a/docker/tls/wait.sh +++ b/docker/tls/wait.sh @@ -18,4 +18,4 @@ done sleep $WAIT # Run service -./bin/osctrl-tls +./bin/$NAME diff --git a/go.mod b/go.mod index 3b20052f..19485fa9 100644 --- a/go.mod +++ b/go.mod @@ -19,10 +19,6 @@ require ( github.com/jmpsec/osctrl/pkg/types v0.1.9 github.com/jmpsec/osctrl/pkg/users v0.1.9 github.com/jmpsec/osctrl/pkg/utils v0.1.9 - github.com/jmpsec/osctrl/plugins/db_logging v0.1.9 // indirect - github.com/jmpsec/osctrl/plugins/graylog_logging v0.1.9 // indirect - github.com/jmpsec/osctrl/plugins/logging_dispatcher v0.1.9 // indirect - github.com/jmpsec/osctrl/plugins/splunk_logging v0.1.9 // indirect github.com/mattn/go-runewidth v0.0.4 // indirect github.com/olekukonko/tablewriter v0.0.1 github.com/russellhaering/goxmldsig v0.0.0-20180430223755-7acd5e4a6ef7 // indirect diff --git a/pkg/queries/queries.go b/pkg/queries/queries.go index 35b8ab27..b7f9b3bb 100644 --- a/pkg/queries/queries.go +++ b/pkg/queries/queries.go @@ -32,6 +32,19 @@ const ( StatusComplete string = "COMPLETE" ) +const ( + // TargetAll for all queries + TargetAll string = "all" + // TargetAllFull for all queries including hidden ones + TargetAllFull string = "all-full" + // TargetActive for active queries + TargetActive string = "active" + // TargetCompleted for completed queries + TargetCompleted string = "completed" + // TargetDeleted for deleted queries + TargetDeleted string = "deleted" +) + // DistributedQuery as abstraction of a distributed query type DistributedQuery struct { gorm.Model @@ -46,7 +59,6 @@ type DistributedQuery struct { Protected bool Completed bool Deleted bool - Repeat uint Type string Path string } @@ -67,7 +79,7 @@ type DistributedQueryExecution struct { Result int } -// QueryReadQueries to hold the on-demand queries +// QueryReadQueries to hold all the on-demand queries type QueryReadQueries map[string]string // Queries to handle on-demand queries @@ -97,48 +109,51 @@ func CreateQueries(backend *gorm.DB) *Queries { // NodeQueries to get all queries that belong to the provided node // FIXME this will impact the performance of the TLS endpoint due to being CPU and I/O hungry // FIMXE potential mitigation can be add a cache (Redis?) layer to store queries per node_key -func (q *Queries) NodeQueries(node nodes.OsqueryNode) (QueryReadQueries, error) { - // Get all current active queries and carvesccccccijgvvbighcllglrtditncrninnndegfhuurkgu - +func (q *Queries) NodeQueries(node nodes.OsqueryNode) (QueryReadQueries, bool, error) { + acelerate := false + // Get all current active queries and carves queries, err := q.GetActive() if err != nil { - return QueryReadQueries{}, err + return QueryReadQueries{}, false, err } // Iterate through active queries, see if they target this node and prepare data in the same loop qs := make(QueryReadQueries) for _, _q := range queries { targets, err := q.GetTargets(_q.Name) if err != nil { - return QueryReadQueries{}, err + return QueryReadQueries{}, false, err + } + if len(targets) == 1 { + acelerate = true } if isQueryTarget(node, targets) && q.NotYetExecuted(_q.Name, node.UUID) { qs[_q.Name] = _q.Query } } - return qs, nil + return qs, acelerate, nil } // Gets all queries by target (active/completed/all/all-full/deleted) func (q *Queries) Gets(target, qtype string) ([]DistributedQuery, error) { var queries []DistributedQuery switch target { - case "active": + case TargetActive: if err := q.DB.Where("active = ? AND completed = ? AND deleted = ? AND type = ?", true, false, false, qtype).Find(&queries).Error; err != nil { return queries, err } - case "completed": + case TargetCompleted: if err := q.DB.Where("active = ? AND completed = ? AND deleted = ? AND type = ?", false, true, false, qtype).Find(&queries).Error; err != nil { return queries, err } - case "all-full": + case TargetAllFull: if err := q.DB.Where("deleted = ? AND hidden = ? AND type = ?", false, true, qtype).Find(&queries).Error; err != nil { return queries, err } - case "all": + case TargetAll: if err := q.DB.Where("deleted = ? AND hidden = ? AND type = ?", false, false, qtype).Find(&queries).Error; err != nil { return queries, err } - case "deleted": + case TargetDeleted: if err := q.DB.Where("deleted = ? AND type = ?", true, qtype).Find(&queries).Error; err != nil { return queries, err } diff --git a/pkg/settings/settings.go b/pkg/settings/settings.go index 8a280307..51dda578 100644 --- a/pkg/settings/settings.go +++ b/pkg/settings/settings.go @@ -11,6 +11,7 @@ import ( const ( ServiceTLS string = "tls" ServiceAdmin string = "admin" + ServiceAPI string = "api" ) // Types of settings values @@ -27,6 +28,7 @@ const ( AuthDB string = "db" AuthSAML string = "saml" AuthHeaders string = "headers" + AuthJWT string = "jwt" ) // Types of logging @@ -42,17 +44,18 @@ const ( // Names for all possible settings values for services const ( - DebugHTTP string = "debug_http" - DebugService string = "debug_service" - RefreshEnvs string = "refresh_envs" - RefreshSettings string = "refresh_settings" - CleanupSessions string = "cleanup_sessions" - ServiceMetrics string = "service_metrics" - MetricsHost string = "metrics_host" - MetricsPort string = "metrics_port" - MetricsProtocol string = "metrics_protocol" - DefaultEnv string = "default_env" - InactiveHours string = "inactive_hours" + DebugHTTP string = "debug_http" + DebugService string = "debug_service" + RefreshEnvs string = "refresh_envs" + RefreshSettings string = "refresh_settings" + CleanupSessions string = "cleanup_sessions" + ServiceMetrics string = "service_metrics" + MetricsHost string = "metrics_host" + MetricsPort string = "metrics_port" + MetricsProtocol string = "metrics_protocol" + DefaultEnv string = "default_env" + InactiveHours string = "inactive_hours" + AcceleratedSeconds string = "accelerated_seconds" ) // Names for setting values for logging diff --git a/pkg/types/osquery.go b/pkg/types/osquery.go index f55103ee..36f4c366 100644 --- a/pkg/types/osquery.go +++ b/pkg/types/osquery.go @@ -169,6 +169,14 @@ type QueryReadResponse struct { NodeInvalid bool `json:"node_invalid"` } +// AcceleratedQueryReadResponse for accelerated on-demand queries from nodes +// https://github.com/osquery/osquery/blob/master/osquery/distributed/distributed.cpp#L219-L231 +type AcceleratedQueryReadResponse struct { + Queries queries.QueryReadQueries `json:"queries"` + NodeInvalid bool `json:"node_invalid"` + Accelerated int `json:"accelerated"` +} + // QueryWriteQueries to hold the on-demand queries results type QueryWriteQueries map[string]json.RawMessage diff --git a/pkg/types/types.go b/pkg/types/types.go index 7dc27169..7b4b78e0 100644 --- a/pkg/types/types.go +++ b/pkg/types/types.go @@ -21,7 +21,7 @@ type JSONConfigurationService struct { Logging string `json:"logging"` } -// JSONConfigurationHeaders to keep all SAML details for auth +// JSONConfigurationHeaders to keep all headers details for auth type JSONConfigurationHeaders struct { TrustedPrefix string `json:"trustedPrefix"` AdminGroup string `json:"adminGroup"` @@ -34,3 +34,9 @@ type JSONConfigurationHeaders struct { DistinguishedName string `json:"distinguishedName"` Groups string `json:"groups"` } + +// JSONConfigurationJWT to hold all JWT configuration values +type JSONConfigurationJWT struct { + JWTSecret string `json:"jwtSecret"` + HoursToExpire int `json:"hoursToExpire"` +} diff --git a/pkg/users/go.mod b/pkg/users/go.mod index 08d7693f..4922021d 100644 --- a/pkg/users/go.mod +++ b/pkg/users/go.mod @@ -3,6 +3,7 @@ module github.com/jmpsec/osctrl/pkg/users go 1.12 require ( + github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/jinzhu/gorm v1.9.8 golang.org/x/crypto v0.0.0-20190530122614-20be4c3c3ed5 ) diff --git a/pkg/users/go.sum b/pkg/users/go.sum index 5190c921..fbdbc6b0 100644 --- a/pkg/users/go.sum +++ b/pkg/users/go.sum @@ -13,6 +13,8 @@ github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDk github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/denisenkom/go-mssqldb v0.0.0-20190423183735-731ef375ac02 h1:PS3xfVPa8N84AzoWZHFCbA0+ikz4f4skktfjQoNMsgk= github.com/denisenkom/go-mssqldb v0.0.0-20190423183735-731ef375ac02/go.mod h1:zAg7JM8CkOJ43xKXIj7eRO9kmWm/TW578qo+oDO6tuM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible h1:7qlOGliEKZXTDg6OTjfoBKDXWrumCAMpl/TFQ4/5kLM= +github.com/dgrijalva/jwt-go v3.2.0+incompatible/go.mod h1:E3ru+11k8xSBh+hMPgOLZmtrrCbhqsmaPHjLKYnJCaQ= github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs= github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU= github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I= diff --git a/pkg/users/users.go b/pkg/users/users.go index 269aa928..70caee20 100644 --- a/pkg/users/users.go +++ b/pkg/users/users.go @@ -5,6 +5,7 @@ import ( "log" "time" + "github.com/dgrijalva/jwt-go" "github.com/jinzhu/gorm" "golang.org/x/crypto/bcrypt" ) @@ -16,10 +17,19 @@ type AdminUser struct { Email string Fullname string PassHash string + APIToken string + TokenExpire time.Time Admin bool LastIPAddress string LastUserAgent string LastAccess time.Time + LastTokenUse time.Time +} + +// TokenClaims to hold user claims when using JWT +type TokenClaims struct { + Username string `json:"username"` + jwt.StandardClaims } // UserManager have all users of the system @@ -38,9 +48,9 @@ func CreateUserManager(backend *gorm.DB) *UserManager { return u } -// HashPasswordWithSalt to hash a password before store it -func (m *UserManager) HashPasswordWithSalt(password string) (string, error) { - saltedBytes := []byte(password) +// HashTextWithSalt to hash text before store it +func (m *UserManager) HashTextWithSalt(text string) (string, error) { + saltedBytes := []byte(text) hashedBytes, err := bcrypt.GenerateFromPassword(saltedBytes, bcrypt.DefaultCost) if err != nil { return "", err @@ -49,6 +59,11 @@ func (m *UserManager) HashPasswordWithSalt(password string) (string, error) { return hash, nil } +// HashPasswordWithSalt to hash a password before store it +func (m *UserManager) HashPasswordWithSalt(password string) (string, error) { + return m.HashTextWithSalt(password) +} + // CheckLoginCredentials to check provided login credentials by matching hashes func (m *UserManager) CheckLoginCredentials(username, password string) (bool, AdminUser) { // Retrieve user @@ -66,6 +81,44 @@ func (m *UserManager) CheckLoginCredentials(username, password string) (bool, Ad return true, user } +// CreateToken to create a new JWT token for a given user +func (m *UserManager) CreateToken(username string, expireHours int, jwtSecret string) (string, time.Time, error) { + expirationTime := time.Now().Add(time.Hour * time.Duration(expireHours)) + // Create the JWT claims, which includes the username, level and expiry time + claims := &TokenClaims{ + Username: username, + StandardClaims: jwt.StandardClaims{ + // In JWT, the expiry time is expressed as unix milliseconds + ExpiresAt: expirationTime.Unix(), + }, + } + // Declare the token with the algorithm used for signing, and the claims + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + // Create the JWT string + tokenString, err := token.SignedString([]byte(jwtSecret)) + if err != nil { + return "", time.Now(), err + } + return tokenString, expirationTime, nil +} + +// CheckToken to verify if a token used is valid +func (m *UserManager) CheckToken(jwtSecret, tokenStr string) (TokenClaims, bool) { + claims := &TokenClaims{} + tkn, err := jwt.ParseWithClaims(tokenStr, claims, func(token *jwt.Token) (interface{}, error) { + return []byte(jwtSecret), nil + }) + if err != nil { + log.Printf("Error %v", err) + return *claims, false + } + if !tkn.Valid { + log.Println("Not valid") + return *claims, false + } + return *claims, true +} + // Get user by username func (m *UserManager) Get(username string) (AdminUser, error) { var user AdminUser @@ -172,6 +225,25 @@ func (m *UserManager) ChangePassword(username, password string) error { return nil } +// UpdateToken for user by username +func (m *UserManager) UpdateToken(username, token string, exp time.Time) error { + user, err := m.Get(username) + if err != nil { + return fmt.Errorf("error getting user %v", err) + } + if token != user.APIToken { + if err := m.DB.Model(&user).Updates( + AdminUser{ + APIToken: token, + TokenExpire: exp, + LastAccess: time.Now(), + }).Error; err != nil { + return fmt.Errorf("Update %v", err) + } + } + return nil +} + // ChangeEmail for user by username func (m *UserManager) ChangeEmail(username, email string) error { user, err := m.Get(username) @@ -206,15 +278,29 @@ func (m *UserManager) UpdateMetadata(ipaddress, useragent, username string) erro if err != nil { return fmt.Errorf("error getting user %v", err) } - if ipaddress != user.LastIPAddress || useragent != user.LastUserAgent { - if err := m.DB.Model(&user).Updates( - AdminUser{ - LastIPAddress: ipaddress, - LastUserAgent: useragent, - LastAccess: time.Now(), - }).Error; err != nil { - return fmt.Errorf("Update %v", err) - } + if err := m.DB.Model(&user).Updates( + AdminUser{ + LastIPAddress: ipaddress, + LastUserAgent: useragent, + LastAccess: time.Now(), + }).Error; err != nil { + return fmt.Errorf("Update %v", err) + } + return nil +} + +// UpdateTokenIPAddress updates IP and Last Access for a user's token +func (m *UserManager) UpdateTokenIPAddress(ipaddress, username string) error { + user, err := m.Get(username) + if err != nil { + return fmt.Errorf("error getting user %v", err) + } + if err := m.DB.Model(&user).Updates( + AdminUser{ + LastIPAddress: ipaddress, + LastTokenUse: time.Now(), + }).Error; err != nil { + return fmt.Errorf("Update %v", err) } return nil }