Skip to content

Commit

Permalink
chore: add unmarshal functions for task config policies (#9896)
Browse files Browse the repository at this point in the history
  • Loading branch information
kkunapuli authored Sep 6, 2024
1 parent d842383 commit 4a28c10
Show file tree
Hide file tree
Showing 8 changed files with 444 additions and 97 deletions.
19 changes: 8 additions & 11 deletions harness/determined/common/api/bindings.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

41 changes: 9 additions & 32 deletions master/internal/api_config_policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,13 @@ import (
"google.golang.org/protobuf/types/known/structpb"
"gopkg.in/yaml.v3"

"github.com/determined-ai/determined/master/internal/configpolicy"
"github.com/determined-ai/determined/proto/pkg/apiv1"
)

// workload types are used to distinguish configuration for experiments and NTSCs.
const (
// Should not be used.
WorkloadTypeUnspecified = "UNKNOWN"
// Configuration policies for experiments.
WorkloadTypeExperiment = "EXPERIMENT"
// Configuration policies for NTSC.
WorkloadTypeNTSC = "NTSC"
)

func validWorkloadEnum(val string) bool {
switch val {
case WorkloadTypeExperiment, WorkloadTypeNTSC:
return true
default:
return false
}
}

func stubData() (*structpb.Struct, error) {
const yamlString = `
invariant_configs:
invariant_config:
description: "test"
constraints:
resources:
Expand All @@ -44,19 +26,14 @@ constraints:
return nil, fmt.Errorf("unable to unmarshal yaml: %w", err)
}

// convert map to protobuf struct
yamlStruct, err := structpb.NewStruct(yamlMap)
if err != nil {
return nil, fmt.Errorf("unable to convert map to protobuf struct: %w", err)
}
return yamlStruct, nil
return configpolicy.MarshalConfigPolicy(yamlMap), nil
}

// Add or update workspace task config policies.
func (*apiServer) PutWorkspaceConfigPolicies(
ctx context.Context, req *apiv1.PutWorkspaceConfigPoliciesRequest,
) (*apiv1.PutWorkspaceConfigPoliciesResponse, error) {
if !validWorkloadEnum(req.WorkloadType) {
if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
}
data, err := stubData()
Expand All @@ -67,7 +44,7 @@ func (*apiServer) PutWorkspaceConfigPolicies(
func (*apiServer) PutGlobalConfigPolicies(
ctx context.Context, req *apiv1.PutGlobalConfigPoliciesRequest,
) (*apiv1.PutGlobalConfigPoliciesResponse, error) {
if !validWorkloadEnum(req.WorkloadType) {
if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
}
data, err := stubData()
Expand All @@ -78,7 +55,7 @@ func (*apiServer) PutGlobalConfigPolicies(
func (*apiServer) GetWorkspaceConfigPolicies(
ctx context.Context, req *apiv1.GetWorkspaceConfigPoliciesRequest,
) (*apiv1.GetWorkspaceConfigPoliciesResponse, error) {
if !validWorkloadEnum(req.WorkloadType) {
if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
}
data, err := stubData()
Expand All @@ -89,7 +66,7 @@ func (*apiServer) GetWorkspaceConfigPolicies(
func (*apiServer) GetGlobalConfigPolicies(
ctx context.Context, req *apiv1.GetGlobalConfigPoliciesRequest,
) (*apiv1.GetGlobalConfigPoliciesResponse, error) {
if !validWorkloadEnum(req.WorkloadType) {
if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
}
data, err := stubData()
Expand All @@ -100,7 +77,7 @@ func (*apiServer) GetGlobalConfigPolicies(
func (*apiServer) DeleteWorkspaceConfigPolicies(
ctx context.Context, req *apiv1.DeleteWorkspaceConfigPoliciesRequest,
) (*apiv1.DeleteWorkspaceConfigPoliciesResponse, error) {
if !validWorkloadEnum(req.WorkloadType) {
if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
}
return &apiv1.DeleteWorkspaceConfigPoliciesResponse{}, nil
Expand All @@ -110,7 +87,7 @@ func (*apiServer) DeleteWorkspaceConfigPolicies(
func (*apiServer) DeleteGlobalConfigPolicies(
ctx context.Context, req *apiv1.DeleteGlobalConfigPoliciesRequest,
) (*apiv1.DeleteGlobalConfigPoliciesResponse, error) {
if !validWorkloadEnum(req.WorkloadType) {
if !configpolicy.ValidWorkloadType(req.WorkloadType) {
return nil, fmt.Errorf("invalid workload type: %s", req.WorkloadType)
}
return &apiv1.DeleteGlobalConfigPoliciesResponse{}, nil
Expand Down
4 changes: 2 additions & 2 deletions master/internal/configpolicy/task_config_policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type Constraints struct {
// Submitted experiments whose constraint fields vary from the respective Constraint fields set
// within a given scope are rejected.
type ExperimentConfigPolicy struct {
InvariantConfig *expconf.ExperimentConfig `json:"config"`
InvariantConfig *expconf.ExperimentConfig `json:"invariant_config"`
Constraints *Constraints `json:"constraints"`
}

Expand All @@ -36,6 +36,6 @@ type ExperimentConfigPolicy struct {
// Submitted NTSC tasks whose constraint fields vary from the respective Constraint fields set
// within a given scope are rejected.
type NTSCConfigPolicy struct {
InvariantConfig *model.CommandConfig `json:"config"`
InvariantConfig *model.CommandConfig `json:"invariant_config"`
Constraints *Constraints `json:"constraints"`
}
68 changes: 68 additions & 0 deletions master/internal/configpolicy/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package configpolicy

import (
"bytes"
"encoding/json"
"fmt"

"github.com/ghodss/yaml"
"google.golang.org/protobuf/types/known/structpb"

"github.com/determined-ai/determined/master/pkg/model"
"github.com/determined-ai/determined/master/pkg/protoutils"
)

// ValidWorkloadType checks if the string is an accepted WorkloadType.
func ValidWorkloadType(val string) bool {
switch val {
case string(model.ExperimentType), string(model.NTSCType):
return true
default:
return false
}
}

// UnmarshalExperimentConfigPolicy unpacks a string into ExperimentConfigPolicy struct.
func UnmarshalExperimentConfigPolicy(str string) (*ExperimentConfigPolicy, error) {
var expConfigPolicy ExperimentConfigPolicy
var err error

dec := json.NewDecoder(bytes.NewReader([]byte(str)))
dec.DisallowUnknownFields()
if err = dec.Decode(&expConfigPolicy); err == nil {
// valid JSON input
return &expConfigPolicy, nil
}

if err = yaml.Unmarshal([]byte(str), &expConfigPolicy, yaml.DisallowUnknownFields); err == nil {
// valid Yaml input
return &expConfigPolicy, nil
}
// invalid JSON & invalid Yaml input
return nil, fmt.Errorf("invalid ExperimentConfigPolicy input: %w", err)
}

// UnmarshalNTSCConfigPolicy unpacks a string into NTSCConfigPolicy struct.
func UnmarshalNTSCConfigPolicy(str string) (*NTSCConfigPolicy, error) {
var ntscConfigPolicy NTSCConfigPolicy
var err error

dec := json.NewDecoder(bytes.NewReader([]byte(str)))
dec.DisallowUnknownFields()
if err = dec.Decode(&ntscConfigPolicy); err == nil {
// valid JSON input
return &ntscConfigPolicy, nil
}

if err = yaml.Unmarshal([]byte(str), &ntscConfigPolicy, yaml.DisallowUnknownFields); err == nil {
// valid Yaml input
return &ntscConfigPolicy, nil
}
// invalid JSON & Yaml input
return nil, fmt.Errorf("invalid ExperimentConfigPolicy input: %w", err)
}

// MarshalConfigPolicy packs a config policy into a proto struct.
func MarshalConfigPolicy(configPolicy interface{}) *structpb.Struct {
return protoutils.ToStruct(configPolicy)
}
Loading

0 comments on commit 4a28c10

Please sign in to comment.