Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce a plugin API to provide all thread local state, and deprecate stdio-specific methods (Cherry-pick of #15890) #15916

Merged
merged 1 commit into from
Jun 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/python/pants/bsp/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pants.bsp.context import BSPContext
from pants.bsp.protocol import BSPConnection
from pants.bsp.rules import rules as bsp_rules
from pants.engine.internals import native_engine
from pants.engine.internals.native_engine import PyThreadLocals
from pants.testutil.rule_runner import RuleRunner


Expand Down Expand Up @@ -93,7 +93,7 @@ def setup_bsp_server(
):
rule_runner = rule_runner or RuleRunner(rules=bsp_rules())
notification_names = notification_names or set()
stdio_destination = native_engine.stdio_thread_get_destination()
thread_locals = PyThreadLocals.get_for_current_thread()

with setup_pipes() as pipes:
context = BSPContext()
Expand All @@ -107,7 +107,7 @@ def setup_bsp_server(
)

def run_bsp_server():
native_engine.stdio_thread_set_destination(stdio_destination)
thread_locals.set_for_current_thread()
conn.run()

bsp_thread = Thread(target=run_bsp_server)
Expand Down
5 changes: 5 additions & 0 deletions src/python/pants/engine/internals/native_engine.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -386,5 +386,10 @@ class PyTypes:
class PyStdioDestination:
pass

class PyThreadLocals:
@classmethod
def get_for_current_thread(cls) -> PyThreadLocals: ...
def set_for_current_thread(self) -> None: ...

class PollTimeout(Exception):
pass
29 changes: 24 additions & 5 deletions src/python/pants/engine/streaming_workunit_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pants.base.specs import Specs
from pants.engine.addresses import Addresses
from pants.engine.fs import Digest, DigestContents, FileDigest, Snapshot
from pants.engine.internals import native_engine
from pants.engine.internals.native_engine import PyThreadLocals
from pants.engine.internals.scheduler import SchedulerSession, Workunit
from pants.engine.internals.selectors import Params
from pants.engine.rules import Get, MultiGet, QueryRule, collect_rules, rule
Expand All @@ -30,6 +30,24 @@
# -----------------------------------------------------------------------------------------------


def thread_locals_get_for_current_thread() -> PyThreadLocals:
"""Gets the engine's thread local state for the current thread.

In order to safely use StreamingWorkunitContext methods from additional threads,
StreamingWorkunit plugins should propagate thread local state from the threads that they are
initialized on to any additional threads that they spawn.
"""
return PyThreadLocals.get_for_current_thread()


def thread_locals_set_for_current_thread(thread_locals: PyThreadLocals) -> None:
"""Sets the engine's thread local state for the current thread.

See `thread_locals_get`.
"""
thread_locals.set_for_current_thread()


@dataclass(frozen=True)
class TargetInfo:
filename: str
Expand Down Expand Up @@ -246,9 +264,9 @@ def __init__(
self.block_until_complete = not allow_async_completion or any(
callback.can_finish_async is False for callback in self.callbacks
)
# Get the parent thread's logging destination. Note that this thread has not yet started
# Get the parent thread's thread locals. Note that this thread has not yet started
# as we are only in the constructor.
self.logging_destination = native_engine.stdio_thread_get_destination()
self.thread_locals = PyThreadLocals.get_for_current_thread()

def poll_workunits(self, *, finished: bool) -> None:
workunits = self.scheduler.poll_workunits(self.max_workunit_verbosity)
Expand All @@ -261,8 +279,9 @@ def poll_workunits(self, *, finished: bool) -> None:
)

def run(self) -> None:
# First, set the thread's logging destination to the parent thread's, meaning the console.
native_engine.stdio_thread_set_destination(self.logging_destination)
# First, set the thread's thread locals to the parent thread's in order to propagate the
# console, workunit stores, etc.
self.thread_locals.set_for_current_thread()
while not self.stop_request.isSet():
self.poll_workunits(finished=False)
self.stop_request.wait(timeout=self.report_interval)
Expand Down
37 changes: 33 additions & 4 deletions src/rust/engine/src/externs/interface.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ use rule_graph::{self, RuleGraph};
use task_executor::Executor;
use workunit_store::{
ArtifactOutput, ObservationMetric, UserMetadataItem, Workunit, WorkunitState, WorkunitStore,
WorkunitStoreHandle,
};

use crate::externs::fs::{todo_possible_store_missing_digest, PyFileDigest};
Expand Down Expand Up @@ -73,6 +74,7 @@ fn native_engine(py: Python, m: &PyModule) -> PyO3Result<()> {
m.add_class::<PySessionCancellationLatch>()?;
m.add_class::<PyStdioDestination>()?;
m.add_class::<PyTasks>()?;
m.add_class::<PyThreadLocals>()?;
m.add_class::<PyTypes>()?;

m.add_class::<externs::PyGeneratorResponseBreak>()?;
Expand Down Expand Up @@ -225,7 +227,7 @@ impl PyTypes {
struct PyScheduler(Scheduler);

#[pyclass]
struct PyStdioDestination(Arc<stdio::Destination>);
struct PyStdioDestination(PyThreadLocals);

/// Represents configuration related to process execution strategies.
///
Expand Down Expand Up @@ -500,6 +502,30 @@ fn py_result_from_root(py: Python, result: Result<Value, Failure>) -> PyResult {
}
}

#[pyclass]
struct PyThreadLocals(Arc<stdio::Destination>, Option<WorkunitStoreHandle>);

impl PyThreadLocals {
fn get() -> Self {
let stdio_dest = stdio::get_destination();
let workunit_store_handle = workunit_store::get_workunit_store_handle();
Self(stdio_dest, workunit_store_handle)
}
}

#[pymethods]
impl PyThreadLocals {
#[classmethod]
fn get_for_current_thread(_cls: &PyType) -> Self {
Self::get()
}

fn set_for_current_thread(&self) {
stdio::set_thread_destination(self.0.clone());
workunit_store::set_thread_workunit_store_handle(self.1.clone());
}
}

#[pyfunction]
fn nailgun_server_create(
py_executor: &externs::scheduler::PyExecutor,
Expand Down Expand Up @@ -1641,15 +1667,18 @@ fn stdio_thread_console_clear() {
stdio::get_destination().console_clear();
}

// TODO: Deprecated, but without easy access to the decorator. Use
// `PyThreadLocals::get_for_current_thread` instead. Remove in Pants 2.17.0.dev0.
#[pyfunction]
fn stdio_thread_get_destination() -> PyStdioDestination {
let dest = stdio::get_destination();
PyStdioDestination(dest)
PyStdioDestination(PyThreadLocals::get())
}

// TODO: Deprecated, but without easy access to the decorator. Use
// `PyThreadLocals::set_for_current_thread` instead. Remove in Pants 2.17.0.dev0.
#[pyfunction]
fn stdio_thread_set_destination(stdio_destination: &PyStdioDestination) {
stdio::set_thread_destination(stdio_destination.0.clone());
stdio_destination.0.set_for_current_thread();
}

// TODO: Needs to be thread-local / associated with the Console.
Expand Down