Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

파라미터 추출 기능 개선 #408

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 116 additions & 7 deletions internal/policy-template/policy-template-rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/types"
"github.com/openinfradev/tks-api/internal/model"
"github.com/openinfradev/tks-api/pkg/domain"
"golang.org/x/exp/maps"
Expand All @@ -16,16 +17,85 @@ const (
input_param_prefix = "input.parameters"
input_extract_pattern = `input(\.parameters|\[\"parameters\"\])((\[\"[\w\-]+\"\])|(\[_\])|(\.\w+))*` //(\.\w+)*` // (\.\w+\[\"\w+\"\])|(\.\w+\[\w+\])|(\.\w+))*`
// input_extract_pattern = `input\.parameters((\[\".+\"\])?(\.\w+\[\"\w+\"\])|(\.\w+\[\w+\])|(\.\w+))+`
obj_get_pattern = `object\.get\((input|input\.parameters|input\.parameters\.[^,]+)\, \"*([^,\"]+)\"*, [^\)]+\)`
package_name_regex = `package ([\w\.]+)[\n\r]+`
import_regex = `import ([\w\.]+)[\n\r]+`
obj_get_list_pattern = `object\.get\((input|input\.parameters|input\.parameters\.[^,]+), \"([^\"]+)\", \[\]\)`
obj_get_pattern = `object\.get\((input|input\.parameters|input\.parameters\.[^,]+), \"([^\"]+)\", [^\)]+\)`
package_name_regex = `package ([\w\.]+)[\n\r]+`
import_regex = `import ([\w\.]+)[\n\r]+`
)

var (
input_extract_regex = regexp.MustCompile(input_extract_pattern)
obj_get_list_regex = regexp.MustCompile(obj_get_list_pattern)
obj_get_regex = regexp.MustCompile(obj_get_pattern)
array_param_map = buildArrayParameterMap()
)

// OPA 내장 함수 중 array 인자를 가진 함수와 array 인자의 위치를 담은 자료구조를 생성한다.
// // OPA 내장 함수 목록은 정책에 상관없으므로 처음 로딩될 때 한 번만 호출해 변수에 담아두고 사용하면 된다.
// OPA 엔진 버전에 따라 달라진다. 0.62 버전 기준으로 이 함수의 결과 값은 다음과 같다.
// map[all:[true] any:[true] array.concat:[true true] array.reverse:[true] array.slice:[true false false] concat:[false true] count:[true] glob.match:[false true false] graph.reachable:[false true] graph.reachable_paths:[false true] internal.print:[true] json.filter:[false true] json.patch:[false true] json.remove:[false true] max:[true] min:[true] net.cidr_contains_matches:[true true] net.cidr_merge:[true] object.filter:[false true] object.remove:[false true] object.subset:[true true] object.union_n:[true] product:[true] sort:[true] sprintf:[false true] strings.any_prefix_match:[true true] strings.any_suffix_match:[true true] sum:[true] time.clock:[true] time.date:[true] time.diff:[true true] time.format:[true] time.weekday:[true]]
func buildArrayParameterMap() map[string][]bool {
compiler := ast.NewCompiler()

// 아주 단순한 rego 코드를 컴파일해도 컴파일러의 모든 Built-In 함수 정보를 컴파일 할 수 있음
mod, err := ast.ParseModuleWithOpts("hello", "package hello\n hello {input.message = \"world\"}", ast.ParserOptions{})
if err != nil {
return nil
}

// 컴파일을 수행해야 Built-in 함수 정보 로딩할 수 있음
modules := map[string]*ast.Module{}
modules["hello"] = mod
compiler.Compile(modules)

return getArrayParameterMap(compiler)
}

func getArrayParameterMap(compiler *ast.Compiler) map[string][]bool {
capabilities := compiler.Capabilities()

var result = map[string][]bool{}

if capabilities != nil {
for _, builtin := range capabilities.Builtins {
args := builtin.Decl.FuncArgs().Args
isArrayParam := make([]bool, len(args))

arrayCount := 0
for i, typeVal := range args {
isArrayParam[i] = IsArray(typeVal)

if isArrayParam[i] {
arrayCount += 1
}
}

if arrayCount > 0 {
result[builtin.Name] = isArrayParam
}
}
}

return result
}

