Skip to content

Commit

Permalink
chore: standardize status api errors for task config policies (#10119)
Browse files Browse the repository at this point in the history
  • Loading branch information
carolinaecalderon authored Oct 24, 2024
1 parent e834302 commit 4394f29
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 40 deletions.
73 changes: 37 additions & 36 deletions master/internal/api_config_policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ import (
)

const (
noWorkloadErr = "no workload type specified."
noPoliciesErr = "no specified config policies."
noWorkloadErr = "no workload type specified"
noPoliciesErr = "no specified config policies"
globalPriorityErr = "global priority limit already exists"
invalidWorkloadTypeErr = "invalid workload type"
settingPoliciesErr = "error setting task config policies"
)

func (a *apiServer) validatePoliciesAndWorkloadType(
Expand All @@ -34,16 +35,16 @@ func (a *apiServer) validatePoliciesAndWorkloadType(
) error {
// Validate workload type
if !configpolicy.ValidWorkloadType(workloadType) {
errMessage := fmt.Sprintf(invalidWorkloadTypeErr+": %s.", workloadType)
errMessage := fmt.Errorf(invalidWorkloadTypeErr+": %s", workloadType)
if len(workloadType) == 0 {
errMessage = noWorkloadErr
errMessage = fmt.Errorf(noWorkloadErr)
}
return status.Errorf(codes.InvalidArgument, errMessage)
return errMessage
}

// Validate policies.
if len(configPolicies) == 0 {
return status.Errorf(codes.InvalidArgument, noPoliciesErr)
return fmt.Errorf(noPoliciesErr)
}

// Validate the input config based on workload type.
Expand All @@ -56,15 +57,15 @@ func (a *apiServer) validatePoliciesAndWorkloadType(
return configpolicy.ValidateNTSCConfig(globalConfigPolicies, configPolicies,
priorityEnabledErr)
default:
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(invalidWorkloadTypeErr+": %s.", workloadType))
return fmt.Errorf(invalidWorkloadTypeErr+": %s", workloadType)
}
}

func parseConfigPolicies(configAndConstraints string) (
tcps map[string]interface{}, invariantConfig *string, constraints *string, err error,
) {
if len(configAndConstraints) == 0 {
return nil, nil, nil, status.Error(codes.InvalidArgument, "nothing to parse, empty "+
return nil, nil, nil, fmt.Errorf("nothing to parse, empty " +
"config and constraints input")
}
// Standardize to JSON policies file format.
Expand Down Expand Up @@ -111,33 +112,33 @@ func (a *apiServer) PutWorkspaceConfigPolicies(
// Request Validation
curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
return nil, status.Errorf(codes.NotFound, err.Error())
}

w, err := a.GetWorkspaceByID(ctx, req.WorkspaceId, *curUser, false)
if err != nil {
return nil, err
return nil, status.Errorf(codes.NotFound, err.Error())
}

err = configpolicy.AuthZProvider.Get().CanModifyWorkspaceConfigPolicies(ctx, *curUser, w)
if err != nil {
return nil, err
return nil, status.Errorf(codes.PermissionDenied, err.Error())
}

globalConfigPolicies, err := configpolicy.GetTaskConfigPolicies(ctx, nil, req.WorkloadType)
if err != nil {
return nil, err
return nil, status.Errorf(codes.NotFound, err.Error())
}

err = a.validatePoliciesAndWorkloadType(globalConfigPolicies, req.WorkloadType,
req.ConfigPolicies)
if err != nil {
return nil, err
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}

configPolicies, invariantConfig, constraints, err := parseConfigPolicies(req.ConfigPolicies)
if err != nil {
return nil, err
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}

err = configpolicy.SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
Expand All @@ -149,13 +150,13 @@ func (a *apiServer) PutWorkspaceConfigPolicies(
Constraints: constraints,
})
if err != nil {
return nil, fmt.Errorf("error setting task config policies: %w", err)
return nil, status.Errorf(codes.Internal, settingPoliciesErr+": "+err.Error())
}

return &apiv1.PutWorkspaceConfigPoliciesResponse{
ConfigPolicies: configpolicy.MarshalConfigPolicy(configPolicies),
},
err
nil
}

// Add or update global task config policies.
Expand All @@ -166,22 +167,22 @@ func (a *apiServer) PutGlobalConfigPolicies(

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
return nil, status.Errorf(codes.PermissionDenied, err.Error())
}

err = configpolicy.AuthZProvider.Get().CanModifyGlobalConfigPolicies(ctx, curUser)
if err != nil {
return nil, err
return nil, status.Errorf(codes.PermissionDenied, err.Error())
}

err = a.validatePoliciesAndWorkloadType(nil, req.WorkloadType, req.ConfigPolicies)
if err != nil {
return nil, err
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}

configPolicies, invariantConfig, constraints, err := parseConfigPolicies(req.ConfigPolicies)
if err != nil {
return nil, err
return nil, status.Errorf(codes.InvalidArgument, err.Error())
}

err = configpolicy.SetTaskConfigPolicies(ctx, &model.TaskConfigPolicies{
Expand All @@ -192,13 +193,13 @@ func (a *apiServer) PutGlobalConfigPolicies(
Constraints: constraints,
})
if err != nil {
return nil, fmt.Errorf("error setting task config policies: %w", err)
return nil, status.Errorf(codes.Internal, settingPoliciesErr+": "+err.Error())
}

return &apiv1.PutGlobalConfigPoliciesResponse{
ConfigPolicies: configpolicy.MarshalConfigPolicy(configPolicies),
},
err
nil
}

// Get workspace task config policies.
Expand All @@ -209,22 +210,22 @@ func (a *apiServer) GetWorkspaceConfigPolicies(

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
return nil, status.Errorf(codes.NotFound, err.Error())
}

w, err := a.GetWorkspaceByID(ctx, req.WorkspaceId, *curUser, false)
if err != nil {
return nil, err
return nil, status.Errorf(codes.NotFound, err.Error())
}

err = configpolicy.AuthZProvider.Get().CanViewWorkspaceConfigPolicies(ctx, *curUser, w)
if err != nil {
return nil, err
return nil, status.Errorf(codes.PermissionDenied, err.Error())
}

resp, err := a.getConfigPolicies(ctx, ptrs.Ptr(int(req.WorkspaceId)), req.WorkloadType)
if err != nil {
return nil, err
return nil, status.Errorf(codes.Internal, err.Error())
}

return &apiv1.GetWorkspaceConfigPoliciesResponse{ConfigPolicies: resp}, nil
Expand All @@ -238,17 +239,17 @@ func (a *apiServer) GetGlobalConfigPolicies(

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
return nil, status.Errorf(codes.NotFound, err.Error())
}

err = configpolicy.AuthZProvider.Get().CanViewGlobalConfigPolicies(ctx, curUser)
if err != nil {
return nil, err
return nil, status.Errorf(codes.PermissionDenied, err.Error())
}

resp, err := a.getConfigPolicies(ctx, nil, req.WorkloadType)
if err != nil {
return nil, err
return nil, status.Errorf(codes.Internal, err.Error())
}

return &apiv1.GetGlobalConfigPoliciesResponse{ConfigPolicies: resp}, nil
Expand Down Expand Up @@ -296,17 +297,17 @@ func (a *apiServer) DeleteWorkspaceConfigPolicies(

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
return nil, status.Errorf(codes.NotFound, err.Error())
}

w, err := a.GetWorkspaceByID(ctx, req.WorkspaceId, *curUser, false)
if err != nil {
return nil, err
return nil, status.Errorf(codes.NotFound, err.Error())
}

err = configpolicy.AuthZProvider.Get().CanModifyWorkspaceConfigPolicies(ctx, *curUser, w)
if err != nil {
return nil, err
return nil, status.Error(codes.PermissionDenied, err.Error())
}

if !configpolicy.ValidWorkloadType(req.WorkloadType) {
Expand All @@ -320,7 +321,7 @@ func (a *apiServer) DeleteWorkspaceConfigPolicies(
err = configpolicy.DeleteConfigPolicies(ctx, ptrs.Ptr(int(req.WorkspaceId)),
req.WorkloadType)
if err != nil {
return nil, err
return nil, status.Errorf(codes.Internal, err.Error())
}
return &apiv1.DeleteWorkspaceConfigPoliciesResponse{}, nil
}
Expand All @@ -333,12 +334,12 @@ func (a *apiServer) DeleteGlobalConfigPolicies(

curUser, _, err := grpcutil.GetUser(ctx)
if err != nil {
return nil, err
return nil, status.Errorf(codes.NotFound, err.Error())
}

err = configpolicy.AuthZProvider.Get().CanModifyGlobalConfigPolicies(ctx, curUser)
if err != nil {
return nil, err
return nil, status.Errorf(codes.PermissionDenied, err.Error())
}

if !configpolicy.ValidWorkloadType(req.WorkloadType) {
Expand All @@ -351,7 +352,7 @@ func (a *apiServer) DeleteGlobalConfigPolicies(

err = configpolicy.DeleteConfigPolicies(ctx, nil, req.WorkloadType)
if err != nil {
return nil, err
return nil, status.Errorf(codes.Internal, err.Error())
}
return &apiv1.DeleteGlobalConfigPoliciesResponse{}, nil
}
10 changes: 6 additions & 4 deletions master/internal/api_config_policies_intg_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import (
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"

"github.com/determined-ai/determined/master/internal/configpolicy"
"github.com/determined-ai/determined/master/internal/db"
Expand Down Expand Up @@ -504,15 +506,15 @@ func TestAuthZCanModifyConfigPolicies(t *testing.T) {
WorkspaceId: workspaceID,
WorkloadType: model.NTSCType,
})
require.Equal(t, expectedErr, err)
require.Equal(t, status.Error(codes.PermissionDenied, expectedErr.Error()), err)

_, err = api.PutWorkspaceConfigPolicies(ctx,
&apiv1.PutWorkspaceConfigPoliciesRequest{
WorkspaceId: workspaceID,
WorkloadType: model.NTSCType,
ConfigPolicies: validConstraintsPolicyYAML,
})
require.Equal(t, expectedErr, err)
require.Equal(t, status.Error(codes.PermissionDenied, expectedErr.Error()), err)

// (Workspace-level) Nil error returns whatever the request returned.
workspaceAuthZ.On("CanGetWorkspace", mock.Anything, mock.Anything, mock.Anything).
Expand Down Expand Up @@ -542,14 +544,14 @@ func TestAuthZCanModifyConfigPolicies(t *testing.T) {

_, err = api.DeleteGlobalConfigPolicies(ctx,
&apiv1.DeleteGlobalConfigPoliciesRequest{WorkloadType: model.NTSCType})
require.Equal(t, expectedErr, err)
require.Equal(t, status.Error(codes.PermissionDenied, expectedErr.Error()), err)

_, err = api.PutGlobalConfigPolicies(ctx,
&apiv1.PutGlobalConfigPoliciesRequest{
WorkloadType: model.NTSCType,
ConfigPolicies: validNTSCConfigPolicyYAML,
})
require.Equal(t, expectedErr, err)
require.Equal(t, status.Error(codes.PermissionDenied, expectedErr.Error()), err)

// (Global) Nil error returns whatever the request returned.
configPolicyAuthZ.On("CanModifyGlobalConfigPolicies", mock.Anything, mock.Anything).
Expand Down

0 comments on commit 4394f29

Please sign in to comment.