diff --git a/go/vt/throttler/max_replication_lag_module.go b/go/vt/throttler/max_replication_lag_module.go index bd4666ec92f..e492764e443 100644 --- a/go/vt/throttler/max_replication_lag_module.go +++ b/go/vt/throttler/max_replication_lag_module.go @@ -302,6 +302,12 @@ func (m *MaxReplicationLagModule) recalculateRate(lagRecordNow replicationLagRec if lagRecordNow.isZero() { panic("rate recalculation was triggered with a zero replication lag record") } + + // Protect against nil stats + if lagRecordNow.Stats == nil { + return + } + now := lagRecordNow.time lagNow := lagRecordNow.lag() diff --git a/go/vt/throttler/max_replication_lag_module_test.go b/go/vt/throttler/max_replication_lag_module_test.go index f0324df192c..6379b067412 100644 --- a/go/vt/throttler/max_replication_lag_module_test.go +++ b/go/vt/throttler/max_replication_lag_module_test.go @@ -22,6 +22,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/vt/log" "vitess.io/vitess/go/vt/discovery" @@ -83,6 +85,12 @@ func (tf *testFixture) process(lagRecord replicationLagRecord) { tf.m.processRecord(lagRecord) } +// recalculateRate does the same thing as MaxReplicationLagModule.recalculateRate() does +// for a new "lagRecord". +func (tf *testFixture) recalculateRate(lagRecord replicationLagRecord) { + tf.m.recalculateRate(lagRecord) +} + func (tf *testFixture) checkState(state state, rate int64, lastRateChange time.Time) error { if got, want := tf.m.currentState, state; got != want { return fmt.Errorf("module in wrong state. got = %v, want = %v", got, want) @@ -96,6 +104,47 @@ func (tf *testFixture) checkState(state state, rate int64, lastRateChange time.T return nil } +func TestNewMaxReplicationLagModule_recalculateRate(t *testing.T) { + testCases := []struct { + name string + lagRecord replicationLagRecord + expectPanic bool + }{ + { + name: "Zero lag", + lagRecord: replicationLagRecord{ + time: time.Time{}, + TabletHealth: discovery.TabletHealth{Stats: nil}, + }, + expectPanic: true, + }, + { + name: "nil lag record stats", + lagRecord: replicationLagRecord{ + time: time.Now(), + TabletHealth: discovery.TabletHealth{Stats: nil}, + }, + expectPanic: false, + }, + } + + for _, aTestCase := range testCases { + theCase := aTestCase + + t.Run(theCase.name, func(t *testing.T) { + t.Parallel() + + fixture, err := newTestFixtureWithMaxReplicationLag(5) + assert.NoError(t, err) + + if theCase.expectPanic { + assert.Panics(t, func() { fixture.recalculateRate(theCase.lagRecord) }) + } + }, + ) + } +} + func TestMaxReplicationLagModule_RateNotZeroWhenDisabled(t *testing.T) { tf, err := newTestFixtureWithMaxReplicationLag(ReplicationLagModuleDisabled) if err != nil {