Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
amandavialva01 committed Oct 18, 2024
1 parent ac2c3d1 commit f566a0c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 32 deletions.
5 changes: 2 additions & 3 deletions master/internal/configpolicy/postgres_task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ func SetTaskConfigPoliciesTx(ctx context.Context, tx *bun.Tx,
) error {
q := db.Bun().NewInsert().Model(tcp)

q = q.Set("last_updated_by = ?, last_updated_time = ?", tcp.LastUpdatedBy, tcp.LastUpdatedTime)
q = q.Set("invariant_config = ?", tcp.InvariantConfig)
q = q.Set("constraints = ?", tcp.Constraints)
q = q.Set("last_updated_by = ?, last_updated_time = ?, invariant_config = ?, constraints = ?",
tcp.LastUpdatedBy, tcp.LastUpdatedTime, tcp.InvariantConfig, tcp.Constraints)

if tcp.WorkspaceID == nil {
q = q.On("CONFLICT (workload_type) WHERE workspace_id IS NULL DO UPDATE")
Expand Down
88 changes: 60 additions & 28 deletions master/internal/configpolicy/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (

"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/protoutils"
"github.com/determined-ai/determined/master/pkg/schemas/expconf"
)

const (
Expand Down Expand Up @@ -84,15 +85,32 @@ func ValidateExperimentConfig(
// policies (since these fields will be overridden by the respective fields in the global
// policies).
var globalConstraints *model.Constraints
var globalConfig *expconf.ExperimentConfig
if globalConfigPolicies != nil {
globalConstraints, err = UnmarshalConfigPolicy[model.Constraints](*globalConfigPolicies.Constraints,
InvalidExperimentConfigPolicyErr)
if err != nil {
ConfigPolicyWarning(err.Error())
return nil
if globalConfigPolicies.Constraints != nil {
globalConstraints, err = UnmarshalConfigPolicy[model.Constraints](
*globalConfigPolicies.Constraints,
InvalidExperimentConfigPolicyErr,
)
if err != nil {
ConfigPolicyWarning(err.Error())
return err
}
}

if globalConfigPolicies.InvariantConfig != nil {
globalConfig, err = UnmarshalConfigPolicy[expconf.ExperimentConfig](
*globalConfigPolicies.InvariantConfig,
InvalidExperimentConfigPolicyErr,
)
if err != nil {
ConfigPolicyWarning(err.Error())
return err
}
}
configPolicyOverlap(globalConfigPolicies.Constraints, cp.Constraints)
configPolicyOverlap(globalConfigPolicies.InvariantConfig, cp.InvariantConfig)

warnConfigPolicyOverlap(globalConstraints, cp.Constraints)
warnConfigPolicyOverlap(globalConfig, cp.InvariantConfig)
}

if cp.Constraints != nil {
Expand All @@ -110,11 +128,9 @@ func ValidateExperimentConfig(
}

// Verify the workspace invariant config doesn't conflict with global constraints.
if globalConstraints != nil {
if err := checkConstraintConflicts(globalConstraints, cp.InvariantConfig.RawResources.RawMaxSlots,
cp.InvariantConfig.RawResources.RawSlotsPerTrial, cp.InvariantConfig.RawResources.RawPriority); err != nil {
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(InvalidExperimentConfigPolicyErr+": %s.", err))
}
if err := checkConstraintConflicts(globalConstraints, cp.InvariantConfig.RawResources.RawMaxSlots,
cp.InvariantConfig.RawResources.RawSlotsPerTrial, cp.InvariantConfig.RawResources.RawPriority); err != nil {
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(InvalidExperimentConfigPolicyErr+": %s.", err))
}
}
}
Expand Down Expand Up @@ -142,16 +158,27 @@ func ValidateNTSCConfig(
// policies (since these fields will be overridden by the respective fields in the global
// policies).
var globalConstraints *model.Constraints
var globalConfig *model.CommandConfig
if globalConfigPolicies != nil {
globalConstraints, err = UnmarshalConfigPolicy[model.Constraints](*globalConfigPolicies.Constraints,
InvalidNTSCConfigPolicyErr)
if err != nil {
ConfigPolicyWarning(err.Error())
return nil
if globalConfigPolicies.Constraints != nil {
globalConstraints, err = UnmarshalConfigPolicy[model.Constraints](*globalConfigPolicies.Constraints,
InvalidNTSCConfigPolicyErr)
if err != nil {
ConfigPolicyWarning(err.Error())
return err
}
}
configPolicyOverlap(globalConfigPolicies.Constraints, cp.Constraints)
configPolicyOverlap(
globalConfigPolicies.InvariantConfig, cp.InvariantConfig)
if globalConfigPolicies.InvariantConfig != nil {
globalConfig, err = UnmarshalConfigPolicy[model.CommandConfig](*globalConfigPolicies.InvariantConfig,
InvalidNTSCConfigPolicyErr)
if err != nil {
ConfigPolicyWarning(err.Error())
return err
}
}

warnConfigPolicyOverlap(globalConstraints, cp.Constraints)
warnConfigPolicyOverlap(globalConfig, cp.InvariantConfig)
}

if cp.Constraints != nil {
Expand All @@ -175,12 +202,10 @@ func ValidateNTSCConfig(
}

// Verify the workspace invariant config conflict with global constraints.
if globalConstraints != nil {
if err := checkConstraintConflicts(globalConstraints,
cp.InvariantConfig.Resources.MaxSlots, slots,
cp.InvariantConfig.Resources.Priority); err != nil {
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(InvalidNTSCConfigPolicyErr+": %s.", err))
}
if err := checkConstraintConflicts(globalConstraints,
cp.InvariantConfig.Resources.MaxSlots, slots,
cp.InvariantConfig.Resources.Priority); err != nil {
return status.Errorf(codes.InvalidArgument, fmt.Sprintf(InvalidNTSCConfigPolicyErr+": %s.", err))
}
}

Expand Down Expand Up @@ -217,9 +242,16 @@ func checkConstraintConflicts(constraints *model.Constraints, maxSlots, slots, p
return nil
}

// configPolicyOverlap compares two different configurations and
// warnConfigPolicyOverlap compares two different configurations and
// warns the user when both configurations define the same field.
func configPolicyOverlap(config1, config2 interface{}) {
func warnConfigPolicyOverlap(config1, config2 interface{}) {
if reflect.ValueOf(config1).Type() != reflect.ValueOf(config2).Type() &&
reflect.ValueOf(config1).Type() != reflect.ValueOf(&model.Constraints{}).Type() &&
reflect.ValueOf(config1).Type() != reflect.ValueOf(&model.CommandConfig{}).Type() &&
reflect.ValueOf(config1).Type() != reflect.ValueOf(&expconf.ExperimentConfig{}).Type() {
return
}

v1 := reflect.ValueOf(config1)
v2 := reflect.ValueOf(config2)

Expand Down
2 changes: 1 addition & 1 deletion master/internal/configpolicy/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ resources:
`
invariant_config:
resources:
max_slots: 15
max_slots: 15
`, nil,
)
require.NoError(t, err)
Expand Down

0 comments on commit f566a0c

Please sign in to comment.