Skip to content

Commit

Permalink
fix(locations): make source info access concurrent safe (#1433)
Browse files Browse the repository at this point in the history
* fix(locations): make source info access concurrent safe

* follow mutex hat pattern

* add test and run unit tests with -race
  • Loading branch information
noahdietz authored Sep 23, 2024
1 parent a6ba5f3 commit 223aa5b
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version: "1.20"
- run: go test -p 1 ./...
- run: go test -race ./...
lint:
runs-on: ubuntu-latest
steps:
Expand Down
47 changes: 37 additions & 10 deletions locations/locations.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
package locations

import (
"sync"

"github.com/jhump/protoreflect/desc"
dpb "google.golang.org/protobuf/types/descriptorpb"
)
Expand All @@ -37,12 +39,25 @@ func pathLocation(d desc.Descriptor, path ...int) *dpb.SourceCodeInfo_Location {
return sourceInfoRegistry.sourceInfo(d.GetFile()).findLocation(fullPath)
}

type sourceInfo map[string]*dpb.SourceCodeInfo_Location
type sourceInfo struct {
// infoMu protects the info map
infoMu sync.Mutex
info map[string]*dpb.SourceCodeInfo_Location
}

func newSourceInfo() *sourceInfo {
return &sourceInfo{
info: map[string]*dpb.SourceCodeInfo_Location{},
}
}

// findLocation returns the Location for a given path.
func (si sourceInfo) findLocation(path []int32) *dpb.SourceCodeInfo_Location {
func (si *sourceInfo) findLocation(path []int32) *dpb.SourceCodeInfo_Location {
si.infoMu.Lock()
defer si.infoMu.Unlock()

// If the path exists in the source info registry, return that object.
if loc, ok := si[strPath(path)]; ok {
if loc, ok := si.info[strPath(path)]; ok {
return loc
}

Expand All @@ -53,7 +68,17 @@ func (si sourceInfo) findLocation(path []int32) *dpb.SourceCodeInfo_Location {
// The source map registry is a singleton that computes a source map for
// any file descriptor that it is given, but then caches it to avoid computing
// the source map for the same file descriptors over and over.
type sourceInfoRegistryType map[*desc.FileDescriptor]sourceInfo
type sourceInfoRegistryType struct {
// registryMu protects the registry map
registryMu sync.Mutex
registry map[*desc.FileDescriptor]*sourceInfo
}

func newSourceInfoRegistryType() *sourceInfoRegistryType {
return &sourceInfoRegistryType{
registry: map[*desc.FileDescriptor]*sourceInfo{},
}
}

// Each location has a path defined as an []int32, but we can not
// use slices as keys, so compile them into a string.
Expand All @@ -70,22 +95,24 @@ func strPath(segments []int32) (p string) {
// sourceInfo compiles the source info object for a given file descriptor.
// It also caches this into a registry, so subsequent calls using the same
// descriptor will return the same object.
func (sir sourceInfoRegistryType) sourceInfo(fd *desc.FileDescriptor) sourceInfo {
answer, ok := sir[fd]
func (sir *sourceInfoRegistryType) sourceInfo(fd *desc.FileDescriptor) *sourceInfo {
sir.registryMu.Lock()
defer sir.registryMu.Unlock()
answer, ok := sir.registry[fd]
if !ok {
answer = sourceInfo{}
answer = newSourceInfo()

// This file descriptor does not yet have a source info map.
// Compile one.
for _, loc := range fd.AsFileDescriptorProto().GetSourceCodeInfo().GetLocation() {
answer[strPath(loc.Path)] = loc
answer.info[strPath(loc.Path)] = loc
}

// Now that we calculated all of this, cache it on the registry so it
// does not need to be calculated again.
sir[fd] = answer
sir.registry[fd] = answer
}
return answer
}

var sourceInfoRegistry = sourceInfoRegistryType{}
var sourceInfoRegistry = newSourceInfoRegistryType()
28 changes: 28 additions & 0 deletions locations/locations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package locations

import (
"strings"
"sync"
"testing"

"github.com/jhump/protoreflect/desc"
Expand Down Expand Up @@ -47,3 +48,30 @@ func parse(t *testing.T, s string) *desc.FileDescriptor {
}
return fds[0]
}

func TestSourceInfo_Concurrency(t *testing.T) {
fd := parse(t, `
syntax = "proto3";
package foo.bar;
`)

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
FileSyntax(fd)
}()

wg.Add(1)
go func() {
defer wg.Done()
FilePackage(fd)
}()

wg.Add(1)
go func() {
defer wg.Done()
FileImport(fd, 0)
}()
wg.Wait()
}

0 comments on commit 223aa5b

Please sign in to comment.