Skip to content

Commit

Permalink
Merge pull request #408 from sangkenlee/enhanced-extract-param
Browse files Browse the repository at this point in the history
파라미터 추출 기능 개선
  • Loading branch information
ktkfree authored Apr 22, 2024
2 parents 0abb93b + 4468e9c commit 067e636
Showing 1 changed file with 116 additions and 7 deletions.
123 changes: 116 additions & 7 deletions internal/policy-template/policy-template-rego.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"strings"

"github.com/open-policy-agent/opa/ast"
"github.com/open-policy-agent/opa/types"
"github.com/openinfradev/tks-api/internal/model"
"github.com/openinfradev/tks-api/pkg/domain"
"golang.org/x/exp/maps"
Expand All @@ -16,16 +17,85 @@ const (
input_param_prefix = "input.parameters"
input_extract_pattern = `input(\.parameters|\[\"parameters\"\])((\[\"[\w\-]+\"\])|(\[_\])|(\.\w+))*` //(\.\w+)*` // (\.\w+\[\"\w+\"\])|(\.\w+\[\w+\])|(\.\w+))*`
// input_extract_pattern = `input\.parameters((\[\".+\"\])?(\.\w+\[\"\w+\"\])|(\.\w+\[\w+\])|(\.\w+))+`
obj_get_pattern = `object\.get\((input|input\.parameters|input\.parameters\.[^,]+)\, \"*([^,\"]+)\"*, [^\)]+\)`
package_name_regex = `package ([\w\.]+)[\n\r]+`
import_regex = `import ([\w\.]+)[\n\r]+`
obj_get_list_pattern = `object\.get\((input|input\.parameters|input\.parameters\.[^,]+), \"([^\"]+)\", \[\]\)`
obj_get_pattern = `object\.get\((input|input\.parameters|input\.parameters\.[^,]+), \"([^\"]+)\", [^\)]+\)`
package_name_regex = `package ([\w\.]+)[\n\r]+`
import_regex = `import ([\w\.]+)[\n\r]+`
)

var (
input_extract_regex = regexp.MustCompile(input_extract_pattern)
obj_get_list_regex = regexp.MustCompile(obj_get_list_pattern)
obj_get_regex = regexp.MustCompile(obj_get_pattern)
array_param_map = buildArrayParameterMap()
)

// OPA 내장 함수 중 array 인자를 가진 함수와 array 인자의 위치를 담은 자료구조를 생성한다.
// // OPA 내장 함수 목록은 정책에 상관없으므로 처음 로딩될 때 한 번만 호출해 변수에 담아두고 사용하면 된다.
// OPA 엔진 버전에 따라 달라진다. 0.62 버전 기준으로 이 함수의 결과 값은 다음과 같다.
// map[all:[true] any:[true] array.concat:[true true] array.reverse:[true] array.slice:[true false false] concat:[false true] count:[true] glob.match:[false true false] graph.reachable:[false true] graph.reachable_paths:[false true] internal.print:[true] json.filter:[false true] json.patch:[false true] json.remove:[false true] max:[true] min:[true] net.cidr_contains_matches:[true true] net.cidr_merge:[true] object.filter:[false true] object.remove:[false true] object.subset:[true true] object.union_n:[true] product:[true] sort:[true] sprintf:[false true] strings.any_prefix_match:[true true] strings.any_suffix_match:[true true] sum:[true] time.clock:[true] time.date:[true] time.diff:[true true] time.format:[true] time.weekday:[true]]
func buildArrayParameterMap() map[string][]bool {
compiler := ast.NewCompiler()

// 아주 단순한 rego 코드를 컴파일해도 컴파일러의 모든 Built-In 함수 정보를 컴파일 할 수 있음
mod, err := ast.ParseModuleWithOpts("hello", "package hello\n hello {input.message = \"world\"}", ast.ParserOptions{})
if err != nil {
return nil
}

// 컴파일을 수행해야 Built-in 함수 정보 로딩할 수 있음
modules := map[string]*ast.Module{}
modules["hello"] = mod
compiler.Compile(modules)

return getArrayParameterMap(compiler)
}

