diff --git a/go/mathstats/beta.go b/go/mathstats/beta.go new file mode 100644 index 00000000000..f70565a28bb --- /dev/null +++ b/go/mathstats/beta.go @@ -0,0 +1,87 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mathstats + +import "math" + +func lgamma(x float64) float64 { + y, _ := math.Lgamma(x) + return y +} + +// mathBetaInc returns the value of the regularized incomplete beta +// function Iₓ(a, b). +// +// This is not to be confused with the "incomplete beta function", +// which can be computed as BetaInc(x, a, b)*Beta(a, b). +// +// If x < 0 or x > 1, returns NaN. +func mathBetaInc(x, a, b float64) float64 { + // Based on Numerical Recipes in C, section 6.4. This uses the + // continued fraction definition of I: + // + // (xᵃ*(1-x)ᵇ)/(a*B(a,b)) * (1/(1+(d₁/(1+(d₂/(1+...)))))) + // + // where B(a,b) is the beta function and + // + // d_{2m+1} = -(a+m)(a+b+m)x/((a+2m)(a+2m+1)) + // d_{2m} = m(b-m)x/((a+2m-1)(a+2m)) + if x < 0 || x > 1 { + return math.NaN() + } + bt := 0.0 + if 0 < x && x < 1 { + // Compute the coefficient before the continued + // fraction. + bt = math.Exp(lgamma(a+b) - lgamma(a) - lgamma(b) + + a*math.Log(x) + b*math.Log(1-x)) + } + if x < (a+1)/(a+b+2) { + // Compute continued fraction directly. + return bt * betacf(x, a, b) / a + } else { + // Compute continued fraction after symmetry transform. + return 1 - bt*betacf(1-x, b, a)/b + } +} + +// betacf is the continued fraction component of the regularized +// incomplete beta function Iₓ(a, b). +func betacf(x, a, b float64) float64 { + const maxIterations = 200 + const epsilon = 3e-14 + + raiseZero := func(z float64) float64 { + if math.Abs(z) < math.SmallestNonzeroFloat64 { + return math.SmallestNonzeroFloat64 + } + return z + } + + c := 1.0 + d := 1 / raiseZero(1-(a+b)*x/(a+1)) + h := d + for m := 1; m <= maxIterations; m++ { + mf := float64(m) + + // Even step of the recurrence. + numer := mf * (b - mf) * x / ((a + 2*mf - 1) * (a + 2*mf)) + d = 1 / raiseZero(1+numer*d) + c = raiseZero(1 + numer/c) + h *= d * c + + // Odd step of the recurrence. + numer = -(a + mf) * (a + b + mf) * x / ((a + 2*mf) * (a + 2*mf + 1)) + d = 1 / raiseZero(1+numer*d) + c = raiseZero(1 + numer/c) + hfac := d * c + h *= hfac + + if math.Abs(hfac-1) < epsilon { + return h + } + } + panic("betainc: a or b too big; failed to converge") +} diff --git a/go/mathstats/beta_test.go b/go/mathstats/beta_test.go new file mode 100644 index 00000000000..2878493a57d --- /dev/null +++ b/go/mathstats/beta_test.go @@ -0,0 +1,28 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mathstats + +import ( + "testing" +) + +func TestBetaInc(t *testing.T) { + // Example values from MATLAB betainc documentation. + testFunc(t, "I_0.5(%v, 3)", + func(a float64) float64 { return mathBetaInc(0.5, a, 3) }, + map[float64]float64{ + 0: 1.00000000000000, + 1: 0.87500000000000, + 2: 0.68750000000000, + 3: 0.50000000000000, + 4: 0.34375000000000, + 5: 0.22656250000000, + 6: 0.14453125000000, + 7: 0.08984375000000, + 8: 0.05468750000000, + 9: 0.03271484375000, + 10: 0.01928710937500, + }) +} diff --git a/go/mathstats/sample.go b/go/mathstats/sample.go new file mode 100644 index 00000000000..d645ee1a7f3 --- /dev/null +++ b/go/mathstats/sample.go @@ -0,0 +1,235 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mathstats + +import ( + "math" + "sort" +) + +// Sample is a collection of possibly weighted data points. +type Sample struct { + // Xs is the slice of sample values. + Xs []float64 + + // Sorted indicates that Xs is sorted in ascending order. + Sorted bool +} + +// Bounds returns the minimum and maximum values of xs. +func Bounds(xs []float64) (min float64, max float64) { + if len(xs) == 0 { + return math.NaN(), math.NaN() + } + min, max = xs[0], xs[0] + for _, x := range xs { + if x < min { + min = x + } + if x > max { + max = x + } + } + return +} + +// Bounds returns the minimum and maximum values of the Sample. +// +// If the Sample is weighted, this ignores samples with zero weight. +// +// This is constant time if s.Sorted and there are no zero-weighted +// values. +func (s Sample) Bounds() (min float64, max float64) { + if len(s.Xs) == 0 || !s.Sorted { + return Bounds(s.Xs) + } + return s.Xs[0], s.Xs[len(s.Xs)-1] +} + +// vecSum returns the sum of xs. +func vecSum(xs []float64) float64 { + sum := 0.0 + for _, x := range xs { + sum += x + } + return sum +} + +// Sum returns the (possibly weighted) sum of the Sample. +func (s Sample) Sum() float64 { + return vecSum(s.Xs) +} + +// Weight returns the total weight of the Sasmple. +func (s Sample) Weight() float64 { + return float64(len(s.Xs)) +} + +// Mean returns the arithmetic mean of xs. +func Mean(xs []float64) float64 { + if len(xs) == 0 { + return math.NaN() + } + m := 0.0 + for i, x := range xs { + m += (x - m) / float64(i+1) + } + return m +} + +// Mean returns the arithmetic mean of the Sample. +func (s Sample) Mean() float64 { + return Mean(s.Xs) +} + +// GeoMean returns the geometric mean of xs. xs must be positive. +func GeoMean(xs []float64) float64 { + if len(xs) == 0 { + return math.NaN() + } + m := 0.0 + for i, x := range xs { + if x <= 0 { + return math.NaN() + } + lx := math.Log(x) + m += (lx - m) / float64(i+1) + } + return math.Exp(m) +} + +// GeoMean returns the geometric mean of the Sample. All samples +// values must be positive. +func (s Sample) GeoMean() float64 { + return GeoMean(s.Xs) +} + +// Variance returns the sample variance of xs. +func Variance(xs []float64) float64 { + if len(xs) == 0 { + return math.NaN() + } else if len(xs) <= 1 { + return 0 + } + + // Based on Wikipedia's presentation of Welford 1962 + // (http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Online_algorithm). + // This is more numerically stable than the standard two-pass + // formula and not prone to massive cancellation. + mean, M2 := 0.0, 0.0 + for n, x := range xs { + delta := x - mean + mean += delta / float64(n+1) + M2 += delta * (x - mean) + } + return M2 / float64(len(xs)-1) +} + +// Variance returns the variance of xs +func (s Sample) Variance() float64 { + return Variance(s.Xs) +} + +// StdDev returns the sample standard deviation of xs. +func StdDev(xs []float64) float64 { + return math.Sqrt(Variance(xs)) +} + +// StdDev returns the sample standard deviation of the Sample. +func (s Sample) StdDev() float64 { + return StdDev(s.Xs) +} + +// Percentile returns the pctileth value from the Sample. This uses +// interpolation method R8 from Hyndman and Fan (1996). +// +// pctile will be capped to the range [0, 1]. If len(xs) == 0 or all +// weights are 0, returns NaN. +// +// Percentile(0.5) is the median. Percentile(0.25) and +// Percentile(0.75) are the first and third quartiles, respectively. +// +// This is constant time if s.Sorted and s.Weights == nil. +func (s *Sample) Percentile(pctile float64) float64 { + if len(s.Xs) == 0 { + return math.NaN() + } else if pctile <= 0 { + min, _ := s.Bounds() + return min + } else if pctile >= 1 { + _, max := s.Bounds() + return max + } + + if !s.Sorted { + s.Sort() + } + + N := float64(len(s.Xs)) + //n := pctile * (N + 1) // R6 + n := 1/3.0 + pctile*(N+1/3.0) // R8 + kf, frac := math.Modf(n) + k := int(kf) + if k <= 0 { + return s.Xs[0] + } else if k >= len(s.Xs) { + return s.Xs[len(s.Xs)-1] + } + return s.Xs[k-1] + frac*(s.Xs[k]-s.Xs[k-1]) +} + +// IQR returns the interquartile range of the Sample. +// +// This is constant time if s.Sorted and s.Weights == nil. +func (s Sample) IQR() float64 { + if !s.Sorted { + s = *s.Copy().Sort() + } + return s.Percentile(0.75) - s.Percentile(0.25) +} + +// Sort sorts the samples in place in s and returns s. +// +// A sorted sample improves the performance of some algorithms. +func (s *Sample) Sort() *Sample { + if s.Sorted || sort.Float64sAreSorted(s.Xs) { + // All set + } else { + sort.Float64s(s.Xs) + } + s.Sorted = true + return s +} + +// Copy returns a copy of the Sample. +// +// The returned Sample shares no data with the original, so they can +// be modified (for example, sorted) independently. +func (s Sample) Copy() *Sample { + xs := make([]float64, len(s.Xs)) + copy(xs, s.Xs) + return &Sample{xs, s.Sorted} +} + +// FilterOutliers updates this sample in-place by removing all the values that are outliers +func (s *Sample) FilterOutliers() { + // Discard outliers. + q1, q3 := s.Percentile(0.25), s.Percentile(0.75) + lo, hi := q1-1.5*(q3-q1), q3+1.5*(q3-q1) + nn := 0 + for _, value := range s.Xs { + if lo <= value && value <= hi { + s.Xs[nn] = value + nn++ + } + } + s.Xs = s.Xs[:nn] +} + +// Clear resets this sample so it contains 0 values +func (s *Sample) Clear() { + s.Xs = s.Xs[:0] + s.Sorted = false +} diff --git a/go/mathstats/sample_test.go b/go/mathstats/sample_test.go new file mode 100644 index 00000000000..fb9d6dbc6ee --- /dev/null +++ b/go/mathstats/sample_test.go @@ -0,0 +1,21 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mathstats + +import "testing" + +func TestSamplePercentile(t *testing.T) { + s := Sample{Xs: []float64{15, 20, 35, 40, 50}} + testFunc(t, "Percentile", s.Percentile, map[float64]float64{ + -1: 15, + 0: 15, + .05: 15, + .30: 19.666666666666666, + .40: 27, + .95: 50, + 1: 50, + 2: 50, + }) +} diff --git a/go/mathstats/tdist.go b/go/mathstats/tdist.go new file mode 100644 index 00000000000..5376669f32e --- /dev/null +++ b/go/mathstats/tdist.go @@ -0,0 +1,33 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mathstats + +import "math" + +// A TDist is a Student's t-distribution with V degrees of freedom. +type TDist struct { + V float64 +} + +func (t TDist) PDF(x float64) float64 { + return math.Exp(lgamma((t.V+1)/2)-lgamma(t.V/2)) / + math.Sqrt(t.V*math.Pi) * math.Pow(1+(x*x)/t.V, -(t.V+1)/2) +} + +func (t TDist) CDF(x float64) float64 { + if x == 0 { + return 0.5 + } else if x > 0 { + return 1 - 0.5*mathBetaInc(t.V/(t.V+x*x), t.V/2, 0.5) + } else if x < 0 { + return 1 - t.CDF(-x) + } else { + return math.NaN() + } +} + +func (t TDist) Bounds() (float64, float64) { + return -4, 4 +} diff --git a/go/mathstats/tdist_test.go b/go/mathstats/tdist_test.go new file mode 100644 index 00000000000..b30ba95662b --- /dev/null +++ b/go/mathstats/tdist_test.go @@ -0,0 +1,95 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mathstats + +import "testing" + +func TestT(t *testing.T) { + testFunc(t, "PDF(%v|v=1)", TDist{1}.PDF, map[float64]float64{ + -10: 0.0031515830315226806, + -9: 0.0038818278802901312, + -8: 0.0048970751720583188, + -7: 0.0063661977236758151, + -6: 0.0086029698968592104, + -5: 0.012242687930145799, + -4: 0.018724110951987692, + -3: 0.031830988618379075, + -2: 0.063661977236758149, + -1: 0.15915494309189537, + 0: 0.31830988618379075, + 1: 0.15915494309189537, + 2: 0.063661977236758149, + 3: 0.031830988618379075, + 4: 0.018724110951987692, + 5: 0.012242687930145799, + 6: 0.0086029698968592104, + 7: 0.0063661977236758151, + 8: 0.0048970751720583188, + 9: 0.0038818278802901312}) + testFunc(t, "PDF(%v|v=5)", TDist{5}.PDF, map[float64]float64{ + -10: 4.0989816415343313e-05, + -9: 7.4601664362590413e-05, + -8: 0.00014444303269563934, + -7: 0.00030134402928803911, + -6: 0.00068848154013743002, + -5: 0.0017574383788078445, + -4: 0.0051237270519179133, + -3: 0.017292578800222964, + -2: 0.065090310326216455, + -1: 0.21967979735098059, + 0: 0.3796066898224944, + 1: 0.21967979735098059, + 2: 0.065090310326216455, + 3: 0.017292578800222964, + 4: 0.0051237270519179133, + 5: 0.0017574383788078445, + 6: 0.00068848154013743002, + 7: 0.00030134402928803911, + 8: 0.00014444303269563934, + 9: 7.4601664362590413e-05}) + + testFunc(t, "CDF(%v|v=1)", TDist{1}.CDF, map[float64]float64{ + -10: 0.03172551743055356, + -9: 0.035223287477277272, + -8: 0.039583424160565539, + -7: 0.045167235300866547, + -6: 0.052568456711253424, + -5: 0.06283295818900117, + -4: 0.077979130377369324, + -3: 0.10241638234956672, + -2: 0.14758361765043321, + -1: 0.24999999999999978, + 0: 0.5, + 1: 0.75000000000000022, + 2: 0.85241638234956674, + 3: 0.89758361765043326, + 4: 0.92202086962263075, + 5: 0.93716704181099886, + 6: 0.94743154328874657, + 7: 0.95483276469913347, + 8: 0.96041657583943452, + 9: 0.96477671252272279}) + testFunc(t, "CDF(%v|v=5)", TDist{5}.CDF, map[float64]float64{ + -10: 8.5473787871481787e-05, + -9: 0.00014133998712194845, + -8: 0.00024645333028622187, + -7: 0.00045837375719920225, + -6: 0.00092306914479700695, + -5: 0.0020523579900266612, + -4: 0.0051617077404157259, + -3: 0.015049623948731284, + -2: 0.05096973941492914, + -1: 0.18160873382456127, + 0: 0.5, + 1: 0.81839126617543867, + 2: 0.9490302605850709, + 3: 0.98495037605126878, + 4: 0.99483829225958431, + 5: 0.99794764200997332, + 6: 0.99907693085520299, + 7: 0.99954162624280074, + 8: 0.99975354666971372, + 9: 0.9998586600128780}) +} diff --git a/go/mathstats/ttest.go b/go/mathstats/ttest.go new file mode 100644 index 00000000000..218ad8c0807 --- /dev/null +++ b/go/mathstats/ttest.go @@ -0,0 +1,170 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mathstats + +import ( + "errors" + "math" +) + +// A LocationHypothesis specifies the alternative hypothesis of a +// location test such as a t-test or a Mann-Whitney U-test. The +// default (zero) value is to test against the alternative hypothesis +// that they differ. +type LocationHypothesis int + +const ( + // LocationLess specifies the alternative hypothesis that the + // location of the first sample is less than the second. This + // is a one-tailed test. + LocationLess LocationHypothesis = -1 + + // LocationDiffers specifies the alternative hypothesis that + // the locations of the two samples are not equal. This is a + // two-tailed test. + LocationDiffers LocationHypothesis = 0 + + // LocationGreater specifies the alternative hypothesis that + // the location of the first sample is greater than the + // second. This is a one-tailed test. + LocationGreater LocationHypothesis = 1 +) + +// A TTestResult is the result of a t-test. +type TTestResult struct { + // N1 and N2 are the sizes of the input samples. For a + // one-sample t-test, N2 is 0. + N1, N2 int + + // T is the value of the t-statistic for this t-test. + T float64 + + // DoF is the degrees of freedom for this t-test. + DoF float64 + + // AltHypothesis specifies the alternative hypothesis tested + // by this test against the null hypothesis that there is no + // difference in the means of the samples. + AltHypothesis LocationHypothesis + + // P is p-value for this t-test for the given null hypothesis. + P float64 +} + +func newTTestResult(n1, n2 int, t, dof float64, alt LocationHypothesis) *TTestResult { + dist := TDist{dof} + var p float64 + switch alt { + case LocationDiffers: + p = 2 * (1 - dist.CDF(math.Abs(t))) + case LocationLess: + p = dist.CDF(t) + case LocationGreater: + p = 1 - dist.CDF(t) + } + return &TTestResult{N1: n1, N2: n2, T: t, DoF: dof, AltHypothesis: alt, P: p} +} + +// A TTestSample is a sample that can be used for a one or two sample +// t-test. +type TTestSample interface { + Weight() float64 + Mean() float64 + Variance() float64 +} + +var ( + ErrSampleSize = errors.New("sample is too small") + ErrZeroVariance = errors.New("sample has zero variance") + ErrMismatchedSamples = errors.New("samples have different lengths") +) + +// TwoSampleTTest performs a two-sample (unpaired) Student's t-test on +// samples x1 and x2. This is a test of the null hypothesis that x1 +// and x2 are drawn from populations with equal means. It assumes x1 +// and x2 are independent samples, that the distributions have equal +// variance, and that the populations are normally distributed. +func TwoSampleTTest(x1, x2 TTestSample, alt LocationHypothesis) (*TTestResult, error) { + n1, n2 := x1.Weight(), x2.Weight() + if n1 == 0 || n2 == 0 { + return nil, ErrSampleSize + } + v1, v2 := x1.Variance(), x2.Variance() + if v1 == 0 && v2 == 0 { + return nil, ErrZeroVariance + } + + dof := n1 + n2 - 2 + v12 := ((n1-1)*v1 + (n2-1)*v2) / dof + t := (x1.Mean() - x2.Mean()) / math.Sqrt(v12*(1/n1+1/n2)) + return newTTestResult(int(n1), int(n2), t, dof, alt), nil +} + +// TwoSampleWelchTTest performs a two-sample (unpaired) Welch's t-test +// on samples x1 and x2. This is like TwoSampleTTest, but does not +// assume the distributions have equal variance. +func TwoSampleWelchTTest(x1, x2 TTestSample, alt LocationHypothesis) (*TTestResult, error) { + n1, n2 := x1.Weight(), x2.Weight() + if n1 <= 1 || n2 <= 1 { + // TODO: Can we still do this with n == 1? + return nil, ErrSampleSize + } + v1, v2 := x1.Variance(), x2.Variance() + if v1 == 0 && v2 == 0 { + return nil, ErrZeroVariance + } + + dof := math.Pow(v1/n1+v2/n2, 2) / + (math.Pow(v1/n1, 2)/(n1-1) + math.Pow(v2/n2, 2)/(n2-1)) + s := math.Sqrt(v1/n1 + v2/n2) + t := (x1.Mean() - x2.Mean()) / s + return newTTestResult(int(n1), int(n2), t, dof, alt), nil +} + +// PairedTTest performs a two-sample paired t-test on samples x1 and +// x2. If μ0 is non-zero, this tests if the average of the difference +// is significantly different from μ0. If x1 and x2 are identical, +// this returns nil. +func PairedTTest(x1, x2 []float64, μ0 float64, alt LocationHypothesis) (*TTestResult, error) { + if len(x1) != len(x2) { + return nil, ErrMismatchedSamples + } + if len(x1) <= 1 { + // TODO: Can we still do this with n == 1? + return nil, ErrSampleSize + } + + dof := float64(len(x1) - 1) + + diff := make([]float64, len(x1)) + for i := range x1 { + diff[i] = x1[i] - x2[i] + } + sd := StdDev(diff) + if sd == 0 { + // TODO: Can we still do the test? + return nil, ErrZeroVariance + } + t := (Mean(diff) - μ0) * math.Sqrt(float64(len(x1))) / sd + return newTTestResult(len(x1), len(x2), t, dof, alt), nil +} + +// OneSampleTTest performs a one-sample t-test on sample x. This tests +// the null hypothesis that the population mean is equal to μ0. This +// assumes the distribution of the population of sample means is +// normal. +func OneSampleTTest(x TTestSample, μ0 float64, alt LocationHypothesis) (*TTestResult, error) { + n, v := x.Weight(), x.Variance() + if n == 0 { + return nil, ErrSampleSize + } + if v == 0 { + // TODO: Can we still do the test? + return nil, ErrZeroVariance + } + dof := n - 1 + t := (x.Mean() - μ0) * math.Sqrt(n) / math.Sqrt(v) + return newTTestResult(int(n), 0, t, dof, alt), nil +} diff --git a/go/mathstats/ttest_test.go b/go/mathstats/ttest_test.go new file mode 100644 index 00000000000..0c9b78fdb9f --- /dev/null +++ b/go/mathstats/ttest_test.go @@ -0,0 +1,71 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mathstats + +import "testing" + +func TestTTest(t *testing.T) { + s1 := Sample{Xs: []float64{2, 1, 3, 4}} + s2 := Sample{Xs: []float64{6, 5, 7, 9}} + + check := func(want, got *TTestResult) { + if want.N1 != got.N1 || want.N2 != got.N2 || + !aeq(want.T, got.T) || !aeq(want.DoF, got.DoF) || + want.AltHypothesis != got.AltHypothesis || + !aeq(want.P, got.P) { + t.Errorf("want %+v, got %+v", want, got) + } + } + check3 := func(test func(alt LocationHypothesis) (*TTestResult, error), n1, n2 int, t, dof float64, pless, pdiff, pgreater float64) { + want := &TTestResult{N1: n1, N2: n2, T: t, DoF: dof} + + want.AltHypothesis = LocationLess + want.P = pless + got, _ := test(want.AltHypothesis) + check(want, got) + + want.AltHypothesis = LocationDiffers + want.P = pdiff + got, _ = test(want.AltHypothesis) + check(want, got) + + want.AltHypothesis = LocationGreater + want.P = pgreater + got, _ = test(want.AltHypothesis) + check(want, got) + } + + check3(func(alt LocationHypothesis) (*TTestResult, error) { + return TwoSampleTTest(s1, s1, alt) + }, 4, 4, 0, 6, + 0.5, 1, 0.5) + check3(func(alt LocationHypothesis) (*TTestResult, error) { + return TwoSampleWelchTTest(s1, s1, alt) + }, 4, 4, 0, 6, + 0.5, 1, 0.5) + + check3(func(alt LocationHypothesis) (*TTestResult, error) { + return TwoSampleTTest(s1, s2, alt) + }, 4, 4, -3.9703446152237674, 6, + 0.0036820296121056195, 0.0073640592242113214, 0.9963179703878944) + check3(func(alt LocationHypothesis) (*TTestResult, error) { + return TwoSampleWelchTTest(s1, s2, alt) + }, 4, 4, -3.9703446152237674, 5.584615384615385, + 0.004256431565689112, 0.0085128631313781695, 0.9957435684343109) + + check3(func(alt LocationHypothesis) (*TTestResult, error) { + return PairedTTest(s1.Xs, s2.Xs, 0, alt) + }, 4, 4, -17, 3, + 0.0002216717691559955, 0.00044334353831207749, 0.999778328230844) + + check3(func(alt LocationHypothesis) (*TTestResult, error) { + return OneSampleTTest(s1, 0, alt) + }, 4, 0, 3.872983346207417, 3, + 0.9847668541689145, 0.030466291662170977, 0.015233145831085482) + check3(func(alt LocationHypothesis) (*TTestResult, error) { + return OneSampleTTest(s1, 2.5, alt) + }, 4, 0, 0, 3, + 0.5, 1, 0.5) +} diff --git a/go/mathstats/util_test.go b/go/mathstats/util_test.go new file mode 100644 index 00000000000..68fac8488f4 --- /dev/null +++ b/go/mathstats/util_test.go @@ -0,0 +1,44 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package mathstats + +import ( + "fmt" + "math" + "sort" + "strings" + "testing" +) + +// aeq returns true if expect and got are equal to 8 significant +// figures (1 part in 100 million). +func aeq(expect, got float64) bool { + if expect < 0 && got < 0 { + expect, got = -expect, -got + } + return expect*0.99999999 <= got && got*0.99999999 <= expect +} + +func testFunc(t *testing.T, name string, f func(float64) float64, vals map[float64]float64) { + xs := make([]float64, 0, len(vals)) + for x := range vals { + xs = append(xs, x) + } + sort.Float64s(xs) + + for _, x := range xs { + want, got := vals[x], f(x) + if math.IsNaN(want) && math.IsNaN(got) || aeq(want, got) { + continue + } + var label string + if strings.Contains(name, "%v") { + label = fmt.Sprintf(name, x) + } else { + label = fmt.Sprintf("%s(%v)", name, x) + } + t.Errorf("want %s=%v, got %v", label, want, got) + } +} diff --git a/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go b/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go index 836cd13b4cb..48d6d2854c5 100644 --- a/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go +++ b/go/vt/vttablet/tabletmanager/vreplication/vcopier_test.go @@ -35,10 +35,8 @@ import ( func TestPlayerCopyCharPK(t *testing.T) { defer deleteTablet(addTablet(100)) - savedPacketSize := *vstreamer.PacketSize - // PacketSize of 1 byte will send at most one row at a time. - *vstreamer.PacketSize = 1 - defer func() { *vstreamer.PacketSize = savedPacketSize }() + reset := vstreamer.AdjustPacketSize(1) + defer reset() savedCopyTimeout := copyTimeout // copyTimeout should be low enough to have time to send one row. @@ -138,10 +136,8 @@ func TestPlayerCopyCharPK(t *testing.T) { func TestPlayerCopyVarcharPKCaseInsensitive(t *testing.T) { defer deleteTablet(addTablet(100)) - savedPacketSize := *vstreamer.PacketSize - // PacketSize of 1 byte will send at most one row at a time. - *vstreamer.PacketSize = 1 - defer func() { *vstreamer.PacketSize = savedPacketSize }() + reset := vstreamer.AdjustPacketSize(1) + defer reset() savedCopyTimeout := copyTimeout // copyTimeout should be low enough to have time to send one row. @@ -244,10 +240,8 @@ func TestPlayerCopyVarcharPKCaseInsensitive(t *testing.T) { func TestPlayerCopyVarcharCompositePKCaseSensitiveCollation(t *testing.T) { defer deleteTablet(addTablet(100)) - savedPacketSize := *vstreamer.PacketSize - // PacketSize of 1 byte will send at most one row at a time. - *vstreamer.PacketSize = 1 - defer func() { *vstreamer.PacketSize = savedPacketSize }() + reset := vstreamer.AdjustPacketSize(1) + defer reset() savedCopyTimeout := copyTimeout // copyTimeout should be low enough to have time to send one row. @@ -532,10 +526,8 @@ func TestPlayerCopyTables(t *testing.T) { func TestPlayerCopyBigTable(t *testing.T) { defer deleteTablet(addTablet(100)) - savedPacketSize := *vstreamer.PacketSize - // PacketSize of 1 byte will send at most one row at a time. - *vstreamer.PacketSize = 1 - defer func() { *vstreamer.PacketSize = savedPacketSize }() + reset := vstreamer.AdjustPacketSize(1) + defer reset() savedCopyTimeout := copyTimeout // copyTimeout should be low enough to have time to send one row. @@ -650,10 +642,8 @@ func TestPlayerCopyBigTable(t *testing.T) { func TestPlayerCopyWildcardRule(t *testing.T) { defer deleteTablet(addTablet(100)) - savedPacketSize := *vstreamer.PacketSize - // PacketSize of 1 byte will send at most one row at a time. - *vstreamer.PacketSize = 1 - defer func() { *vstreamer.PacketSize = savedPacketSize }() + reset := vstreamer.AdjustPacketSize(1) + defer reset() savedCopyTimeout := copyTimeout // copyTimeout should be low enough to have time to send one row. diff --git a/go/vt/vttablet/tabletserver/vstreamer/engine.go b/go/vt/vttablet/tabletserver/vstreamer/engine.go index 8485e16858e..c8a5b60a591 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/engine.go +++ b/go/vt/vttablet/tabletserver/vstreamer/engine.go @@ -373,5 +373,5 @@ func (vse *Engine) setWatch() { } func getPacketSize() int64 { - return int64(*PacketSize) + return int64(*defaultPacketSize) } diff --git a/go/vt/vttablet/tabletserver/vstreamer/packet_size.go b/go/vt/vttablet/tabletserver/vstreamer/packet_size.go new file mode 100644 index 00000000000..628104b7f30 --- /dev/null +++ b/go/vt/vttablet/tabletserver/vstreamer/packet_size.go @@ -0,0 +1,210 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vstreamer + +import ( + "flag" + "time" + + "vitess.io/vitess/go/mathstats" +) + +// defaultPacketSize is the suggested packet size for VReplication streamer. +var defaultPacketSize = flag.Int("vstream_packet_size", 250000, "Suggested packet size for VReplication streamer. This is used only as a recommendation. The actual packet size may be more or less than this amount.") + +// useDynamicPacketSize controls whether to use dynamic packet size adjustments to increase performance while streaming +var useDynamicPacketSize = flag.Bool("vstream_dynamic_packet_size", true, "Enable dynamic packet sizing for VReplication. This will adjust the packet size during replication to improve performance.") + +// PacketSizer is a controller that adjusts the size of the packets being sent by the vstreamer at runtime +type PacketSizer interface { + ShouldSend(byteCount int) bool + Record(byteCount int, duration time.Duration) + Limit() int +} + +// DefaultPacketSizer creates a new PacketSizer using the default settings. +// If dynamic packet sizing is enabled, this will return a dynamicPacketSizer. +func DefaultPacketSizer() PacketSizer { + if *useDynamicPacketSize { + return newDynamicPacketSizer(*defaultPacketSize) + } + return newFixedPacketSize(*defaultPacketSize) +} + +// AdjustPacketSize temporarily adjusts the default packet sizes to the given value. +// Calling the returned cleanup function resets them to their original value. +// This function is only used for testing. +func AdjustPacketSize(size int) func() { + originalSize := *defaultPacketSize + originalDyn := *useDynamicPacketSize + + *defaultPacketSize = size + *useDynamicPacketSize = false + + return func() { + *defaultPacketSize = originalSize + *useDynamicPacketSize = originalDyn + } +} + +type fixedPacketSizer struct { + baseSize int +} + +func newFixedPacketSize(baseSize int) PacketSizer { + return &fixedPacketSizer{baseSize: baseSize} +} + +func (ps *fixedPacketSizer) Limit() int { + return ps.baseSize +} + +// ShouldSend checks whether the given byte count is large enough to be sent as a packet while streaming +func (ps *fixedPacketSizer) ShouldSend(byteCount int) bool { + return byteCount >= ps.baseSize +} + +// Record records the total duration it took to send the given byte count while streaming +func (ps *fixedPacketSizer) Record(_ int, _ time.Duration) {} + +type dynamicPacketSizer struct { + // currentSize is the last size for the packet that is safe to use + currentSize int + + // currentMetrics are the performance metrics for the current size + currentMetrics *mathstats.Sample + + // candidateSize is the target size for packets being tested + candidateSize int + + // candidateMetrics are the performance metrics for this new metric + candidateMetrics *mathstats.Sample + + // grow is the growth rate with which we adjust the packet size + grow int + + // calls is the amount of calls to the packet sizer + calls int + + // settled is true when the last experiment has finished and arrived at a new target packet size + settled bool + + // elapsed is the time elapsed since the last experiment was settled + elapsed time.Duration +} + +func newDynamicPacketSizer(baseSize int) PacketSizer { + return &dynamicPacketSizer{ + currentSize: baseSize, + currentMetrics: &mathstats.Sample{}, + candidateMetrics: &mathstats.Sample{}, + candidateSize: baseSize, + grow: baseSize / 4, + } +} + +func (ps *dynamicPacketSizer) Limit() int { + return ps.candidateSize +} + +// ShouldSend checks whether the given byte count is large enough to be sent as a packet while streaming +func (ps *dynamicPacketSizer) ShouldSend(byteCount int) bool { + return byteCount >= ps.candidateSize +} + +type change int8 + +const ( + notChanging change = iota + gettingFaster + gettingSlower +) + +func (ps *dynamicPacketSizer) changeInThroughput() change { + const PValueMargin = 0.1 + + t, err := mathstats.TwoSampleWelchTTest(ps.currentMetrics, ps.candidateMetrics, mathstats.LocationDiffers) + if err != nil { + return notChanging + } + if t.P < PValueMargin { + if ps.candidateMetrics.Mean() > ps.currentMetrics.Mean() { + return gettingFaster + } + return gettingSlower + } + return notChanging +} + +func (ps *dynamicPacketSizer) reset() { + ps.currentMetrics.Clear() + ps.candidateMetrics.Clear() + ps.calls = 0 + ps.settled = false + ps.elapsed = 0 +} + +// Record records the total duration it took to send the given byte count while streaming +func (ps *dynamicPacketSizer) Record(byteCount int, d time.Duration) { + const ExperimentDelay = 5 * time.Second + const CheckFrequency = 16 + const GrowthFrequency = 32 + const InitialCandidateLen = 32 + const SettleCandidateLen = 64 + + if ps.settled { + ps.elapsed += d + if ps.elapsed < ExperimentDelay { + return + } + ps.reset() + } + + ps.calls++ + ps.candidateMetrics.Xs = append(ps.candidateMetrics.Xs, float64(byteCount)/float64(d)) + if ps.calls%CheckFrequency == 0 { + ps.candidateMetrics.Sorted = false + ps.candidateMetrics.FilterOutliers() + + if len(ps.currentMetrics.Xs) == 0 { + if len(ps.candidateMetrics.Xs) >= InitialCandidateLen { + ps.currentMetrics, ps.candidateMetrics = ps.candidateMetrics, ps.currentMetrics + } + return + } + + change := ps.changeInThroughput() + switch change { + case notChanging, gettingSlower: + if len(ps.candidateMetrics.Xs) >= SettleCandidateLen { + ps.candidateSize = ps.currentSize + ps.settled = true + } else { + if change == notChanging && ps.calls%GrowthFrequency == 0 { + ps.candidateSize += ps.grow + } + } + + case gettingFaster: + ps.candidateMetrics, ps.currentMetrics = ps.currentMetrics, ps.candidateMetrics + ps.candidateMetrics.Clear() + + ps.candidateSize += ps.grow + ps.currentSize = ps.candidateSize + } + } +} diff --git a/go/vt/vttablet/tabletserver/vstreamer/packet_size_test.go b/go/vt/vttablet/tabletserver/vstreamer/packet_size_test.go new file mode 100644 index 00000000000..ae71603ea52 --- /dev/null +++ b/go/vt/vttablet/tabletserver/vstreamer/packet_size_test.go @@ -0,0 +1,115 @@ +/* +Copyright 2021 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package vstreamer + +import ( + "math" + "math/rand" + "testing" + "time" +) + +type polynomial []float64 + +func (p polynomial) fit(x float64) float64 { + var y float64 + for i, exp := range p { + y += exp * math.Pow(x, float64(i)) + } + return y +} + +func simulate(t *testing.T, ps PacketSizer, base, mustSend int, interpolate func(float64) float64) (time.Duration, int) { + t.Helper() + + var elapsed time.Duration + var sent int + var sentPkt int + packetRange := float64(base) * 10.0 + + packetSize := 0 + for sent < mustSend { + packetSize += rand.Intn(base / 100) + + if ps.ShouldSend(packetSize) { + x := float64(packetSize) / packetRange + y := interpolate(x) + d := time.Duration(float64(time.Microsecond) * y * float64(packetSize)) + ps.Record(packetSize, d) + + sent += packetSize + elapsed += d + sentPkt++ + + packetSize = 0 + } + } + return elapsed, sentPkt +} + +func TestPacketSizeSimulation(t *testing.T) { + cases := []struct { + name string + baseSize int + p polynomial + error time.Duration + }{ + { + name: "growth with tapper", + baseSize: 25000, + p: polynomial{0.767, 1.278, -12.048, 25.262, -21.270, 6.410}, + }, + { + name: "growth without tapper", + baseSize: 25000, + p: polynomial{0.473, 5.333, -38.663, 90.731, -87.005, 30.128}, + error: 5 * time.Millisecond, + }, + { + name: "regression", + baseSize: 25000, + p: polynomial{0.247, -0.726, 2.864, -3.022, 2.273, -0.641}, + error: 1 * time.Second, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + seed := time.Now().UnixNano() + rand.Seed(seed) + + // Simulate a replication using the given polynomial and the dynamic packet sizer + ps1 := newDynamicPacketSizer(tc.baseSize) + elapsed1, sent1 := simulate(t, ps1, tc.baseSize, tc.baseSize*1000, tc.p.fit) + + // Simulate the same polynomial using a fixed packet size + ps2 := newFixedPacketSize(tc.baseSize) + elapsed2, sent2 := simulate(t, ps2, tc.baseSize, tc.baseSize*1000, tc.p.fit) + + // the simulation for dynamic packet sizing should always be faster then the fixed packet, + // and should also send fewer packets in total + delta := elapsed1 - elapsed2 + if delta > tc.error { + t.Errorf("packet-adjusted simulation is %v slower than fixed approach, seed %d", delta, seed) + } + if sent1 > sent2 { + t.Errorf("packet-adjusted simulation sent more packets (%d) than fixed approach (%d), seed %d", sent1, sent2, seed) + } + // t.Logf("dynamic = (%v, %d), fixed = (%v, %d)", elapsed1, sent1, elapsed2, sent2) + }) + } +} diff --git a/go/vt/vttablet/tabletserver/vstreamer/resultstreamer.go b/go/vt/vttablet/tabletserver/vstreamer/resultstreamer.go index c7897b756ca..231491f0cfc 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/resultstreamer.go +++ b/go/vt/vttablet/tabletserver/vstreamer/resultstreamer.go @@ -19,6 +19,7 @@ package vstreamer import ( "context" "fmt" + "time" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" @@ -39,17 +40,19 @@ type resultStreamer struct { tableName sqlparser.TableIdent send func(*binlogdatapb.VStreamResultsResponse) error vse *Engine + pktsize PacketSizer } func newResultStreamer(ctx context.Context, cp dbconfigs.Connector, query string, send func(*binlogdatapb.VStreamResultsResponse) error, vse *Engine) *resultStreamer { ctx, cancel := context.WithCancel(ctx) return &resultStreamer{ - ctx: ctx, - cancel: cancel, - cp: cp, - query: query, - send: send, - vse: vse, + ctx: ctx, + cancel: cancel, + cp: cp, + query: query, + send: send, + vse: vse, + pktsize: DefaultPacketSizer(), } } @@ -114,16 +117,18 @@ func (rs *resultStreamer) Stream() error { byteCount += s.Len() } - if byteCount >= *PacketSize { + if rs.pktsize.ShouldSend(byteCount) { rs.vse.resultStreamerNumRows.Add(int64(len(response.Rows))) rs.vse.resultStreamerNumPackets.Add(int64(1)) + startSend := time.Now() err = rs.send(response) if err != nil { return err } + rs.pktsize.Record(byteCount, time.Since(startSend)) // empty the rows so we start over, but we keep the // same capacity - response.Rows = nil + response.Rows = response.Rows[:0] byteCount = 0 } } diff --git a/go/vt/vttablet/tabletserver/vstreamer/resultstreamer_test.go b/go/vt/vttablet/tabletserver/vstreamer/resultstreamer_test.go index aeff0b976ec..9853ba9262e 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/resultstreamer_test.go +++ b/go/vt/vttablet/tabletserver/vstreamer/resultstreamer_test.go @@ -30,11 +30,9 @@ func TestStreamResults(t *testing.T) { if testing.Short() { t.Skip() } - oldPacketSize := *PacketSize - defer func() { - *PacketSize = oldPacketSize - }() - *PacketSize = 1 + + reset := AdjustPacketSize(1) + defer reset() engine.resultStreamerNumPackets.Reset() engine.resultStreamerNumRows.Reset() diff --git a/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go b/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go index 95863d1384f..77481b49254 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go +++ b/go/vt/vttablet/tabletserver/vstreamer/rowstreamer.go @@ -19,6 +19,7 @@ package vstreamer import ( "context" "fmt" + "time" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/dbconfigs" @@ -65,6 +66,7 @@ type rowStreamer struct { pkColumns []int sendQuery string vse *Engine + pktsize PacketSizer } func newRowStreamer(ctx context.Context, cp dbconfigs.Connector, se *schema.Engine, query string, lastpk []sqltypes.Value, vschema *localVSchema, send func(*binlogdatapb.VStreamRowsResponse) error, vse *Engine) *rowStreamer { @@ -79,6 +81,7 @@ func newRowStreamer(ctx context.Context, cp dbconfigs.Connector, se *schema.Engi send: send, vschema: vschema, vse: vse, + pktsize: DefaultPacketSizer(), } } @@ -237,11 +240,9 @@ func (rs *rowStreamer) streamQuery(conn *snapshotConn, send func(*binlogdatapb.V byteCount := 0 for { //log.Infof("StreamResponse for loop iteration starts") - select { - case <-rs.ctx.Done(): + if rs.ctx.Err() != nil { log.Infof("Stream ended because of ctx.Done") return fmt.Errorf("stream ended: %v", rs.ctx.Err()) - default: } // check throttler. @@ -273,19 +274,20 @@ func (rs *rowStreamer) streamQuery(conn *snapshotConn, send func(*binlogdatapb.V } } - if byteCount >= *PacketSize { + if rs.pktsize.ShouldSend(byteCount) { rs.vse.rowStreamerNumRows.Add(int64(len(response.Rows))) rs.vse.rowStreamerNumPackets.Add(int64(1)) response.Lastpk = sqltypes.RowToProto3(lastpk) + startSend := time.Now() err = send(response) if err != nil { log.Infof("Rowstreamer send returned error %v", err) return err } - // empty the rows so we start over, but we keep the - // same capacity - response.Rows = nil + rs.pktsize.Record(byteCount, time.Since(startSend)) + // empty the rows so we start over, but we keep the same capacity + response.Rows = response.Rows[:0] byteCount = 0 } } diff --git a/go/vt/vttablet/tabletserver/vstreamer/rowstreamer_test.go b/go/vt/vttablet/tabletserver/vstreamer/rowstreamer_test.go index ba7a113e95b..4a2a6e94b6e 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/rowstreamer_test.go +++ b/go/vt/vttablet/tabletserver/vstreamer/rowstreamer_test.go @@ -331,9 +331,8 @@ func TestStreamRowsMultiPacket(t *testing.T) { t.Skip() } - savedSize := *PacketSize - *PacketSize = 10 - defer func() { *PacketSize = savedSize }() + reset := AdjustPacketSize(10) + defer reset() execStatements(t, []string{ "create table t1(id int, val varbinary(128), primary key(id))", @@ -360,9 +359,8 @@ func TestStreamRowsCancel(t *testing.T) { t.Skip() } - savedSize := *PacketSize - *PacketSize = 10 - defer func() { *PacketSize = savedSize }() + reset := AdjustPacketSize(10) + defer reset() execStatements(t, []string{ "create table t1(id int, val varbinary(128), primary key(id))", diff --git a/go/vt/vttablet/tabletserver/vstreamer/vstreamer.go b/go/vt/vttablet/tabletserver/vstreamer/vstreamer.go index 32e142828a6..0252e3fa333 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/vstreamer.go +++ b/go/vt/vttablet/tabletserver/vstreamer/vstreamer.go @@ -19,7 +19,6 @@ package vstreamer import ( "bytes" "context" - "flag" "fmt" "io" "strings" @@ -42,9 +41,6 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" ) -// PacketSize is the suggested packet size for VReplication streamer. -var PacketSize = flag.Int("vstream_packet_size", 250000, "Suggested packet size for VReplication streamer. This is used only as a recommendation. The actual packet size may be more or less than this amount.") - // HeartbeatTime is set to slightly below 1s, compared to idleTimeout // set by VPlayer at slightly above 1s. This minimizes conflicts // between the two timeouts. @@ -227,7 +223,7 @@ func (vs *vstreamer) parseEvents(ctx context.Context, events <-chan mysql.Binlog return vs.send(vevents) case binlogdatapb.VEventType_INSERT, binlogdatapb.VEventType_DELETE, binlogdatapb.VEventType_UPDATE, binlogdatapb.VEventType_REPLACE: newSize := len(vevent.GetDml()) - if curSize+newSize > *PacketSize { + if curSize+newSize > *defaultPacketSize { vs.vse.vstreamerNumPackets.Add(1) vevents := bufferedEvents bufferedEvents = []*binlogdatapb.VEvent{vevent} @@ -248,7 +244,7 @@ func (vs *vstreamer) parseEvents(ctx context.Context, events <-chan mysql.Binlog newSize += len(rowChange.After.Values) } } - if curSize+newSize > *PacketSize { + if curSize+newSize > *defaultPacketSize { vs.vse.vstreamerNumPackets.Add(1) vevents := bufferedEvents bufferedEvents = []*binlogdatapb.VEvent{vevent} diff --git a/go/vt/vttablet/tabletserver/vstreamer/vstreamer_test.go b/go/vt/vttablet/tabletserver/vstreamer/vstreamer_test.go index 8047c8047e0..d15c4670716 100644 --- a/go/vt/vttablet/tabletserver/vstreamer/vstreamer_test.go +++ b/go/vt/vttablet/tabletserver/vstreamer/vstreamer_test.go @@ -1299,9 +1299,8 @@ func TestBuffering(t *testing.T) { t.Skip() } - savedSize := *PacketSize - *PacketSize = 10 - defer func() { *PacketSize = savedSize }() + reset := AdjustPacketSize(10) + defer reset() execStatement(t, "create table packet_test(id int, val varbinary(128), primary key(id))") defer execStatement(t, "drop table packet_test")