diff --git a/internal/instrument/chiv5.go b/internal/instrument/chiv5.go index 93ebc4d1..c769dba0 100644 --- a/internal/instrument/chiv5.go +++ b/internal/instrument/chiv5.go @@ -9,6 +9,17 @@ import ( "github.com/dave/dst" ) +func instrumentChiV5(stmt *dst.AssignStmt) []dst.Stmt { + if !isChiV5(stmt) { + return nil + } + stmt.Decorations().Start.Prepend(dd_instrumented) + return []dst.Stmt{ + stmt, + chiV5Middleware(stmt), + } +} + func isChiV5(stmt *dst.AssignStmt) bool { rhs := stmt.Rhs[0] f, ok := funcIdent(rhs) @@ -21,7 +32,7 @@ func chiV5Middleware(got *dst.AssignStmt) dst.Stmt { return nil } stmt := useMiddleware(iden.Name, "ChiV5Middleware") - wrap(stmt) + markAsInstrumented(stmt) return stmt } diff --git a/internal/instrument/chiv5_test.go b/internal/instrument/chiv5_test.go new file mode 100644 index 00000000..ed36e636 --- /dev/null +++ b/internal/instrument/chiv5_test.go @@ -0,0 +1,96 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2023-present Datadog, Inc. + +package instrument + +import ( + "fmt" + "io" + "strings" + "testing" + + "github.com/datadog/orchestrion/internal/config" + + "github.com/stretchr/testify/require" +) + +func TestChiV5(t *testing.T) { + var codeTpl = `package main + +import %s + +func register() { + %s +} +` + var wantTpl = `package main + +import ( + "github.com/datadog/orchestrion/instrument" + %s +) + +func register() { + //dd:instrumented + %s + //dd:startinstrument + %s + //dd:endinstrument +} +` + + tests := []struct { + pkg string + stmt string + want string + tmpl string + }{ + {pkg: `"github.com/go-chi/chi/v5"`, stmt: `r := chi.NewRouter()`, want: `r.Use(instrument.ChiV5Middleware())`, tmpl: wantTpl}, + {pkg: `chi "github.com/go-chi/chi/v5"`, stmt: `r := chi.NewRouter()`, want: `r.Use(instrument.ChiV5Middleware())`, tmpl: wantTpl}, + {pkg: `chiv5 "github.com/go-chi/chi/v5"`, stmt: `r := chiv5.NewRouter()`, want: `r.Use(instrument.ChiV5Middleware())`, tmpl: wantTpl}, + } + + for i, tc := range tests { + t.Run(fmt.Sprintf("tc-%d", i), func(t *testing.T) { + code := fmt.Sprintf(codeTpl, tc.pkg, tc.stmt) + reader, err := InstrumentFile("test", strings.NewReader(code), config.Config{}) + require.Nil(t, err) + got, err := io.ReadAll(reader) + require.Nil(t, err) + want := fmt.Sprintf(tc.tmpl, tc.pkg, tc.stmt, tc.want) + require.Equal(t, want, string(got)) + + reader, err = UninstrumentFile("test", strings.NewReader(want), config.Config{}) + require.Nil(t, err) + orig, err := io.ReadAll(reader) + require.Nil(t, err) + require.Equal(t, code, string(orig)) + }) + } +} + +func TestChiV5Duplicates(t *testing.T) { + var tpl = `package main + +import ( + "github.com/datadog/orchestrion/instrument" + "github.com/go-chi/chi/v5" +) + +func chiV5Server() { + //dd:instrumented + r := chi.NewRouter() + //dd:startinstrument + r.Use(instrument.ChiV5Middleware()) + //dd:endinstrument +} +` + + reader, err := InstrumentFile("test", strings.NewReader(tpl), config.Config{}) + require.Nil(t, err) + got, err := io.ReadAll(reader) + require.Nil(t, err) + require.Equal(t, tpl, string(got)) +} diff --git a/internal/instrument/echov4.go b/internal/instrument/echov4.go index ad945339..b1102780 100644 --- a/internal/instrument/echov4.go +++ b/internal/instrument/echov4.go @@ -7,6 +7,17 @@ package instrument import "github.com/dave/dst" +func instrumentEchoV4(stmt *dst.AssignStmt) []dst.Stmt { + if !isEchoV4(stmt) { + return nil + } + stmt.Decorations().Start.Prepend(dd_instrumented) + return []dst.Stmt{ + stmt, + echoV4Middleware(stmt), + } +} + func isEchoV4(stmt *dst.AssignStmt) bool { rhs := stmt.Rhs[0] f, ok := funcIdent(rhs) @@ -19,7 +30,7 @@ func echoV4Middleware(got *dst.AssignStmt) dst.Stmt { return nil } stmt := useMiddleware(iden.Name, "EchoV4Middleware") - wrap(stmt) + markAsInstrumented(stmt) return stmt } diff --git a/internal/instrument/echov4_test.go b/internal/instrument/echov4_test.go new file mode 100644 index 00000000..804ebcf5 --- /dev/null +++ b/internal/instrument/echov4_test.go @@ -0,0 +1,96 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2023-present Datadog, Inc. + +package instrument + +import ( + "fmt" + "io" + "strings" + "testing" + + "github.com/datadog/orchestrion/internal/config" + + "github.com/stretchr/testify/require" +) + +func TestEchoV4(t *testing.T) { + var codeTpl = `package main + +import %s + +func register() { + %s +} +` + var wantTpl = `package main + +import ( + "github.com/datadog/orchestrion/instrument" + %s +) + +func register() { + //dd:instrumented + %s + //dd:startinstrument + %s + //dd:endinstrument +} +` + + tests := []struct { + pkg string + stmt string + want string + tmpl string + }{ + {pkg: `"github.com/labstack/echo/v4"`, stmt: `r := echo.New()`, want: `r.Use(instrument.EchoV4Middleware())`, tmpl: wantTpl}, + {pkg: `echo "github.com/labstack/echo/v4"`, stmt: `r := echo.New()`, want: `r.Use(instrument.EchoV4Middleware())`, tmpl: wantTpl}, + {pkg: `echov4 "github.com/labstack/echo/v4"`, stmt: `r := echov4.New()`, want: `r.Use(instrument.EchoV4Middleware())`, tmpl: wantTpl}, + } + + for i, tc := range tests { + t.Run(fmt.Sprintf("tc-%d", i), func(t *testing.T) { + code := fmt.Sprintf(codeTpl, tc.pkg, tc.stmt) + reader, err := InstrumentFile("test", strings.NewReader(code), config.Config{}) + require.Nil(t, err) + got, err := io.ReadAll(reader) + require.Nil(t, err) + want := fmt.Sprintf(tc.tmpl, tc.pkg, tc.stmt, tc.want) + require.Equal(t, want, string(got)) + + reader, err = UninstrumentFile("test", strings.NewReader(want), config.Config{}) + require.Nil(t, err) + orig, err := io.ReadAll(reader) + require.Nil(t, err) + require.Equal(t, code, string(orig)) + }) + } +} + +func TestEchoV4Duplicates(t *testing.T) { + var tpl = `package main + +import ( + "github.com/datadog/orchestrion/instrument" + "github.com/labstack/echo/v4" +) + +func echoV4Server() { + //dd:instrumented + r := echo.New() + //dd:startinstrument + r.Use(instrument.EchoV4Middleware()) + //dd:endinstrument +} +` + + reader, err := InstrumentFile("test", strings.NewReader(tpl), config.Config{}) + require.Nil(t, err) + got, err := io.ReadAll(reader) + require.Nil(t, err) + require.Equal(t, tpl, string(got)) +} diff --git a/internal/instrument/gin.go b/internal/instrument/gin.go index 3708d466..db75cff9 100644 --- a/internal/instrument/gin.go +++ b/internal/instrument/gin.go @@ -7,6 +7,17 @@ package instrument import "github.com/dave/dst" +func instrumentGin(stmt *dst.AssignStmt) []dst.Stmt { + if !isGin(stmt) { + return nil + } + stmt.Decorations().Start.Prepend(dd_instrumented) + return []dst.Stmt{ + stmt, + ginMiddleware(stmt), + } +} + func isGin(stmt *dst.AssignStmt) bool { rhs := stmt.Rhs[0] f, ok := funcIdent(rhs) @@ -19,7 +30,7 @@ func ginMiddleware(got *dst.AssignStmt) dst.Stmt { return nil } stmt := useMiddleware(iden.Name, "GinMiddleware") - wrap(stmt) + markAsInstrumented(stmt) return stmt } diff --git a/internal/instrument/gin_test.go b/internal/instrument/gin_test.go new file mode 100644 index 00000000..5c840612 --- /dev/null +++ b/internal/instrument/gin_test.go @@ -0,0 +1,102 @@ +// Unless explicitly stated otherwise all files in this repository are licensed +// under the Apache License Version 2.0. +// This product includes software developed at Datadog (https://www.datadoghq.com/). +// Copyright 2023-present Datadog, Inc. + +package instrument + +import ( + "fmt" + "io" + "strings" + "testing" + + "github.com/datadog/orchestrion/internal/config" + + "github.com/stretchr/testify/require" +) + +func TestGin(t *testing.T) { + var codeTpl = `package main + +import "github.com/gin-gonic/gin" + +func register() { + %s +} +` + var wantTpl = `package main + +import ( + "github.com/datadog/orchestrion/instrument" + "github.com/gin-gonic/gin" +) + +func register() { + //dd:instrumented + %s + //dd:startinstrument + %s + //dd:endinstrument +} +` + + tests := []struct { + in string + want string + tmpl string + }{ + {in: `g := gin.New()`, want: `g.Use(instrument.GinMiddleware())`, tmpl: wantTpl}, + {in: `g := gin.Default()`, want: `g.Use(instrument.GinMiddleware())`, tmpl: wantTpl}, + } + + for i, tc := range tests { + t.Run(fmt.Sprintf("tc-%d", i), func(t *testing.T) { + code := fmt.Sprintf(codeTpl, tc.in) + reader, err := InstrumentFile("test", strings.NewReader(code), config.Config{}) + require.Nil(t, err) + got, err := io.ReadAll(reader) + require.Nil(t, err) + want := fmt.Sprintf(tc.tmpl, tc.in, tc.want) + require.Equal(t, want, string(got)) + + reader, err = UninstrumentFile("test", strings.NewReader(want), config.Config{}) + require.Nil(t, err) + orig, err := io.ReadAll(reader) + require.Nil(t, err) + require.Equal(t, code, string(orig)) + }) + } +} + +func TestGinDuplicates(t *testing.T) { + var tpl = `package main + +import ( + "net/http" + + "github.com/datadog/orchestrion/instrument" + "github.com/gin-gonic/gin" +) + +func ginServer() { + //dd:instrumented + r := gin.Default() + //dd:startinstrument + r.Use(instrument.GinMiddleware()) + //dd:endinstrument + r.GET("/ping", func(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "message": "pong", + }) + }) + r.Run() +} +` + + reader, err := InstrumentFile("test", strings.NewReader(tpl), config.Config{}) + require.Nil(t, err) + got, err := io.ReadAll(reader) + require.Nil(t, err) + require.Equal(t, tpl, string(got)) +} diff --git a/internal/instrument/gorilla.go b/internal/instrument/gorilla.go index 0f34f53a..2a8b61b9 100644 --- a/internal/instrument/gorilla.go +++ b/internal/instrument/gorilla.go @@ -16,7 +16,7 @@ func wrapGorillaMux(stmt *dst.AssignStmt) { if !(f.Path == "github.com/gorilla/mux" && f.Name == "NewRouter") { return } - wrap(stmt) + markAsWrap(stmt) call := rhs.(*dst.CallExpr) call.Args = []dst.Expr{ &dst.CallExpr{ diff --git a/internal/instrument/grpc.go b/internal/instrument/grpc.go index 9d7a63f0..51ae3719 100644 --- a/internal/instrument/grpc.go +++ b/internal/instrument/grpc.go @@ -28,7 +28,7 @@ func wrapGRPC(stmt *dst.AssignStmt) { if !(iden.Name == targetName && iden.Path == "google.golang.org/grpc") { return } - wrap(stmt) + markAsWrap(stmt) for _, opt := range opts { fun.Args = append(fun.Args, &dst.CallExpr{Fun: &dst.Ident{Name: opt, Path: "github.com/datadog/orchestrion/instrument"}}, diff --git a/internal/instrument/http.go b/internal/instrument/http.go index 762799d8..fb070598 100644 --- a/internal/instrument/http.go +++ b/internal/instrument/http.go @@ -101,7 +101,7 @@ func wrapHandlerFromAssign(stmt *dst.AssignStmt, tc *typechecker.TypeChecker) bo if !(ok && k.Name == "Handler" && tc.OfType(k, "net/http.Handler")) { continue } - wrap(kve) + markAsWrap(kve) kve.Value = &dst.CallExpr{ Fun: &dst.Ident{Name: "WrapHandler", Path: "github.com/datadog/orchestrion/instrument"}, Args: []dst.Expr{kve.Value}, @@ -124,7 +124,7 @@ func wrapClientFromAssign(stmt *dst.AssignStmt, tc *typechecker.TypeChecker) boo if !(ok && tc.OfType(iden, "*net/http.Client")) { return false } - wrap(stmt) + markAsWrap(stmt) stmt.Rhs[0] = &dst.CallExpr{ Fun: &dst.Ident{Name: "WrapHTTPClient", Path: "github.com/datadog/orchestrion/instrument"}, Args: []dst.Expr{stmt.Rhs[0]}, @@ -160,7 +160,7 @@ func wrapHandlerFromExpr(stmt *dst.ExprStmt, tc *typechecker.TypeChecker) bool { default: return false } - wrap(fun) + markAsWrap(fun) fun.Args[1] = &dst.CallExpr{ Fun: &dst.Ident{Name: wrapper, Path: "github.com/datadog/orchestrion/instrument"}, Args: []dst.Expr{fun.Args[1]}, diff --git a/internal/instrument/instrument.go b/internal/instrument/instrument.go index 56329e93..e49b4d27 100644 --- a/internal/instrument/instrument.go +++ b/internal/instrument/instrument.go @@ -122,7 +122,7 @@ func newResolver() guess.RestorerResolver { } func addSpanCodeToFunction(comment string, decl *dst.FuncDecl, tc *typechecker.TypeChecker) *dst.FuncDecl { - //check if magic comment is attached to first line + // check if magic comment is attached to first line if len(decl.Body.List) > 0 { decs := decl.Body.List[0].Decorations().Start for _, v := range decs.All() { @@ -299,20 +299,17 @@ func addInFunctionCode(list []dst.Stmt, tc *typechecker.TypeChecker, conf config wrapSqlOpenFromAssign(stmt) wrapGRPC(stmt) wrapGorillaMux(stmt) - if isGin(stmt) { - out = append(out, stmt) + if r := instrumentGin(stmt); r != nil { appendStmt = false - out = append(out, ginMiddleware(stmt)) + out = append(out, r...) } - if isEchoV4(stmt) { - out = append(out, stmt) + if r := instrumentEchoV4(stmt); r != nil { appendStmt = false - out = append(out, echoV4Middleware(stmt)) + out = append(out, r...) } - if isChiV5(stmt) { - out = append(out, stmt) + if r := instrumentChiV5(stmt); r != nil { appendStmt = false - out = append(out, chiV5Middleware(stmt)) + out = append(out, r...) } // Recurse when there is a function literal on the RHS of the assignment. @@ -415,7 +412,7 @@ func hasLabel(label string, decs []string) bool { } func addInit(decl *dst.FuncDecl) *dst.FuncDecl { - //check if magic comment is attached to first line + // check if magic comment is attached to first line if len(decl.Body.List) > 0 { decs := decl.Body.List[0].Decorations().Start for _, v := range decs.All() { @@ -572,7 +569,12 @@ func useMiddleware(pkg, middleware string) *dst.ExprStmt { return stmt } -func wrap(stmt dst.Node) { +func markAsWrap(stmt dst.Node) { stmt.Decorations().Start.Append(dd_startwrap) stmt.Decorations().End.Append("\n", dd_endwrap) } + +func markAsInstrumented(stmt dst.Node) { + stmt.Decorations().Start.Append(dd_startinstrument) + stmt.Decorations().End.Append("\n", dd_endinstrument) +} diff --git a/internal/instrument/instrument_test.go b/internal/instrument/instrument_test.go index 0bd07207..e155d9b3 100644 --- a/internal/instrument/instrument_test.go +++ b/internal/instrument/instrument_test.go @@ -823,163 +823,3 @@ func init() { }) } - -func TestGin(t *testing.T) { - var codeTpl = `package main - -import "github.com/gin-gonic/gin" - -func register() { - %s -} -` - var wantTpl = `package main - -import ( - "github.com/datadog/orchestrion/instrument" - "github.com/gin-gonic/gin" -) - -func register() { - %s - //dd:startwrap - %s - //dd:endwrap -} -` - - tests := []struct { - in string - want string - tmpl string - }{ - {in: `g := gin.New()`, want: `g.Use(instrument.GinMiddleware())`, tmpl: wantTpl}, - {in: `g := gin.Default()`, want: `g.Use(instrument.GinMiddleware())`, tmpl: wantTpl}, - } - - for i, tc := range tests { - t.Run(fmt.Sprintf("tc-%d", i), func(t *testing.T) { - code := fmt.Sprintf(codeTpl, tc.in) - reader, err := InstrumentFile("test", strings.NewReader(code), config.Config{}) - require.Nil(t, err) - got, err := io.ReadAll(reader) - require.Nil(t, err) - want := fmt.Sprintf(tc.tmpl, tc.in, tc.want) - require.Equal(t, want, string(got)) - - reader, err = UninstrumentFile("test", strings.NewReader(want), config.Config{}) - require.Nil(t, err) - orig, err := io.ReadAll(reader) - require.Nil(t, err) - require.Equal(t, code, string(orig)) - }) - } -} - -func TestEchoV4(t *testing.T) { - var codeTpl = `package main - -import %s - -func register() { - %s -} -` - var wantTpl = `package main - -import ( - "github.com/datadog/orchestrion/instrument" - %s -) - -func register() { - %s - //dd:startwrap - %s - //dd:endwrap -} -` - - tests := []struct { - pkg string - stmt string - want string - tmpl string - }{ - {pkg: `"github.com/labstack/echo/v4"`, stmt: `r := echo.New()`, want: `r.Use(instrument.EchoV4Middleware())`, tmpl: wantTpl}, - {pkg: `echo "github.com/labstack/echo/v4"`, stmt: `r := echo.New()`, want: `r.Use(instrument.EchoV4Middleware())`, tmpl: wantTpl}, - {pkg: `echov4 "github.com/labstack/echo/v4"`, stmt: `r := echov4.New()`, want: `r.Use(instrument.EchoV4Middleware())`, tmpl: wantTpl}, - } - - for i, tc := range tests { - t.Run(fmt.Sprintf("tc-%d", i), func(t *testing.T) { - code := fmt.Sprintf(codeTpl, tc.pkg, tc.stmt) - reader, err := InstrumentFile("test", strings.NewReader(code), config.Config{}) - require.Nil(t, err) - got, err := io.ReadAll(reader) - require.Nil(t, err) - want := fmt.Sprintf(tc.tmpl, tc.pkg, tc.stmt, tc.want) - require.Equal(t, want, string(got)) - - reader, err = UninstrumentFile("test", strings.NewReader(want), config.Config{}) - require.Nil(t, err) - orig, err := io.ReadAll(reader) - require.Nil(t, err) - require.Equal(t, code, string(orig)) - }) - } -} - -func TestChiV5(t *testing.T) { - var codeTpl = `package main - -import %s - -func register() { - %s -} -` - var wantTpl = `package main - -import ( - "github.com/datadog/orchestrion/instrument" - %s -) - -func register() { - %s - //dd:startwrap - %s - //dd:endwrap -} -` - - tests := []struct { - pkg string - stmt string - want string - tmpl string - }{ - {pkg: `"github.com/go-chi/chi/v5"`, stmt: `r := chi.NewRouter()`, want: `r.Use(instrument.ChiV5Middleware())`, tmpl: wantTpl}, - {pkg: `chi "github.com/go-chi/chi/v5"`, stmt: `r := chi.NewRouter()`, want: `r.Use(instrument.ChiV5Middleware())`, tmpl: wantTpl}, - {pkg: `chiv5 "github.com/go-chi/chi/v5"`, stmt: `r := chiv5.NewRouter()`, want: `r.Use(instrument.ChiV5Middleware())`, tmpl: wantTpl}, - } - - for i, tc := range tests { - t.Run(fmt.Sprintf("tc-%d", i), func(t *testing.T) { - code := fmt.Sprintf(codeTpl, tc.pkg, tc.stmt) - reader, err := InstrumentFile("test", strings.NewReader(code), config.Config{}) - require.Nil(t, err) - got, err := io.ReadAll(reader) - require.Nil(t, err) - want := fmt.Sprintf(tc.tmpl, tc.pkg, tc.stmt, tc.want) - require.Equal(t, want, string(got)) - - reader, err = UninstrumentFile("test", strings.NewReader(want), config.Config{}) - require.Nil(t, err) - orig, err := io.ReadAll(reader) - require.Nil(t, err) - require.Equal(t, code, string(orig)) - }) - } -} diff --git a/internal/instrument/sql.go b/internal/instrument/sql.go index b9c66077..68e1eef1 100644 --- a/internal/instrument/sql.go +++ b/internal/instrument/sql.go @@ -23,7 +23,7 @@ func wrapSqlReturnCall(stmt *dst.ReturnStmt) *dst.ReturnStmt { continue } if wrapSqlCall(fun) { - wrap(stmt) + markAsWrap(stmt) stmt.Decorations().Before = dst.NewLine } } @@ -46,7 +46,7 @@ func wrapSqlOpenFromAssign(stmt *dst.AssignStmt) bool { return false } if wrapSqlCall(fun) { - wrap(stmt) + markAsWrap(stmt) return true } return false