func getArrayParameterMap(compiler *ast.Compiler) map[string][]bool {
capabilities := compiler.Capabilities()

var result = map[string][]bool{}

if capabilities != nil {
for _, builtin := range capabilities.Builtins {
args := builtin.Decl.FuncArgs().Args
isArrayParam := make([]bool, len(args))

arrayCount := 0
for i, typeVal := range args {
isArrayParam[i] = IsArray(typeVal)

if isArrayParam[i] {
arrayCount += 1
}
}

if arrayCount > 0 {
result[builtin.Name] = isArrayParam
}
}
}

return result
}

func IsArray(t types.Type) bool {
switch specific_type := t.(type) {
case *types.Array:
return true
case types.Any:
{
for _, anyType := range specific_type {
if IsArray(anyType) {
return true
}
}
}
}

return false
}

func extractInputExprFromModule(module *ast.Module) []string {
rules := module.Rules

Expand Down Expand Up @@ -69,6 +139,9 @@ func processRule(rule *ast.Rule, paramRefs map[string]string, passedParams []str
exprs := rule.Body
localAssignMap := map[string]string{}

// if strings.Contains(rule.String(), "statefulset_vct_noname_msg") {
// fmt.Printf("1111111 %+ v%+v\n", rule.Head.Name, rule.Head.Key.String())
// }
for i, param := range passedParams {
if isSubstitutionRequired(param) {
argName := rule.Head.Args[i].String()
Expand All @@ -95,14 +168,26 @@ func processRule(rule *ast.Rule, paramRefs map[string]string, passedParams []str

if expr.IsCall() {
call, _ := expr.Terms.([]*ast.Term)
if len(call) > 2 {
if len(call) > 1 {
ruleName := call[0].String()

args := call[1:]

inputPassed, passingParams := processingInputArgs(args, localAssignMap)

if inputPassed {
if is_arrays, ok := array_param_map[ruleName]; ok {
for i, passingParam := range passingParams {
is_array := is_arrays[i]

if is_array && strings.HasPrefix(passingParam, "input.parameters.") &&
!strings.HasSuffix(passingParam, "[_]") {

paramRefs[passingParam+"[_]"] = "1"
}
}
}

for _, nvrule := range nonViolatonRule {
if ruleName == nvrule.Head.Name.String() {

Expand Down Expand Up @@ -163,6 +248,17 @@ func processRule(rule *ast.Rule, paramRefs map[string]string, passedParams []str
return false
})
}

headKey := rule.Head.Key

if headKey != nil {
for _, nvrule := range nonViolatonRule {
ruleName := nvrule.Head.Name.String()
if strings.Contains(headKey.String(), ruleName) {
processRule(nvrule, paramRefs, []string{}, nonViolatonRule)
}
}
}
}

// object.get(object.get(input, "parameters", {}), "exemptImages", [])) -> input.parameters.exemptImages와 같은 패턴 변환
Expand All @@ -171,8 +267,10 @@ func replaceAllObjectGet(expr string) string {
return expr
}

result := obj_get_regex.ReplaceAllString(expr, "$1.$2")
result := obj_get_list_regex.ReplaceAllString(expr, "$1.$2"+`[_]`)
result = obj_get_regex.ReplaceAllString(result, "$1.$2")

// 정규식의 영향 없음 그냥 리턴
if result == expr {
return expr
}
Expand All @@ -194,8 +292,9 @@ func processingInputArgs(args []*ast.Term, localAssignMap map[string]string) (bo
if isSubstitutionRequired(arg) {
passingParams = append(passingParams, arg)
inputPassed = true
} else {
passingParams = append(passingParams, "")
}
passingParams = append(passingParams, "")
}
}
return inputPassed, passingParams
Expand Down Expand Up @@ -333,14 +432,24 @@ func findKey(defs []*domain.ParameterDef, key string) *domain.ParameterDef {
return nil
}

func cutTrailingArrayNote(val string) string {
cut, found := strings.CutSuffix(val, "[_]")

if found {
return cutTrailingArrayNote(cut)
}

return val
}

func createKey(key string, isLast bool) *domain.ParameterDef {
var finalType string

pKey := key
isArray := false

if strings.HasSuffix(pKey, "[_]") {
pKey, _ = strings.CutSuffix(pKey, "[_]")
pKey = cutTrailingArrayNote(pKey)
isArray = true
}

Expand Down

0 comments on commit 067e636

Please sign in to comment.