func IsArray(t types.Type) bool {
switch specific_type := t.(type) {
case *types.Array:
return true
case types.Any:
{
for _, anyType := range specific_type {
if IsArray(anyType) {
return true
}
}
}
}

return false
}

func extractInputExprFromModule(module *ast.Module) []string {
rules := module.Rules

Expand Down Expand Up @@ -69,6 +139,9 @@ func processRule(rule *ast.Rule, paramRefs map[string]string, passedParams []str
exprs := rule.Body
localAssignMap := map[string]string{}

// if strings.Contains(rule.String(), "statefulset_vct_noname_msg") {
// fmt.Printf("1111111 %+ v%+v\n", rule.Head.Name, rule.Head.Key.String())
// }
for i, param := range passedParams {
if isSubstitutionRequired(param) {
argName := rule.Head.Args[i].String()
Expand All @@ -95,14 +168,26 @@ func processRule(rule *ast.Rule, paramRefs map[string]string, passedParams []str

if expr.IsCall() {
call, _ := expr.Terms.([]*ast.Term)
if len(call) > 2 {
if len(call) > 1 {
ruleName := call[0].String()

args := call[1:]

inputPassed, passingParams := processingInputArgs(args, localAssignMap)

if inputPassed {
if is_arrays, ok := array_param_map[ruleName]; ok {
for i, passingParam := range passingParams {
is_array := is_arrays[i]

if is_array && strings.HasPrefix(passingParam, "input.parameters.") &&
!strings.HasSuffix(passingParam, "[_]") {

paramRefs[passingParam+"[_]"] = "1"
}
}
}

for _, nvrule := range nonViolatonRule {
if ruleName == nvrule.Head.Name.String() {

Expand Down Expand Up @@ -163,6 +248,17 @@ func processRule(rule *ast.Rule, paramRefs map[string]string, passedParams []str
return false
})
}

headKey := rule.Head.Key

if headKey != nil {
for _, nvrule := range nonViolatonRule {
ruleName := nvrule.Head.Name.String()
if strings.Contains(headKey.String(), ruleName) {
processRule(nvrule, paramRefs, []string{}, nonViolatonRule)
}
}
}
}

// object.get(object.get(input, "parameters", {}), "exemptImages", [])) -> input.parameters.exemptImages와 같은 패턴 변환
Expand All @@ -171,8 +267,10 @@ func replaceAllObjectGet(expr string) string {
return expr
}

result := obj_get_regex.ReplaceAllString(expr, "$1.$2")
result := obj_get_list_regex.ReplaceAllString(expr, "$1.$2"+`[_]`)
result = obj_get_regex.ReplaceAllString(result, "$1.$2")

// 정규식의 영향 없음 그냥 리턴
if result == expr {
return expr
}
Expand All @@ -194,8 +292,9 @@ func processingInputArgs(args []*ast.Term, localAssignMap map[string]string) (bo
if isSubstitutionRequired(arg) {
passingParams = append(passingParams, arg)
inputPassed = true
} else {
passingParams = append(passingParams, "")
}
passingParams = append(passingParams, "")
}
}
return inputPassed, passingParams
Expand Down Expand Up @@ -333,14 +432,24 @@ func findKey(defs []*domain.ParameterDef, key string) *domain.ParameterDef {
return nil
}

func cutTrailingArrayNote(val string) string {
cut, found := strings.CutSuffix(val, "[_]")

if found {
return cutTrailingArrayNote(cut)
}

return val
}

func createKey(key string, isLast bool) *domain.ParameterDef {
var finalType string

pKey := key
isArray := false

if strings.HasSuffix(pKey, "[_]") {
pKey, _ = strings.CutSuffix(pKey, "[_]")
pKey = cutTrailingArrayNote(pKey)
isArray = true
}

Expand Down
Loading