diff --git a/sdk/trace/sampling.go b/sdk/trace/sampling.go index 25f0377c..c40994a8 100644 --- a/sdk/trace/sampling.go +++ b/sdk/trace/sampling.go @@ -128,13 +128,29 @@ func NeverSample() Sampler { return alwaysOffSampler{} } -// AlwaysParentSample returns a Sampler that samples a trace only -// if the parent span is sampled. -// This Sampler is a passthrough to the ProbabilitySampler with -// a fraction of value 0. -func AlwaysParentSample() Sampler { - return &probabilitySampler{ - traceIDUpperBound: 0, - description: "AlwaysParentSampler", +// ParentSample returns a Sampler that samples a trace only +// if the the span has a parent span and it is sampled. If the span has +// parent span but it is not sampled, neither will this span. If the span +// does not have a parent the fallback Sampler is used to determine if the +// span should be sampled. +func ParentSample(fallback Sampler) Sampler { + return parentSampler{fallback} +} + +type parentSampler struct { + fallback Sampler +} + +func (ps parentSampler) ShouldSample(p SamplingParameters) SamplingResult { + if p.ParentContext.IsValid() { + if p.ParentContext.IsSampled() { + return SamplingResult{Decision: RecordAndSampled} + } + return SamplingResult{Decision: NotRecord} } + return ps.fallback.ShouldSample(p) +} + +func (ps parentSampler) Description() string { + return fmt.Sprintf("ParentOrElse{%s}", ps.fallback.Description()) } diff --git a/sdk/trace/sampling_test.go b/sdk/trace/sampling_test.go index ff89c451..fe9ae080 100644 --- a/sdk/trace/sampling_test.go +++ b/sdk/trace/sampling_test.go @@ -23,7 +23,21 @@ import ( ) func TestAlwaysParentSampleWithParentSampled(t *testing.T) { - sampler := sdktrace.AlwaysParentSample() + sampler := sdktrace.ParentSample(sdktrace.AlwaysSample()) + traceID, _ := trace.IDFromHex("4bf92f3577b34da6a3ce929d0e0e4736") + spanID, _ := trace.SpanIDFromHex("00f067aa0ba902b7") + parentCtx := trace.SpanContext{ + TraceID: traceID, + SpanID: spanID, + TraceFlags: trace.FlagsSampled, + } + if sampler.ShouldSample(sdktrace.SamplingParameters{ParentContext: parentCtx}).Decision != sdktrace.RecordAndSampled { + t.Error("Sampling decision should be RecordAndSampled") + } +} + +func TestNeverParentSampleWithParentSampled(t *testing.T) { + sampler := sdktrace.ParentSample(sdktrace.NeverSample()) traceID, _ := trace.IDFromHex("4bf92f3577b34da6a3ce929d0e0e4736") spanID, _ := trace.SpanIDFromHex("00f067aa0ba902b7") parentCtx := trace.SpanContext{ @@ -37,7 +51,7 @@ func TestAlwaysParentSampleWithParentSampled(t *testing.T) { } func TestAlwaysParentSampleWithParentNotSampled(t *testing.T) { - sampler := sdktrace.AlwaysParentSample() + sampler := sdktrace.ParentSample(sdktrace.AlwaysSample()) traceID, _ := trace.IDFromHex("4bf92f3577b34da6a3ce929d0e0e4736") spanID, _ := trace.SpanIDFromHex("00f067aa0ba902b7") parentCtx := trace.SpanContext{ @@ -48,3 +62,17 @@ func TestAlwaysParentSampleWithParentNotSampled(t *testing.T) { t.Error("Sampling decision should be NotRecord") } } + +func TestParentSampleWithNoParent(t *testing.T) { + params := sdktrace.SamplingParameters{} + + sampler := sdktrace.ParentSample(sdktrace.AlwaysSample()) + if sampler.ShouldSample(params).Decision != sdktrace.RecordAndSampled { + t.Error("Sampling decision should be RecordAndSampled") + } + + sampler = sdktrace.ParentSample(sdktrace.NeverSample()) + if sampler.ShouldSample(params).Decision != sdktrace.NotRecord { + t.Error("Sampling decision should be NotRecord") + } +}