Skip to content

Commit

Permalink
Reapply "Refactor to move ChartState to vegafusion core (#519)" (#520)
Browse files Browse the repository at this point in the history
This reverts commit 713b159.
  • Loading branch information
jonmmease committed Oct 17, 2024
1 parent 713b159 commit 94b9d6d
Show file tree
Hide file tree
Showing 32 changed files with 3,266 additions and 997 deletions.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ prost = { version = "0.12.3" }
prost-types = { version = "0.12.3" }
object_store = { version = "0.11.0" }
lazy_static = { version = "1.5" }
async-trait = "0.1.73"

[workspace.dependencies.datafusion]
version = "42.0.0"
Expand Down
2,681 changes: 2,376 additions & 305 deletions pixi.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pixi.toml
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,8 @@ ruff = ">=0.6.9,<0.7"
mypy = ">=1.11.2,<2"
pixi-pycharm = ">=0.0.8,<0.0.9"
scipy = "1.14.1.*"
pandas = ">=2.2.3,<3"
pyarrow = ">=13.0.0,<14"

# Dependencies are those required at runtime by the Python packages
[dependencies]
Expand Down
10 changes: 10 additions & 0 deletions vegafusion-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ workspace = true
[dependencies.prost-types]
workspace = true

[dependencies.chrono-tz]
workspace = true

[dependencies.async-trait]
workspace = true

[dependencies.sqlparser]
workspace = true
optional = true
Expand All @@ -49,6 +55,10 @@ version = "1.6.9"
[dependencies.datafusion-common]
workspace = true

[dependencies.vegafusion-dataframe]
path = "../vegafusion-dataframe"
version = "1.6.9"

[dependencies.pyo3]
workspace = true
optional = true
Expand Down
236 changes: 236 additions & 0 deletions vegafusion-core/src/chart_state.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
use crate::{
data::dataset::VegaFusionDataset,
planning::{
apply_pre_transform::apply_pre_transform_datasets,
plan::SpecPlan,
stitch::CommPlan,
watch::{ExportUpdateArrow, ExportUpdateJSON, ExportUpdateNamespace},
},
proto::gen::{
pretransform::PreTransformSpecWarning,
tasks::{NodeValueIndex, TaskGraph, TzConfig, Variable, VariableNamespace},
},
runtime::VegaFusionRuntimeTrait,
spec::chart::ChartSpec,
task_graph::{graph::ScopedVariable, task_value::TaskValue},
};
use datafusion_common::ScalarValue;
use std::{
collections::{HashMap, HashSet},
sync::{Arc, Mutex},
};
use vegafusion_common::{
data::{scalar::ScalarValueHelpers, table::VegaFusionTable},
error::{Result, ResultWithContext, VegaFusionError},
};

#[derive(Clone)]
pub struct ChartState {
input_spec: ChartSpec,
transformed_spec: ChartSpec,
plan: SpecPlan,
inline_datasets: HashMap<String, VegaFusionDataset>,
task_graph: Arc<Mutex<TaskGraph>>,
task_graph_mapping: Arc<HashMap<ScopedVariable, NodeValueIndex>>,
server_to_client_value_indices: Arc<HashSet<NodeValueIndex>>,
warnings: Vec<PreTransformSpecWarning>,
}

impl ChartState {
pub async fn try_new(
runtime: &dyn VegaFusionRuntimeTrait,
spec: ChartSpec,
inline_datasets: HashMap<String, VegaFusionDataset>,
tz_config: TzConfig,
row_limit: Option<u32>,
) -> Result<Self> {
let dataset_fingerprints = inline_datasets
.iter()
.map(|(k, ds)| (k.clone(), ds.fingerprint()))
.collect::<HashMap<_, _>>();

let plan = SpecPlan::try_new(&spec, &Default::default())?;

let task_scope = plan
.server_spec
.to_task_scope()
.with_context(|| "Failed to create task scope for server spec")?;
let tasks = plan
.server_spec
.to_tasks(&tz_config, &dataset_fingerprints)
.unwrap();
let task_graph = TaskGraph::new(tasks, &task_scope).unwrap();
let task_graph_mapping = task_graph.build_mapping();
let server_to_client_value_indices: Arc<HashSet<_>> = Arc::new(
plan.comm_plan
.server_to_client
.iter()
.map(|scoped_var| task_graph_mapping.get(scoped_var).unwrap().clone())
.collect(),
);

// Gather values of server-to-client values using query_request
let indices: Vec<NodeValueIndex> = plan
.comm_plan
.server_to_client
.iter()
.map(|var| task_graph_mapping.get(var).unwrap().clone())
.collect();

let response_task_values = runtime
.query_request(Arc::new(task_graph.clone()), &indices, &inline_datasets)
.await?;

let mut init = Vec::new();
for response_value in response_task_values {
let variable = response_value
.variable
.with_context(|| "Missing variable for response value".to_string())?;

let scope = response_value.scope;
let proto_value = response_value
.value
.with_context(|| "Missing value for response value".to_string())?;

let value = TaskValue::try_from(&proto_value).with_context(|| {
"Deserialization failed for value of response value".to_string()
})?;

init.push(ExportUpdateArrow {
namespace: ExportUpdateNamespace::try_from(variable.ns()).unwrap(),
name: variable.name.clone(),
scope,
value,
});
}

let (transformed_spec, warnings) =
apply_pre_transform_datasets(&spec, &plan, init, row_limit)?;

Ok(Self {
input_spec: spec,
transformed_spec,
plan,
inline_datasets,
task_graph: Arc::new(Mutex::new(task_graph)),
task_graph_mapping: Arc::new(task_graph_mapping),
server_to_client_value_indices,
warnings,
})
}

pub async fn update(
&self,
runtime: &dyn VegaFusionRuntimeTrait,
updates: Vec<ExportUpdateJSON>,
) -> Result<Vec<ExportUpdateJSON>> {
let mut task_graph = self.task_graph.lock().map_err(|err| {
VegaFusionError::internal(format!("Failed to acquire task graph lock: {:?}", err))
})?;
let server_to_client = self.server_to_client_value_indices.clone();
let mut indices: Vec<NodeValueIndex> = Vec::new();
for export_update in &updates {
let var = match export_update.namespace {
ExportUpdateNamespace::Signal => Variable::new_signal(&export_update.name),
ExportUpdateNamespace::Data => Variable::new_data(&export_update.name),
};
let scoped_var: ScopedVariable = (var, export_update.scope.clone());
let node_value_index = self
.task_graph_mapping
.get(&scoped_var)
.with_context(|| format!("No task graph node found for {scoped_var:?}"))?
.clone();

let value = match export_update.namespace {
ExportUpdateNamespace::Signal => {
TaskValue::Scalar(ScalarValue::from_json(&export_update.value)?)
}
ExportUpdateNamespace::Data => {
TaskValue::Table(VegaFusionTable::from_json(&export_update.value)?)
}
};

indices.extend(task_graph.update_value(node_value_index.node_index as usize, value)?);
}

// Filter to update nodes in the comm plan
let indices: Vec<_> = indices
.iter()
.filter(|&node| server_to_client.contains(node))
.cloned()
.collect();

let cloned_task_graph = task_graph.clone();

// Drop the MutexGuard before await call to avoid warning
drop(task_graph);

let response_task_values = runtime
.query_request(
Arc::new(cloned_task_graph),
indices.as_slice(),
&self.inline_datasets,
)
.await?;

let mut response_updates = response_task_values
.into_iter()
.map(|response_value| {
let variable = response_value
.variable
.with_context(|| "Missing variable for response value".to_string())?;

let scope = response_value.scope;
let proto_value = response_value
.value
.with_context(|| "missing value for response value: {:?}".to_string())?;

let value = TaskValue::try_from(&proto_value).with_context(|| {
"Deserialization failed for value of response value: {:?}".to_string()
})?;

Ok(ExportUpdateJSON {
namespace: match variable.ns() {
VariableNamespace::Signal => ExportUpdateNamespace::Signal,
VariableNamespace::Data => ExportUpdateNamespace::Data,
VariableNamespace::Scale => {
return Err(VegaFusionError::internal("Unexpected scale variable"))
}
},
name: variable.name.clone(),
scope: scope.clone(),
value: value.to_json()?,
})
})
.collect::<Result<Vec<_>>>()?;

// Sort for deterministic ordering
response_updates.sort_by_key(|update| update.name.clone());

Ok(response_updates)
}

pub fn get_input_spec(&self) -> &ChartSpec {
&self.input_spec
}

pub fn get_server_spec(&self) -> &ChartSpec {
&self.plan.server_spec
}

pub fn get_client_spec(&self) -> &ChartSpec {
&self.plan.client_spec
}

pub fn get_transformed_spec(&self) -> &ChartSpec {
&self.transformed_spec
}

pub fn get_comm_plan(&self) -> &CommPlan {
&self.plan.comm_plan
}

pub fn get_warnings(&self) -> &Vec<PreTransformSpecWarning> {
&self.warnings
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use crate::error::Result;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
use std::sync::Arc;
use vegafusion_common::data::table::VegaFusionTable;
use vegafusion_core::error::Result;
use vegafusion_dataframe::dataframe::DataFrame;

#[derive(Clone)]
Expand Down
1 change: 1 addition & 0 deletions vegafusion-core/src/data/mod.rs
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
pub mod dataset;
pub mod tasks;
2 changes: 2 additions & 0 deletions vegafusion-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
extern crate lazy_static;
extern crate core;

pub mod chart_state;
pub mod data;
pub mod expression;
pub mod patch;
pub mod planning;
pub mod proto;
pub mod runtime;
pub mod spec;
pub mod task_graph;
pub mod transform;
Expand Down
Loading

0 comments on commit 94b9d6d

Please sign in to comment.