Skip to content

Commit

Permalink
Merge pull request #468 from dora-rs/specify-conda-env
Browse files Browse the repository at this point in the history
Specify conda env for Python Operators
  • Loading branch information
haixuanTao authored Apr 10, 2024
2 parents cea23e8 + b589049 commit 7126454
Show file tree
Hide file tree
Showing 10 changed files with 155 additions and 29 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,15 @@ jobs:
python-version: "3.10"
- name: "Python Dataflow example"
run: cargo run --example python-dataflow
- uses: conda-incubator/setup-miniconda@v2
with:
auto-activate-base: true
activate-environment: ""
- name: "Python Operator Dataflow example"
run: cargo run --example python-operator-dataflow
shell: bash -l {0}
run: |
conda deactivate
cargo run --example python-operator-dataflow
# ROS2 bridge examples
ros2-bridge-examples:
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions binaries/cli/src/attach.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ pub fn attach_dataflow(
CoreNodeKind::Custom(_cn) => (),
CoreNodeKind::Runtime(rn) => {
for op in rn.operators.iter() {
if let dora_core::descriptor::OperatorSource::Python(source) = &op.config.source
if let dora_core::descriptor::OperatorSource::Python(python_source) =
&op.config.source
{
let path = resolve_path(source, &working_dir).wrap_err_with(|| {
format!("failed to resolve node source `{}`", source)
})?;
let path = resolve_path(&python_source.source, &working_dir)
.wrap_err_with(|| {
format!("failed to resolve node source `{}`", python_source.source)
})?;
node_path_lookup
.insert(path, (dataflow_id, node.id.clone(), Some(op.id.clone())));
}
Expand Down
1 change: 1 addition & 0 deletions binaries/daemon/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ bincode = "1.3.3"
async-trait = "0.1.64"
aligned-vec = "0.5.0"
ctrlc = "3.2.5"
which = "5.0.0"
65 changes: 51 additions & 14 deletions binaries/daemon/src/spawn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use dora_core::{
config::{DataId, NodeRunConfig},
daemon_messages::{DataMessage, DataflowId, NodeConfig, RuntimeConfig, Timestamped},
descriptor::{
resolve_path, source_is_url, Descriptor, OperatorSource, ResolvedNode, SHELL_SOURCE,
resolve_path, source_is_url, Descriptor, OperatorDefinition, OperatorSource, PythonSource,
ResolvedNode, SHELL_SOURCE,
},
get_python_path,
message::uhlc::HLC,
Expand All @@ -19,7 +20,7 @@ use dora_node_api::{
arrow_utils::{copy_array_into_sample, required_data_size},
Metadata,
};
use eyre::WrapErr;
use eyre::{ContextCompat, WrapErr};
use std::{
env::consts::EXE_EXTENSION,
path::{Path, PathBuf},
Expand Down Expand Up @@ -149,26 +150,62 @@ pub async fn spawn_node(
})?
}
dora_core::descriptor::CoreNodeKind::Runtime(n) => {
let has_python_operator = n
let python_operators: Vec<&OperatorDefinition> = n
.operators
.iter()
.any(|x| matches!(x.config.source, OperatorSource::Python { .. }));
.filter(|x| matches!(x.config.source, OperatorSource::Python { .. }))
.collect();

let has_other_operator = n
let other_operators = n
.operators
.iter()
.any(|x| !matches!(x.config.source, OperatorSource::Python { .. }));

let mut command = if has_python_operator && !has_other_operator {
let mut command = if !python_operators.is_empty() && !other_operators {
// Use python to spawn runtime if there is a python operator
let python = get_python_path().context("Could not find python in daemon")?;
let mut command = tokio::process::Command::new(python);
command.args([
"-c",
format!("import dora; dora.start_runtime() # {}", node.id).as_str(),
]);
command
} else if !has_python_operator && has_other_operator {

// TODO: Handle multi-operator runtime once sub-interpreter is supported
if python_operators.len() > 2 {
eyre::bail!(
"Runtime currently only support one Python Operator.
This is because pyo4 sub-interpreter is not yet available.
See: https://github.com/PyO4/pyo3/issues/576"
);
}

let python_operator = python_operators
.first()
.context("Runtime had no operators definition.")?;

if let OperatorSource::Python(PythonSource {
source: _,
conda_env: Some(conda_env),
}) = &python_operator.config.source
{
let conda = which::which("conda").context(
"failed to find `conda`, yet a `conda_env` was defined. Make sure that `conda` is available.",
)?;
let mut command = tokio::process::Command::new(conda);
command.args([
"run",
"-n",
&conda_env,
"python",
"-c",
format!("import dora; dora.start_runtime() # {}", node.id).as_str(),
]);
command
} else {
let python = get_python_path()
.context("Could not find python path when spawning runtime node")?;
let mut command = tokio::process::Command::new(python);
command.args([
"-c",
format!("import dora; dora.start_runtime() # {}", node.id).as_str(),
]);
command
}
} else if python_operators.is_empty() && other_operators {
let mut cmd = tokio::process::Command::new(
std::env::current_exe().wrap_err("failed to get current executable path")?,
);
Expand Down
10 changes: 5 additions & 5 deletions binaries/runtime/src/operator/python.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use super::{OperatorEvent, StopReason};
use dora_core::{
config::{NodeId, OperatorId},
descriptor::{source_is_url, Descriptor},
descriptor::{source_is_url, Descriptor, PythonSource},
};
use dora_download::download_file;
use dora_node_api::Event;
Expand Down Expand Up @@ -35,25 +35,25 @@ fn traceback(err: pyo3::PyErr) -> eyre::Report {
pub fn run(
node_id: &NodeId,
operator_id: &OperatorId,
source: &str,
python_source: &PythonSource,
events_tx: Sender<OperatorEvent>,
incoming_events: flume::Receiver<Event>,
init_done: oneshot::Sender<Result<()>>,
dataflow_descriptor: &Descriptor,
) -> eyre::Result<()> {
let path = if source_is_url(source) {
let path = if source_is_url(&python_source.source) {
let target_path = Path::new("build")
.join(node_id.to_string())
.join(format!("{}.py", operator_id));
// try to download the shared library
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
rt.block_on(download_file(source, &target_path))
rt.block_on(download_file(&python_source.source, &target_path))
.wrap_err("failed to download Python operator")?;
target_path
} else {
Path::new(source).to_owned()
Path::new(&python_source.source).to_owned()
};

if !path.exists() {
Expand Down
28 changes: 28 additions & 0 deletions examples/python-operator-dataflow/dataflow_conda.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
nodes:
- id: webcam
operator:
python: webcam.py
inputs:
tick: dora/timer/millis/50
outputs:
- image

- id: object_detection
operator:
send_stdout_as: stdout
python: object_detection.py
inputs:
image: webcam/image
outputs:
- bbox
- stdout

- id: plot
operator:
python:
source: plot.py
conda_env: base
inputs:
image: webcam/image
bbox: object_detection/bbox
assistant_message: object_detection/stdout
9 changes: 7 additions & 2 deletions examples/python-operator-dataflow/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,13 @@ async fn main() -> eyre::Result<()> {
.await
.context("maturin develop failed")?;

let dataflow = Path::new("dataflow.yml");
run_dataflow(dataflow).await?;
if std::env::var("CONDA_EXE").is_ok() {
let dataflow = Path::new("dataflow.yml");
run_dataflow(dataflow).await?;
} else {
let dataflow = Path::new("dataflow_conda.yml");
run_dataflow(dataflow).await?;
}

Ok(())
}
Expand Down
46 changes: 45 additions & 1 deletion libraries/core/src/descriptor/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -260,9 +260,53 @@ pub struct OperatorConfig {
#[serde(rename_all = "kebab-case")]
pub enum OperatorSource {
SharedLibrary(String),
Python(String),
Python(PythonSource),
Wasm(String),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(
deny_unknown_fields,
from = "PythonSourceDef",
into = "PythonSourceDef"
)]
pub struct PythonSource {
pub source: String,
pub conda_env: Option<String>,
}

#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum PythonSourceDef {
SourceOnly(String),
WithOptions {
source: String,
conda_env: Option<String>,
},
}

impl From<PythonSource> for PythonSourceDef {
fn from(input: PythonSource) -> Self {
match input {
PythonSource {
source,
conda_env: None,
} => Self::SourceOnly(source),
PythonSource { source, conda_env } => Self::WithOptions { source, conda_env },
}
}
}

impl From<PythonSourceDef> for PythonSource {
fn from(value: PythonSourceDef) -> Self {
match value {
PythonSourceDef::SourceOnly(source) => Self {
source,
conda_env: None,
},
PythonSourceDef::WithOptions { source, conda_env } => Self { source, conda_env },
}
}
}

pub fn source_is_url(source: &str) -> bool {
source.contains("://")
Expand Down
5 changes: 3 additions & 2 deletions libraries/core/src/descriptor/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ pub fn check_dataflow(dataflow: &Descriptor, working_dir: &Path) -> eyre::Result
}
}
}
OperatorSource::Python(path) => {
OperatorSource::Python(python_source) => {
has_python_operator = true;
if source_is_url(path) {
let path = &python_source.source;
if source_is_url(&path) {

Check warning on line 56 in libraries/core/src/descriptor/validate.rs

View workflow job for this annotation

GitHub Actions / Clippy

this expression creates a reference which is immediately dereferenced by the compiler
info!("{path} is a URL."); // TODO: Implement url check.
} else if !working_dir.join(path).exists() {
bail!("no Python library at `{path}`");
Expand Down

0 comments on commit 7126454

Please sign in to comment.