diff --git a/internal/injector/aspect/context/node.go b/internal/injector/aspect/context/node.go index 06046ec7..a1dda10d 100644 --- a/internal/injector/aspect/context/node.go +++ b/internal/injector/aspect/context/node.go @@ -56,6 +56,6 @@ func (n *NodeChain) PropertyName() string { return n.name } -func (n *NodeChain) Index() string { - return n.Index() +func (n *NodeChain) Index() int { + return n.index } diff --git a/internal/injector/injector.go b/internal/injector/injector.go index b0b6dc81..660352ef 100644 --- a/internal/injector/injector.go +++ b/internal/injector/injector.go @@ -11,8 +11,10 @@ package injector import ( "errors" "fmt" + "go/ast" "go/importer" "go/token" + "sync" "github.com/DataDog/orchestrion/internal/injector/aspect" "github.com/DataDog/orchestrion/internal/injector/aspect/context" @@ -84,26 +86,48 @@ func (i *Injector) InjectFiles(files []string) (map[string]InjectedFile, error) return nil, err } - decorator := decorator.NewDecoratorWithImports(fset, i.ImportPath, gotypes.New(uses)) - dstFiles := make([]*dst.File, len(astFiles)) + var ( + wg sync.WaitGroup + errs []error + errsMu sync.Mutex + result = make(map[string]InjectedFile, len(astFiles)) + resultMu sync.Mutex + ) + + wg.Add(len(astFiles)) for idx, astFile := range astFiles { - dstFiles[idx], err = decorator.DecorateFile(astFile) - if err != nil { - return nil, err - } - } + go func(astFile *ast.File) { + defer wg.Done() - result := make(map[string]InjectedFile, len(files)) - for idx, dstFile := range dstFiles { - res, err := i.injectFile(decorator, dstFile) - if err != nil { - return nil, err - } - if res.Modified { + decorator := decorator.NewDecoratorWithImports(fset, i.ImportPath, gotypes.New(uses)) + dstFile, err := decorator.DecorateFile(astFile) + if err != nil { + errsMu.Lock() + defer errsMu.Unlock() + errs = append(errs, err) + return + } + + res, err := i.injectFile(decorator, dstFile) + if err != nil { + errsMu.Lock() + defer errsMu.Unlock() + errs = append(errs, err) + return + } + + if !res.Modified { + return + } + + resultMu.Lock() + defer resultMu.Unlock() result[files[idx]] = res.InjectedFile - } + }(astFile) } - return result, nil + wg.Wait() + + return result, errors.Join(errs...) } func (i *Injector) validate() error { @@ -114,6 +138,10 @@ func (i *Injector) validate() error { if i.Lookup == nil { err = errors.Join(err, fmt.Errorf("invalid %T: missing Lookup", i)) } + + // Initialize the restorerResolver field, too... + i.restorerResolver = &lookupResolver{lookup: i.Lookup} + return err } diff --git a/internal/injector/restorer.go b/internal/injector/restorer.go index 2dac42dd..2a9d495b 100644 --- a/internal/injector/restorer.go +++ b/internal/injector/restorer.go @@ -10,6 +10,7 @@ import ( "go/importer" "go/token" "go/types" + "sync" "github.com/dave/dst/decorator" "golang.org/x/tools/go/gcexportdata" @@ -20,13 +21,11 @@ type lookupResolver struct { fset *token.FileSet imports map[string]*types.Package + + mu sync.Mutex } func (i *Injector) newRestorer(filename string) *decorator.FileRestorer { - if i.restorerResolver == nil { - i.restorerResolver = &lookupResolver{lookup: i.Lookup} - } - return &decorator.FileRestorer{ Restorer: decorator.NewRestorerWithImports(i.ImportPath, i.restorerResolver), Name: filename, @@ -39,6 +38,9 @@ func (r *lookupResolver) ResolvePackage(path string) (string, error) { return "unsafe", nil } + r.mu.Lock() + defer r.mu.Unlock() + // If this is present in "cache", we can return right away! if pkg, ok := r.imports[path]; ok { return pkg.Name(), nil diff --git a/internal/jobserver/pkgs/resolve.go b/internal/jobserver/pkgs/resolve.go index d17d9580..3db52ecd 100644 --- a/internal/jobserver/pkgs/resolve.go +++ b/internal/jobserver/pkgs/resolve.go @@ -139,7 +139,7 @@ func (s *service) resolve(req *ResolveRequest) (ResolveResponse, error) { return nil, err } - var resp ResolveResponse + resp := make(ResolveResponse) for _, pkg := range pkgs { resp.mergeFrom(pkg) } @@ -152,16 +152,6 @@ func (s *service) resolve(req *ResolveRequest) (ResolveResponse, error) { return resp, nil } -func hashArray(items []string) string { - h := sha512.New512_224() - - for idx, item := range items { - _, _ = fmt.Fprintf(h, "\x01%d\x02%s\x03", idx, item) - } - - return base64.URLEncoding.EncodeToString(h.Sum(nil)) -} - func (r *ResolveRequest) canonicalize() { if r.canonical { return @@ -186,15 +176,14 @@ func (r *ResolveRequest) hash() (string, error) { return base64.URLEncoding.EncodeToString(hash.Sum(sum[:0])), nil } -func (r *ResolveResponse) mergeFrom(pkg *packages.Package) { - if pkg.PkgPath == "" || pkg.PkgPath == "unsafe" { +func (r ResolveResponse) mergeFrom(pkg *packages.Package) { + if pkg.PkgPath == "" || pkg.PkgPath == "unsafe" || r[pkg.PkgPath] != "" { + // Ignore the "unsafe" package (no archive file, ever), packages with an empty import path + // (standard library), and those already present in the map (already processed previously). return } - if *r == nil { - *r = make(ResolveResponse) - } - (*r)[pkg.PkgPath] = pkg.ExportFile + r[pkg.PkgPath] = pkg.ExportFile for _, dep := range pkg.Imports { r.mergeFrom(dep) }