Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add GET global config policies API #9952

Merged
merged 5 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 41 additions & 16 deletions master/internal/api_config_policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,52 @@ func (a *apiServer) GetWorkspaceConfigPolicies(
return nil, err
}

if !configpolicy.ValidWorkloadType(req.WorkloadType) {
errMessage := fmt.Sprintf("invalid workload type: %s.", req.WorkloadType)
if len(req.WorkloadType) == 0 {
resp, err := a.GetConfigPolicies(ctx, ptrs.Ptr(int(req.WorkspaceId)), req.WorkloadType)
amandavialva01 marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return nil, err
}

return &apiv1.GetWorkspaceConfigPoliciesResponse{ConfigPolicies: resp}, nil
}

// Get global task config policies.
func (a *apiServer) GetGlobalConfigPolicies(
ctx context.Context, req *apiv1.GetGlobalConfigPoliciesRequest,
) (*apiv1.GetGlobalConfigPoliciesResponse, error) {
license.RequireLicense("manage config policies")

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
}

permErr, err := cluster.AuthZProvider.Get().CanViewGlobalConfigPolicies(ctx, curUser)
if err != nil {
return nil, err
} else if permErr != nil {
return nil, permErr
}
resp, err := a.GetConfigPolicies(ctx, nil, req.WorkloadType)
if err != nil {
return nil, err
}

return &apiv1.GetGlobalConfigPoliciesResponse{ConfigPolicies: resp}, nil
}

func (*apiServer) GetConfigPolicies(
salonig23 marked this conversation as resolved.
Show resolved Hide resolved
ctx context.Context, workspaceID *int, workloadType string,
) (*structpb.Struct, error) {
if !configpolicy.ValidWorkloadType(workloadType) {
errMessage := fmt.Sprintf("invalid workload type: %s.", workloadType)
if len(workloadType) == 0 {
errMessage = noWorkloadErr
}
return nil, status.Errorf(codes.InvalidArgument, errMessage)
}

configPolicies, err := configpolicy.GetTaskConfigPolicies(
ctx, ptrs.Ptr(int(req.WorkspaceId)), req.WorkloadType)
ctx, workspaceID, workloadType)
if err != nil {
return nil, err
}
Expand All @@ -111,18 +147,7 @@ func (a *apiServer) GetWorkspaceConfigPolicies(
}
policyMap["constraints"] = constraintsMap
}
return &apiv1.GetWorkspaceConfigPoliciesResponse{ConfigPolicies: configpolicy.MarshalConfigPolicy(policyMap)}, nil
}

// Get global task config policies.
func (*apiServer) GetGlobalConfigPolicies(
ctx context.Context, req *apiv1.GetGlobalConfigPoliciesRequest,
) (*apiv1.GetGlobalConfigPoliciesResponse, error) {
if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
}
data, err := stubData()
return &apiv1.GetGlobalConfigPoliciesResponse{ConfigPolicies: data}, err
return configpolicy.MarshalConfigPolicy(policyMap), nil
}

// Delete workspace task config policies.
Expand Down
116 changes: 64 additions & 52 deletions master/internal/api_config_policies_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -213,101 +213,113 @@ func TestBasicRBACConfigPolicyPerms(t *testing.T) {
}
}

func TestGetWorkspaceConfigPolicies(t *testing.T) {
func TestGetConfigPolicies(t *testing.T) {
api, curUser, ctx := setupAPITest(t, nil)
testutils.MustLoadLicenseAndKeyFromFilesystem("../../")

wkspResp, err := api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()})
require.NoError(t, err)
workspaceID1 := wkspResp.Workspace.Id
workspaceID1 := ptrs.Ptr(int(wkspResp.Workspace.Id))
wkspResp, err = api.PostWorkspace(ctx, &apiv1.PostWorkspaceRequest{Name: uuid.New().String()})
require.NoError(t, err)
workspaceID2 := wkspResp.Workspace.Id
workspaceID2 := ptrs.Ptr(int(wkspResp.Workspace.Id))

// set only config policy
// set only constraints policy for workspace 1
taskConfigPolicies := &model.TaskConfigPolicies{
WorkspaceID: ptrs.Ptr(int(workspaceID1)),
WorkloadType: model.NTSCType,
LastUpdatedBy: curUser.ID,
InvariantConfig: ptrs.Ptr(configpolicy.DefaultInvariantConfigStr),
}
err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies)
require.NoError(t, err)

// set only constraints policy
taskConfigPolicies = &model.TaskConfigPolicies{
WorkspaceID: ptrs.Ptr(int(workspaceID1)),
WorkspaceID: workspaceID1,
WorkloadType: model.ExperimentType,
LastUpdatedBy: curUser.ID,
Constraints: ptrs.Ptr(configpolicy.DefaultConstraintsStr),
}
err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies)
require.NoError(t, err)

