Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
djaglowski committed Sep 24, 2024
1 parent 8d07555 commit 35fb391
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 47 deletions.
2 changes: 1 addition & 1 deletion pkg/ottl/ottlfuncs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1293,7 +1293,7 @@ Examples:
The `RemoveXML` Converter returns an edited version of an XML string with selected elements removed.

`target` is a Getter that returns a string. This string should be in XML format.
If `target` is not a string, nil, or cannot be parsed as XML, `ParseXML` will return an error.
If `target` is not a string, nil, or is not valid xml, `RemoveXML` will return an error.

`xpath` is a string that specifies an [XPath](https://www.w3.org/TR/1999/REC-xpath-19991116/) expression that
selects one or more elements to remove from the XML document.
Expand Down
58 changes: 34 additions & 24 deletions pkg/ottl/ottlfuncs/func_remove_xml.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,36 +30,24 @@ func createRemoveXMLFunction[K any](_ ottl.FunctionContext, oArgs ottl.Arguments
return nil, fmt.Errorf("RemoveXML args must be of type *RemoveXMLAguments[K]")
}

if err := validateXPath(args.XPath); err != nil {
return nil, err
}

return removeXML(args.Target, args.XPath), nil
}

// removeXML returns a `pcommon.String` that is a result of renaming the target XML or s
// removeXML returns a XML formatted string that is a result of removing all matching nodes from the target XML.
// This currently supports removal of elements, attributes, text values, comments, and CharData.
func removeXML[K any](target ottl.StringGetter[K], xPath string) ottl.ExprFunc[K] {
return func(ctx context.Context, tCtx K) (any, error) {
targetVal, err := target.Get(ctx, tCtx)
if err != nil {
var doc *xmlquery.Node
if targetVal, err := target.Get(ctx, tCtx); err != nil {
return nil, err
} else if doc, err = parseNodesXML(targetVal); err != nil {
return nil, err
}

// Validate the xpath
_, err = xpath.Compile(xPath)
if err != nil {
return nil, fmt.Errorf("invalid xpath: %w", err)
}

// Check if the xml starts with a declaration node
preserveDeclearation := strings.HasPrefix(targetVal, "<?xml")

top, err := xmlquery.Parse(strings.NewReader(targetVal))
if err != nil {
return nil, fmt.Errorf("parse xml: %w", err)
}

if !preserveDeclearation {
xmlquery.RemoveFromTree(top.FirstChild)
}

xmlquery.FindEach(top, xPath, func(_ int, n *xmlquery.Node) {
xmlquery.FindEach(doc, xPath, func(_ int, n *xmlquery.Node) {
switch n.Type {
case xmlquery.ElementNode:
xmlquery.RemoveFromTree(n)
Expand All @@ -73,7 +61,29 @@ func removeXML[K any](target ottl.StringGetter[K], xPath string) ottl.ExprFunc[K
xmlquery.RemoveFromTree(n)
}
})
return doc.OutputXML(false), nil
}
}

return top.OutputXML(false), nil
func validateXPath(xPath string) error {
_, err := xpath.Compile(xPath)
if err != nil {
return fmt.Errorf("invalid xpath: %w", err)
}
return nil
}

// Aside from parsing the XML document, this function also ensures that
// the XML declaration is included in the result only if it was present in
// the original document.
func parseNodesXML(targetVal string) (*xmlquery.Node, error) {
preserveDeclearation := strings.HasPrefix(targetVal, "<?xml")
top, err := xmlquery.Parse(strings.NewReader(targetVal))
if err != nil {
return nil, fmt.Errorf("parse xml: %w", err)
}
if !preserveDeclearation {
xmlquery.RemoveFromTree(top.FirstChild)
}
return top, nil
}
50 changes: 28 additions & 22 deletions pkg/ottl/ottlfuncs/func_remove_xml_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,33 +150,39 @@ func Test_RemoveXML(t *testing.T) {
}
}

func Test_RemoveXML_InvalidXML(t *testing.T) {
target := ottl.StandardStringGetter[any]{
Getter: func(_ context.Context, _ any) (any, error) {
return `<a>>>>>>>`, nil
},
}
exprFunc := removeXML(target, "/foo")
_, err := exprFunc(context.Background(), nil)
func TestCreateRemoveXMLFunc(t *testing.T) {
factory := NewRemoveXMLFactory[any]()
fCtx := ottl.FunctionContext{}

// Invalid arg type
exprFunc, err := factory.CreateFunction(fCtx, nil)
assert.Error(t, err)
}
assert.Nil(t, exprFunc)

func Test_RemoveXML_InvalidXPath(t *testing.T) {
target := ottl.StandardStringGetter[any]{
Getter: func(_ context.Context, _ any) (any, error) {
return `<a></a>`, nil
},
}
exprFunc := removeXML(target, "!")
_, err := exprFunc(context.Background(), nil)
// Invalid XPath should error on function creation
exprFunc, err = factory.CreateFunction(
fCtx, &RemoveXMLArguments[any]{
XPath: "!",
})
assert.Error(t, err)
}
assert.Nil(t, exprFunc)

func TestCreateRemoveXMLFunc(t *testing.T) {
exprFunc, err := createRemoveXMLFunction[any](ottl.FunctionContext{}, &RemoveXMLArguments[any]{})
// Invalid XML should error on function execution
exprFunc, err = factory.CreateFunction(
fCtx, &RemoveXMLArguments[any]{
Target: invalidXMLGetter(),
XPath: "/",
})
assert.NoError(t, err)
assert.NotNil(t, exprFunc)

_, err = createRemoveXMLFunction[any](ottl.FunctionContext{}, nil)
_, err = exprFunc(context.Background(), nil)
assert.Error(t, err)
}

func invalidXMLGetter() ottl.StandardStringGetter[any] {
return ottl.StandardStringGetter[any]{
Getter: func(_ context.Context, _ any) (any, error) {
return `<a>>>>>>>`, nil
},
}
}

0 comments on commit 35fb391

Please sign in to comment.