diff --git a/internal/enginenetx/statsmanager.go b/internal/enginenetx/statsmanager.go index 287434c68a..5846b50832 100644 --- a/internal/enginenetx/statsmanager.go +++ b/internal/enginenetx/statsmanager.go @@ -211,7 +211,7 @@ func statsContainerRemoveOldEntries(input *statsContainer) (output *statsContain // At the name implies, this function MUST be called while holding the [*statsManager] mutex. func (c *statsContainer) GetStatsTacticLocked(tactic *httpsDialerTactic) (*statsTactic, bool) { domainEpntRecord, found := c.DomainEndpoints[tactic.domainEndpointKey()] - if !found { + if !found || domainEpntRecord == nil { return nil, false } tacticRecord, found := domainEpntRecord.Tactics[tactic.tacticSummaryKey()] diff --git a/internal/enginenetx/statsmanager_test.go b/internal/enginenetx/statsmanager_test.go index 4ab30b206d..e5f4886bdc 100644 --- a/internal/enginenetx/statsmanager_test.go +++ b/internal/enginenetx/statsmanager_test.go @@ -1016,3 +1016,30 @@ func TestStatsSafeIncrementMapStringInt64(t *testing.T) { } }) } + +func TestStatsContainer(t *testing.T) { + t.Run("GetStatsTacticLocked", func(t *testing.T) { + t.Run("is robust with respect to c.DomainEndpoints containing a nil entry", func(t *testing.T) { + sc := &statsContainer{ + DomainEndpoints: map[string]*statsDomainEndpoint{ + "api.ooni.io:443": nil, + }, + Version: statsContainerVersion, + } + tactic := &httpsDialerTactic{ + Address: "162.55.247.208", + InitialDelay: 0, + Port: "443", + SNI: "www.example.com", + VerifyHostname: "api.ooni.io", + } + record, good := sc.GetStatsTacticLocked(tactic) + if good { + t.Fatal("expected not good") + } + if record != nil { + t.Fatal("expected nil") + } + }) + }) +} diff --git a/internal/enginenetx/statspolicy.go b/internal/enginenetx/statspolicy.go index 9b886e1dbe..b70d0facc2 100644 --- a/internal/enginenetx/statspolicy.go +++ b/internal/enginenetx/statspolicy.go @@ -60,7 +60,7 @@ func (p *statsPolicy) LookupTactics(ctx context.Context, domain string, port str } // give priority to what we know from stats - for _, t := range p.statsLookupTactics(domain, port) { + for _, t := range statsPolicyPostProcessTactics(p.Stats.LookupTactics(domain, port)) { maybeEmitTactic(t) } @@ -73,29 +73,30 @@ func (p *statsPolicy) LookupTactics(ctx context.Context, domain string, port str return out } -func (p *statsPolicy) statsLookupTactics(domain string, port string) (out []*httpsDialerTactic) { - - // obtain information from the stats--here the result may be false if the - // stats do not contain any information about the domain and port - tactics, good := p.Stats.LookupTactics(domain, port) +func statsPolicyPostProcessTactics(tactics []*statsTactic, good bool) (out []*httpsDialerTactic) { + // when good is false, it means p.Stats.LookupTactics failed if !good { return } - // successRate is a convenience function for computing the success rate - successRate := func(t *statsTactic) (rate float64) { - if t.CountStarted > 0 { + // nilSafeSuccessRate is a convenience function for computing the success rate + // which returns zero as the success rate if CountStarted is zero + // + // for robustness, be paranoid about nils here because the stats are + // written on the disk and a user could potentially edit them + nilSafeSuccessRate := func(t *statsTactic) (rate float64) { + if t != nil && t.CountStarted > 0 { rate = float64(t.CountSuccess) / float64(t.CountStarted) } return } - // Implementation note: the function should implement the "less" semantics - // but we want descending sorting not ascending, so we're using a "more" semantics + // Implementation note: the function should implement the "less" semantics for + // ascending sorting, but we want descending sorting, so we use `>` instead sort.SliceStable(tactics, func(i, j int) bool { // TODO(bassosimone): should we also consider the number of samples // we have and how recent a sample is? - return successRate(tactics[i]) > successRate(tactics[j]) + return nilSafeSuccessRate(tactics[i]) > nilSafeSuccessRate(tactics[j]) }) for _, t := range tactics { @@ -103,9 +104,9 @@ func (p *statsPolicy) statsLookupTactics(domain string, port string) (out []*htt // to return what we already know it's not working and it will be the purpose of the // fallback policy to generate new tactics to test // - // additionally, as a precautionary and defensive measure, make sure t.Tactic - // is not nil before adding the real tactic to the return list - if t.CountSuccess > 0 && t.Tactic != nil { + // additionally, as a precautionary and defensive measure, make sure t and t.Tactic + // are not nil before adding a malformed tactic to the return list + if t != nil && t.CountSuccess > 0 && t.Tactic != nil { out = append(out, t.Tactic) } } diff --git a/internal/enginenetx/statspolicy_test.go b/internal/enginenetx/statspolicy_test.go index 68530f226d..32be8e8a3a 100644 --- a/internal/enginenetx/statspolicy_test.go +++ b/internal/enginenetx/statspolicy_test.go @@ -13,6 +13,7 @@ import ( "github.com/ooni/probe-cli/v3/internal/netemx" "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/ooni/probe-cli/v3/internal/runtimex" + "github.com/ooni/probe-cli/v3/internal/testingx" ) func TestStatsPolicyWorkingAsIntended(t *testing.T) { @@ -314,3 +315,43 @@ var _ httpsDialerPolicy = &mocksPolicy{} func (p *mocksPolicy) LookupTactics(ctx context.Context, domain string, port string) <-chan *httpsDialerTactic { return p.MockLookupTactics(ctx, domain, port) } + +func TestStatsPolicyPostProcessTactics(t *testing.T) { + t.Run("we do nothing when good is false", func(t *testing.T) { + tactics := statsPolicyPostProcessTactics(nil, false) + if len(tactics) != 0 { + t.Fatal("expected zero-lenght return value") + } + }) + + t.Run("we filter out cases in which t or t.Tactic are nil", func(t *testing.T) { + expected := &statsTactic{} + ff := &testingx.FakeFiller{} + ff.Fill(&expected) + + input := []*statsTactic{nil, { + CountStarted: 0, + CountTCPConnectError: 0, + CountTCPConnectInterrupt: 0, + CountTLSHandshakeError: 0, + CountTLSHandshakeInterrupt: 0, + CountTLSVerificationError: 0, + CountSuccess: 0, + HistoTCPConnectError: map[string]int64{}, + HistoTLSHandshakeError: map[string]int64{}, + HistoTLSVerificationError: map[string]int64{}, + LastUpdated: time.Time{}, + Tactic: nil, + }, nil, expected} + + got := statsPolicyPostProcessTactics(input, true) + + if len(got) != 1 { + t.Fatal("expected just one element") + } + + if diff := cmp.Diff(expected.Tactic, got[0]); diff != "" { + t.Fatal(diff) + } + }) +}