// set both config and constraints policy
taskConfigPolicies = &model.TaskConfigPolicies{
WorkspaceID: ptrs.Ptr(int(workspaceID2)),
WorkloadType: model.NTSCType,
LastUpdatedBy: curUser.ID,
InvariantConfig: ptrs.Ptr(configpolicy.DefaultInvariantConfigStr),
Constraints: ptrs.Ptr(configpolicy.DefaultConstraintsStr),
}
// set only config policy for workspace 1
taskConfigPolicies.WorkloadType = model.NTSCType
taskConfigPolicies.Constraints = nil
taskConfigPolicies.InvariantConfig = ptrs.Ptr(configpolicy.DefaultInvariantConfigStr)
err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies)
require.NoError(t, err)

// set both config and constraints policy for workspace 2
taskConfigPolicies.WorkspaceID = workspaceID2
taskConfigPolicies.Constraints = ptrs.Ptr(configpolicy.DefaultConstraintsStr)
err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies)
require.NoError(t, err)

// set both config and constraints policy globally
taskConfigPolicies.WorkspaceID = nil
salonig23 marked this conversation as resolved.
Show resolved Hide resolved
err = configpolicy.SetTaskConfigPolicies(ctx, taskConfigPolicies)
require.NoError(t, err)

cases := []struct {
name string
req *apiv1.GetWorkspaceConfigPoliciesRequest
workspaceID *int
workloadType string
err error
hasConfig bool
hasConstraints bool
}{
{
"invalid workload type",
&apiv1.GetWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID1,
WorkloadType: "bad workload type",
},
workspaceID1,
"bad workload type",
fmt.Errorf("invalid workload type"),
false,
false,
},
{
"empty workload type",
&apiv1.GetWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID1,
WorkloadType: "",
},
workspaceID1,
"",
fmt.Errorf(noWorkloadErr),
false,
false,
},
{
"valid request only config",
&apiv1.GetWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID1,
WorkloadType: model.NTSCType,
},
"valid workspace request, only config",
workspaceID1,
model.NTSCType,
nil,
true,
false,
},
{
"valid request only constraints",
&apiv1.GetWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID1,
WorkloadType: model.ExperimentType,
},
"valid workspace request, only constraints",
workspaceID1,
model.ExperimentType,
nil,
false,
true,
},
{
"valid request both configs and constraints",
&apiv1.GetWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID2,
WorkloadType: model.NTSCType,
},
"valid workspace request both configs and constraints",
workspaceID2,
model.NTSCType,
nil,
true,
true,
},
{
"valid workspace request, only config",
salonig23 marked this conversation as resolved.
Show resolved Hide resolved
workspaceID1,
model.NTSCType,
nil,
true,
false,
},
{
"valid workspace request, only constraints",
salonig23 marked this conversation as resolved.
Show resolved Hide resolved
workspaceID1,
model.ExperimentType,
nil,
false,
true,
},
{
"valid global request both configs and constraints",
nil,
model.NTSCType,
nil,
true,
true,
Expand All @@ -316,7 +328,7 @@ func TestGetWorkspaceConfigPolicies(t *testing.T) {

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
resp, err := api.GetWorkspaceConfigPolicies(ctx, test.req)
resp, err := api.GetConfigPolicies(ctx, test.workspaceID, test.workloadType)
salonig23 marked this conversation as resolved.
Show resolved Hide resolved
if test.err != nil {
require.ErrorContains(t, err, test.err.Error())
return
Expand All @@ -325,15 +337,15 @@ func TestGetWorkspaceConfigPolicies(t *testing.T) {
require.NotNil(t, resp)

if test.hasConfig {
require.Contains(t, resp.ConfigPolicies.String(), "config")
require.Contains(t, resp.String(), "config")
} else {
require.NotContains(t, resp.ConfigPolicies.String(), "config")
require.NotContains(t, resp.String(), "config")
}

if test.hasConstraints {
require.Contains(t, resp.ConfigPolicies.String(), "constraints")
require.Contains(t, resp.String(), "constraints")
} else {
require.NotContains(t, resp.ConfigPolicies.String(), "constraints")
require.NotContains(t, resp.String(), "constraints")
}
})
}
Expand Down
5 changes: 5 additions & 0 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@ package configpolicy

import (
"context"
"database/sql"
"fmt"
"strings"

"github.com/pkg/errors"
"github.com/uptrace/bun"

"github.com/determined-ai/determined/master/internal/db"
Expand Down Expand Up @@ -82,6 +84,9 @@ func GetTaskConfigPolicies(ctx context.Context,
Where("workload_type = ?", workloadType).
Scan(ctx)
if err != nil {
if errors.Is(err, db.ErrNotFound) || errors.Cause(err) == sql.ErrNoRows {
return &model.TaskConfigPolicies{}, nil
salonig23 marked this conversation as resolved.
Show resolved Hide resolved
}
return nil, fmt.Errorf("error retrieving %v task config policies for "+
"workspace with ID %d: %w", workloadType, scope, err)
}
Expand Down
Loading