diff --git a/internal/policy-template/policy-template-rego.go b/internal/policy-template/policy-template-rego.go index e5f83937..8a9e3731 100644 --- a/internal/policy-template/policy-template-rego.go +++ b/internal/policy-template/policy-template-rego.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/open-policy-agent/opa/ast" + "github.com/open-policy-agent/opa/types" "github.com/openinfradev/tks-api/internal/model" "github.com/openinfradev/tks-api/pkg/domain" "golang.org/x/exp/maps" @@ -16,16 +17,85 @@ const ( input_param_prefix = "input.parameters" input_extract_pattern = `input(\.parameters|\[\"parameters\"\])((\[\"[\w\-]+\"\])|(\[_\])|(\.\w+))*` //(\.\w+)*` // (\.\w+\[\"\w+\"\])|(\.\w+\[\w+\])|(\.\w+))*` // input_extract_pattern = `input\.parameters((\[\".+\"\])?(\.\w+\[\"\w+\"\])|(\.\w+\[\w+\])|(\.\w+))+` - obj_get_pattern = `object\.get\((input|input\.parameters|input\.parameters\.[^,]+)\, \"*([^,\"]+)\"*, [^\)]+\)` - package_name_regex = `package ([\w\.]+)[\n\r]+` - import_regex = `import ([\w\.]+)[\n\r]+` + obj_get_list_pattern = `object\.get\((input|input\.parameters|input\.parameters\.[^,]+), \"([^\"]+)\", \[\]\)` + obj_get_pattern = `object\.get\((input|input\.parameters|input\.parameters\.[^,]+), \"([^\"]+)\", [^\)]+\)` + package_name_regex = `package ([\w\.]+)[\n\r]+` + import_regex = `import ([\w\.]+)[\n\r]+` ) var ( input_extract_regex = regexp.MustCompile(input_extract_pattern) + obj_get_list_regex = regexp.MustCompile(obj_get_list_pattern) obj_get_regex = regexp.MustCompile(obj_get_pattern) + array_param_map = buildArrayParameterMap() ) +// OPA 내장 함수 중 array 인자를 가진 함수와 array 인자의 위치를 담은 자료구조를 생성한다. +// // OPA 내장 함수 목록은 정책에 상관없으므로 처음 로딩될 때 한 번만 호출해 변수에 담아두고 사용하면 된다. +// OPA 엔진 버전에 따라 달라진다. 0.62 버전 기준으로 이 함수의 결과 값은 다음과 같다. +// map[all:[true] any:[true] array.concat:[true true] array.reverse:[true] array.slice:[true false false] concat:[false true] count:[true] glob.match:[false true false] graph.reachable:[false true] graph.reachable_paths:[false true] internal.print:[true] json.filter:[false true] json.patch:[false true] json.remove:[false true] max:[true] min:[true] net.cidr_contains_matches:[true true] net.cidr_merge:[true] object.filter:[false true] object.remove:[false true] object.subset:[true true] object.union_n:[true] product:[true] sort:[true] sprintf:[false true] strings.any_prefix_match:[true true] strings.any_suffix_match:[true true] sum:[true] time.clock:[true] time.date:[true] time.diff:[true true] time.format:[true] time.weekday:[true]] +func buildArrayParameterMap() map[string][]bool { + compiler := ast.NewCompiler() + + // 아주 단순한 rego 코드를 컴파일해도 컴파일러의 모든 Built-In 함수 정보를 컴파일 할 수 있음 + mod, err := ast.ParseModuleWithOpts("hello", "package hello\n hello {input.message = \"world\"}", ast.ParserOptions{}) + if err != nil { + return nil + } + + // 컴파일을 수행해야 Built-in 함수 정보 로딩할 수 있음 + modules := map[string]*ast.Module{} + modules["hello"] = mod + compiler.Compile(modules) + + return getArrayParameterMap(compiler) +} + +func getArrayParameterMap(compiler *ast.Compiler) map[string][]bool { + capabilities := compiler.Capabilities() + + var result = map[string][]bool{} + + if capabilities != nil { + for _, builtin := range capabilities.Builtins { + args := builtin.Decl.FuncArgs().Args + isArrayParam := make([]bool, len(args)) + + arrayCount := 0 + for i, typeVal := range args { + isArrayParam[i] = IsArray(typeVal) + + if isArrayParam[i] { + arrayCount += 1 + } + } + + if arrayCount > 0 { + result[builtin.Name] = isArrayParam + } + } + } + + return result +} + +func IsArray(t types.Type) bool { + switch specific_type := t.(type) { + case *types.Array: + return true + case types.Any: + { + for _, anyType := range specific_type { + if IsArray(anyType) { + return true + } + } + } + } + + return false +} + func extractInputExprFromModule(module *ast.Module) []string { rules := module.Rules @@ -69,6 +139,9 @@ func processRule(rule *ast.Rule, paramRefs map[string]string, passedParams []str exprs := rule.Body localAssignMap := map[string]string{} + // if strings.Contains(rule.String(), "statefulset_vct_noname_msg") { + // fmt.Printf("1111111 %+ v%+v\n", rule.Head.Name, rule.Head.Key.String()) + // } for i, param := range passedParams { if isSubstitutionRequired(param) { argName := rule.Head.Args[i].String() @@ -95,7 +168,7 @@ func processRule(rule *ast.Rule, paramRefs map[string]string, passedParams []str if expr.IsCall() { call, _ := expr.Terms.([]*ast.Term) - if len(call) > 2 { + if len(call) > 1 { ruleName := call[0].String() args := call[1:] @@ -103,6 +176,18 @@ func processRule(rule *ast.Rule, paramRefs map[string]string, passedParams []str inputPassed, passingParams := processingInputArgs(args, localAssignMap) if inputPassed { + if is_arrays, ok := array_param_map[ruleName]; ok { + for i, passingParam := range passingParams { + is_array := is_arrays[i] + + if is_array && strings.HasPrefix(passingParam, "input.parameters.") && + !strings.HasSuffix(passingParam, "[_]") { + + paramRefs[passingParam+"[_]"] = "1" + } + } + } + for _, nvrule := range nonViolatonRule { if ruleName == nvrule.Head.Name.String() { @@ -163,6 +248,17 @@ func processRule(rule *ast.Rule, paramRefs map[string]string, passedParams []str return false }) } + + headKey := rule.Head.Key + + if headKey != nil { + for _, nvrule := range nonViolatonRule { + ruleName := nvrule.Head.Name.String() + if strings.Contains(headKey.String(), ruleName) { + processRule(nvrule, paramRefs, []string{}, nonViolatonRule) + } + } + } } // object.get(object.get(input, "parameters", {}), "exemptImages", [])) -> input.parameters.exemptImages와 같은 패턴 변환 @@ -171,8 +267,10 @@ func replaceAllObjectGet(expr string) string { return expr } - result := obj_get_regex.ReplaceAllString(expr, "$1.$2") + result := obj_get_list_regex.ReplaceAllString(expr, "$1.$2"+`[_]`) + result = obj_get_regex.ReplaceAllString(result, "$1.$2") + // 정규식의 영향 없음 그냥 리턴 if result == expr { return expr } @@ -194,8 +292,9 @@ func processingInputArgs(args []*ast.Term, localAssignMap map[string]string) (bo if isSubstitutionRequired(arg) { passingParams = append(passingParams, arg) inputPassed = true + } else { + passingParams = append(passingParams, "") } - passingParams = append(passingParams, "") } } return inputPassed, passingParams @@ -333,6 +432,16 @@ func findKey(defs []*domain.ParameterDef, key string) *domain.ParameterDef { return nil } +func cutTrailingArrayNote(val string) string { + cut, found := strings.CutSuffix(val, "[_]") + + if found { + return cutTrailingArrayNote(cut) + } + + return val +} + func createKey(key string, isLast bool) *domain.ParameterDef { var finalType string @@ -340,7 +449,7 @@ func createKey(key string, isLast bool) *domain.ParameterDef { isArray := false if strings.HasSuffix(pKey, "[_]") { - pKey, _ = strings.CutSuffix(pKey, "[_]") + pKey = cutTrailingArrayNote(pKey) isArray = true }