Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix race in set_tracer_provider() #2182

Merged
merged 8 commits into from
Oct 12, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased](https://github.com/open-telemetry/opentelemetry-python/compare/v1.5.0-0.24b0...HEAD)
- Fix race in `set_tracer_provider()`
([#2182](https://github.com/open-telemetry/opentelemetry-python/pull/2182))
- Automatically load OTEL environment variables as options for `opentelemetry-instrument`
([#1969](https://github.com/open-telemetry/opentelemetry-python/pull/1969))
- `opentelemetry-semantic-conventions` Update to semantic conventions v1.6.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def _translate_spans_with_dropped_attributes():

class TestJaegerExporter(unittest.TestCase):
def setUp(self):
trace_api._reset_globals() # pylint: disable=protected-access
# create and save span to be used in tests
self.context = trace_api.SpanContext(
trace_id=0x000000000000000000000000DEADBEEF,
Expand All @@ -73,6 +74,10 @@ def setUp(self):
self._test_span.end(end_time=3)
# pylint: disable=protected-access

def tearDown(self):
super().tearDown()
trace_api._reset_globals() # pylint: disable=protected-access

@patch("opentelemetry.exporter.jaeger.thrift.trace._TRACER_PROVIDER", None)
def test_constructor_default(self):
# pylint: disable=protected-access
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@

# pylint: disable=no-member
class TestCollectorSpanExporter(unittest.TestCase):
def setUp(self):
super().setUp()
trace_api._reset_globals() # pylint: disable=protected-access

def tearDown(self):
super().tearDown()
trace_api._reset_globals() # pylint: disable=protected-access

@mock.patch(
"opentelemetry.exporter.opencensus.trace_exporter.trace._TRACER_PROVIDER",
None,
Expand Down
50 changes: 31 additions & 19 deletions opentelemetry-api/src/opentelemetry/trace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@
)
from opentelemetry.trace.status import Status, StatusCode
from opentelemetry.util import types
from opentelemetry.util._once import Once
from opentelemetry.util._providers import _load_provider

logger = getLogger(__name__)
Expand Down Expand Up @@ -452,8 +453,19 @@ def start_as_current_span(
yield INVALID_SPAN


_TRACER_PROVIDER = None
_PROXY_TRACER_PROVIDER = None
_TRACER_PROVIDER_SET_ONCE = Once()
_TRACER_PROVIDER: Optional[TracerProvider] = None
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()


def _reset_globals() -> None:
aabmass marked this conversation as resolved.
Show resolved Hide resolved
"""WARNING: only use this for tests."""
global _TRACER_PROVIDER_SET_ONCE # pylint: disable=global-statement
global _TRACER_PROVIDER # pylint: disable=global-statement
global _PROXY_TRACER_PROVIDER # pylint: disable=global-statement
_TRACER_PROVIDER_SET_ONCE = Once()
_TRACER_PROVIDER = None
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()


def get_tracer(
Expand All @@ -476,40 +488,40 @@ def get_tracer(
)


def _set_tracer_provider(tracer_provider: TracerProvider, log: bool) -> None:
aabmass marked this conversation as resolved.
Show resolved Hide resolved
def set_tp() -> None:
global _TRACER_PROVIDER # pylint: disable=global-statement
_TRACER_PROVIDER = tracer_provider

did_set = _TRACER_PROVIDER_SET_ONCE.do_once(set_tp)

if not did_set:
logger.warning("Overriding of current TracerProvider is not allowed")


def set_tracer_provider(tracer_provider: TracerProvider) -> None:
"""Sets the current global :class:`~.TracerProvider` object.

This can only be done once, a warning will be logged if any furter attempt
is made.
"""
global _TRACER_PROVIDER # pylint: disable=global-statement

if _TRACER_PROVIDER is not None:
logger.warning("Overriding of current TracerProvider is not allowed")
return

_TRACER_PROVIDER = tracer_provider
_set_tracer_provider(tracer_provider, log=True)


def get_tracer_provider() -> TracerProvider:
"""Gets the current global :class:`~.TracerProvider` object."""
# pylint: disable=global-statement
global _TRACER_PROVIDER
global _PROXY_TRACER_PROVIDER

if _TRACER_PROVIDER is None:
# if a global tracer provider has not been set either via code or env
# vars, return a proxy tracer provider
if OTEL_PYTHON_TRACER_PROVIDER not in os.environ:
if not _PROXY_TRACER_PROVIDER:
_PROXY_TRACER_PROVIDER = ProxyTracerProvider()
return _PROXY_TRACER_PROVIDER

_TRACER_PROVIDER = cast( # type: ignore
"TracerProvider",
_load_provider(OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"),
tracer_provider: TracerProvider = _load_provider(
OTEL_PYTHON_TRACER_PROVIDER, "tracer_provider"
aabmass marked this conversation as resolved.
Show resolved Hide resolved
)
return _TRACER_PROVIDER
_set_tracer_provider(tracer_provider, log=False)
# _TRACER_PROVIDER will have been set by one thread
return cast("TracerProvider", _TRACER_PROVIDER)


@contextmanager # type: ignore
Expand Down
47 changes: 47 additions & 0 deletions opentelemetry-api/src/opentelemetry/util/_once.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from threading import Lock
from typing import Callable


class Once:
owais marked this conversation as resolved.
Show resolved Hide resolved
owais marked this conversation as resolved.
Show resolved Hide resolved
"""Execute a function exactly once and block all callers until the function returns

Same as golang's `sync.Once <https://pkg.go.dev/sync#Once>`_
"""

def __init__(self) -> None:
self._lock = Lock()
self._done = False

def do_once(self, func: Callable[[], None]) -> bool:
"""Execute ``func`` if it hasn't been executed or return.

Will block until ``func`` has been called by one thread.

Returns:
Whether or not ``func`` was executed in this call
"""

# fast path, try to avoid locking
if self._done:
return False

with self._lock:
if not self._done:
func()
self._done = True
return True
return False
72 changes: 62 additions & 10 deletions opentelemetry-api/tests/trace/test_globals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import unittest
from unittest.mock import patch
from unittest.mock import Mock, patch

from opentelemetry import context, trace
from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
from opentelemetry.trace.status import Status, StatusCode


Expand All @@ -27,23 +28,74 @@ def record_exception(

class TestGlobals(unittest.TestCase):
def setUp(self):
self._patcher = patch("opentelemetry.trace._TRACER_PROVIDER")
self._mock_tracer_provider = self._patcher.start()
super().setUp()
trace._reset_globals() # pylint: disable=protected-access

def tearDown(self) -> None:
self._patcher.stop()
def tearDown(self):
super().tearDown()
trace._reset_globals() # pylint: disable=protected-access

def test_get_tracer(self):
@staticmethod
@patch("opentelemetry.trace._TRACER_PROVIDER")
def test_get_tracer(mock_tracer_provider): # type: ignore
"""trace.get_tracer should proxy to the global tracer provider."""
trace.get_tracer("foo", "var")
self._mock_tracer_provider.get_tracer.assert_called_with(
"foo", "var", None
)
mock_provider = unittest.mock.Mock()
mock_tracer_provider.get_tracer.assert_called_with("foo", "var", None)
mock_provider = Mock()
trace.get_tracer("foo", "var", mock_provider)
mock_provider.get_tracer.assert_called_with("foo", "var", None)


class TestGlobalsConcurrency(ConcurrencyTestBase):
def setUp(self):
super().setUp()
trace._reset_globals() # pylint: disable=protected-access

def tearDown(self):
super().tearDown()
trace._reset_globals() # pylint: disable=protected-access

@patch("opentelemetry.trace.logger")
def test_set_tracer_provider_many_threads(self, mock_logger) -> None: # type: ignore
mock_logger.warning = MockFunc()

def do_concurrently() -> Mock:
# first get a proxy tracer
proxy_tracer = trace.ProxyTracerProvider().get_tracer("foo")

# try to set the global tracer provider
mock_tracer_provider = Mock(get_tracer=MockFunc())
trace.set_tracer_provider(mock_tracer_provider)

# start a span through the proxy which will call through to the mock provider
proxy_tracer.start_span("foo")

return mock_tracer_provider

num_threads = 100
mock_tracer_providers = self.run_with_many_threads(
do_concurrently,
num_threads=num_threads,
)

# despite trying to set tracer provider many times, only one of the
aabmass marked this conversation as resolved.
Show resolved Hide resolved
# mock_tracer_providers should have stuck and been called from
# proxy_tracer.start_span()
mock_tps_with_any_call = [
mock
for mock in mock_tracer_providers
if mock.get_tracer.call_count > 0
]

self.assertEqual(len(mock_tps_with_any_call), 1)
self.assertEqual(
mock_tps_with_any_call[0].get_tracer.call_count, num_threads
)

# should have warned everytime except for the successful set
self.assertEqual(mock_logger.warning.call_count, num_threads - 1)


class TestTracer(unittest.TestCase):
def setUp(self):
# pylint: disable=protected-access
Expand Down
15 changes: 11 additions & 4 deletions opentelemetry-api/tests/trace/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,15 @@ class TestSpan(NonRecordingSpan):


class TestProxy(unittest.TestCase):
def test_proxy_tracer(self):
original_provider = trace._TRACER_PROVIDER
def setUp(self) -> None:
super().setUp()
trace._reset_globals()

def tearDown(self) -> None:
super().tearDown()
trace._reset_globals()

def test_proxy_tracer(self):
provider = trace.get_tracer_provider()
# proxy provider
self.assertIsInstance(provider, trace.ProxyTracerProvider)
Expand All @@ -60,6 +66,9 @@ def test_proxy_tracer(self):
# set a real provider
trace.set_tracer_provider(TestProvider())

# get_tracer_provider() now returns the real provider
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)

# tracer provider now returns real instance
self.assertIsInstance(trace.get_tracer_provider(), TestProvider)

Expand All @@ -71,5 +80,3 @@ def test_proxy_tracer(self):
# creates real spans
with tracer.start_span("") as span:
self.assertIsInstance(span, TestSpan)

trace._TRACER_PROVIDER = original_provider
48 changes: 48 additions & 0 deletions opentelemetry-api/tests/util/test_once.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from opentelemetry.test.concurrency_test import ConcurrencyTestBase, MockFunc
from opentelemetry.util._once import Once


class TestOnce(ConcurrencyTestBase):
def test_once_single_thread(self):
once_func = MockFunc()
once = Once()

self.assertEqual(once_func.call_count, 0)

# first call should run
called = once.do_once(once_func)
self.assertTrue(called)
self.assertEqual(once_func.call_count, 1)

# subsequent calls do nothing
called = once.do_once(once_func)
self.assertFalse(called)
self.assertEqual(once_func.call_count, 1)

def test_once_many_threads(self):
once_func = MockFunc()
once = Once()

def run_concurrently() -> bool:
return once.do_once(once_func)

results = self.run_with_many_threads(run_concurrently, num_threads=100)

self.assertEqual(once_func.call_count, 1)

# check that only one of the threads got True
self.assertEqual(results.count(True), 1)
Loading