From 56cfa0ee47ab4cbfa1c8e686580f62553322e04f Mon Sep 17 00:00:00 2001 From: Juan Antonio Osorio Date: Fri, 14 Jun 2024 13:25:48 +0300 Subject: [PATCH] Take into account hierarchy when dealing with rule types Rule types are usable by projects down the hierarchy, this allows someone in a child project to create a profile that uses rule types approved in a parent project. This is very handy since folks with permission to write profiles don't need to have permissions to write rule types, they can simply use some already approved and created. This also adds the constraint that rule type names are unique down the hierarchy. That is, a user cannot create a rule type name that overrides a rule type that exists in a parent project. Signed-off-by: Juan Antonio Osorio --- database/query/rule_types.sql | 2 +- internal/controlplane/handlers_ruletype.go | 5 ++-- internal/db/rule_types.sql.go | 9 ++++---- internal/engine/executor.go | 9 ++------ internal/engine/executor_test.go | 4 ++-- internal/profiles/service.go | 14 +++++++---- internal/profiles/validator.go | 27 ++++++++++++++-------- internal/ruletypes/service.go | 16 +++++++++---- 8 files changed, 50 insertions(+), 36 deletions(-) diff --git a/database/query/rule_types.sql b/database/query/rule_types.sql index fde6b419cc..5b3227ad08 100644 --- a/database/query/rule_types.sql +++ b/database/query/rule_types.sql @@ -26,7 +26,7 @@ SELECT * FROM rule_type WHERE project_id = $1; SELECT * FROM rule_type WHERE id = $1; -- name: GetRuleTypeByName :one -SELECT * FROM rule_type WHERE project_id = sqlc.arg(project_id) AND lower(name) = lower(sqlc.arg(name)); +SELECT * FROM rule_type WHERE project_id = ANY(sqlc.arg(projects)::uuid[]) AND lower(name) = lower(sqlc.arg(name)); -- name: DeleteRuleType :exec DELETE FROM rule_type WHERE id = $1; diff --git a/internal/controlplane/handlers_ruletype.go b/internal/controlplane/handlers_ruletype.go index 60a74ee7d3..cbd77c5406 100644 --- a/internal/controlplane/handlers_ruletype.go +++ b/internal/controlplane/handlers_ruletype.go @@ -83,8 +83,9 @@ func (s *Server) GetRuleTypeByName( resp := &minderv1.GetRuleTypeByNameResponse{} rtdb, err := s.store.GetRuleTypeByName(ctx, db.GetRuleTypeByNameParams{ - ProjectID: entityCtx.Project.ID, - Name: in.GetName(), + // TODO: Add option to fetch rule types from parent projects too + Projects: []uuid.UUID{entityCtx.Project.ID}, + Name: in.GetName(), }) if errors.Is(err, sql.ErrNoRows) { return nil, util.UserVisibleError(codes.NotFound, "rule type %s not found", in.GetName()) diff --git a/internal/db/rule_types.sql.go b/internal/db/rule_types.sql.go index b4cd360dcb..f5735727ff 100644 --- a/internal/db/rule_types.sql.go +++ b/internal/db/rule_types.sql.go @@ -10,6 +10,7 @@ import ( "encoding/json" "github.com/google/uuid" + "github.com/lib/pq" ) const createRuleType = `-- name: CreateRuleType :one @@ -110,16 +111,16 @@ func (q *Queries) GetRuleTypeByID(ctx context.Context, id uuid.UUID) (RuleType, } const getRuleTypeByName = `-- name: GetRuleTypeByName :one -SELECT id, name, provider, project_id, description, guidance, definition, created_at, updated_at, severity_value, provider_id, subscription_id, display_name FROM rule_type WHERE project_id = $1 AND lower(name) = lower($2) +SELECT id, name, provider, project_id, description, guidance, definition, created_at, updated_at, severity_value, provider_id, subscription_id, display_name FROM rule_type WHERE project_id = ANY($1::uuid[]) AND lower(name) = lower($2) ` type GetRuleTypeByNameParams struct { - ProjectID uuid.UUID `json:"project_id"` - Name string `json:"name"` + Projects []uuid.UUID `json:"projects"` + Name string `json:"name"` } func (q *Queries) GetRuleTypeByName(ctx context.Context, arg GetRuleTypeByNameParams) (RuleType, error) { - row := q.db.QueryRowContext(ctx, getRuleTypeByName, arg.ProjectID, arg.Name) + row := q.db.QueryRowContext(ctx, getRuleTypeByName, pq.Array(arg.Projects), arg.Name) var i RuleType err := row.Scan( &i.ID, diff --git a/internal/engine/executor.go b/internal/engine/executor.go index 2ec37b1785..942bf2aac9 100644 --- a/internal/engine/executor.go +++ b/internal/engine/executor.go @@ -273,17 +273,12 @@ func (e *Executor) getEvaluator( return nil, nil, fmt.Errorf("error creating eval status params: %w", err) } - // NOTE: We're only using the first project in the hierarchy for now. - // This means that a rule type must exist in the same project as the profile. - // This will be revisited in the future. - projID := hierarchy[0] - // Load Rule Class from database // TODO(jaosorior): Rule types should be cached in memory so // we don't have to query the database for each rule. dbrt, err := e.querier.GetRuleTypeByName(ctx, db.GetRuleTypeByNameParams{ - ProjectID: projID, - Name: rule.Type, + Projects: hierarchy, + Name: rule.Type, }) if err != nil { return nil, nil, fmt.Errorf("error getting rule type when traversing profile %s: %w", params.ProfileID, err) diff --git a/internal/engine/executor_test.go b/internal/engine/executor_test.go index 21754a5a2c..3299d80cfb 100644 --- a/internal/engine/executor_test.go +++ b/internal/engine/executor_test.go @@ -220,8 +220,8 @@ default allow = true`, mockStore.EXPECT(). GetRuleTypeByName(gomock.Any(), db.GetRuleTypeByNameParams{ - ProjectID: projectID, - Name: passthroughRuleType, + Projects: []uuid.UUID{projectID}, + Name: passthroughRuleType, }).Return(db.RuleType{ ID: ruleTypeID, Name: passthroughRuleType, diff --git a/internal/profiles/service.go b/internal/profiles/service.go index b987c85751..ccc60b7be1 100644 --- a/internal/profiles/service.go +++ b/internal/profiles/service.go @@ -563,12 +563,16 @@ func (_ *profileService) getRulesFromProfile( // track them in the db later. rulesInProf := make(RuleMapping) - err := TraverseAllRulesForPipeline(profile, func(r *minderv1.Profile_Rule) error { - // TODO: This will need to be updated to support - // the hierarchy tree once that's settled in. + // Allows a profile to use rules from parent projects + projects, err := qtx.GetParentProjects(ctx, projectID) + if err != nil { + return nil, fmt.Errorf("error getting parent projects: %w", err) + } + + err = TraverseAllRulesForPipeline(profile, func(r *minderv1.Profile_Rule) error { rtdb, err := qtx.GetRuleTypeByName(ctx, db.GetRuleTypeByNameParams{ - ProjectID: projectID, - Name: r.GetType(), + Projects: projects, + Name: r.GetType(), }) if err != nil { return fmt.Errorf("error getting rule type %s: %w", r.GetType(), err) diff --git a/internal/profiles/validator.go b/internal/profiles/validator.go index 62093b512a..9de2cd19c7 100644 --- a/internal/profiles/validator.go +++ b/internal/profiles/validator.go @@ -97,12 +97,16 @@ func (_ *Validator) validateRuleParams( // track them in the db later. rulesInProfile := make(RuleMapping) - err := TraverseAllRulesForPipeline(prof, func(profileRule *minderv1.Profile_Rule) error { - // TODO: This will need to be updated to support - // the hierarchy tree once that's settled in. + // Allows a profile to use rules from parent projects + projects, err := qtx.GetParentProjects(ctx, projectID) + if err != nil { + return nil, fmt.Errorf("error getting parent projects: %w", err) + } + + err = TraverseAllRulesForPipeline(prof, func(profileRule *minderv1.Profile_Rule) error { ruleType, err := qtx.GetRuleTypeByName(ctx, db.GetRuleTypeByNameParams{ - ProjectID: projectID, - Name: profileRule.GetType(), + Projects: projects, + Name: profileRule.GetType(), }) if err != nil { @@ -272,13 +276,16 @@ func validateEntities( profile *minderv1.Profile, projectID uuid.UUID, ) error { + projects, err := qtx.GetParentProjects(ctx, projectID) + if err != nil { + return fmt.Errorf("error getting parent projects: %w", err) + } + // validate that the entities in the profile match the entities in the project - err := TraverseRuleTypesForEntities(profile, func(entity minderv1.Entity, rule *minderv1.Profile_Rule) error { - // TODO: This will need to be updated to support - // the hierarchy tree once that's settled in. + err = TraverseRuleTypesForEntities(profile, func(entity minderv1.Entity, rule *minderv1.Profile_Rule) error { ruleType, err := qtx.GetRuleTypeByName(ctx, db.GetRuleTypeByNameParams{ - ProjectID: projectID, - Name: rule.GetType(), + Projects: projects, + Name: rule.GetType(), }) if err != nil { diff --git a/internal/ruletypes/service.go b/internal/ruletypes/service.go index 0e0dc9308c..b02c6e4eef 100644 --- a/internal/ruletypes/service.go +++ b/internal/ruletypes/service.go @@ -113,9 +113,14 @@ func (_ *ruleTypeService) CreateRuleType( ruleTypeName := ruleType.GetName() ruleTypeDef := ruleType.GetDef() - _, err := qtx.GetRuleTypeByName(ctx, db.GetRuleTypeByNameParams{ - ProjectID: projectID, - Name: ruleTypeName, + projects, err := qtx.GetParentProjects(ctx, projectID) + if err != nil { + return nil, fmt.Errorf("failed to get parent projects: %w", err) + } + + _, err = qtx.GetRuleTypeByName(ctx, db.GetRuleTypeByNameParams{ + Projects: projects, + Name: ruleTypeName, }) if err == nil { return nil, ErrRuleAlreadyExists @@ -177,8 +182,9 @@ func (_ *ruleTypeService) UpdateRuleType( ruleTypeDef := ruleType.GetDef() oldRuleType, err := qtx.GetRuleTypeByName(ctx, db.GetRuleTypeByNameParams{ - ProjectID: projectID, - Name: ruleTypeName, + // we only need to check the project that the rule type is in + Projects: []uuid.UUID{projectID}, + Name: ruleTypeName, }) if err != nil { if errors.Is(err, sql.ErrNoRows) {