Skip to content

Commit

Permalink
Merge pull request #485 from robertknight/graph-invalid-input-output-ids
Browse files Browse the repository at this point in the history
Return planning error if input or output ID is an operator
  • Loading branch information
robertknight authored Dec 25, 2024
2 parents fffb990 + 63ddb0b commit a9d26ee
Showing 1 changed file with 72 additions and 10 deletions.
82 changes: 72 additions & 10 deletions src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1706,10 +1706,32 @@ impl Graph {
if !all_unique(outputs, |x, y| x == y) {
return Err(RunError::PlanningError("output IDs are not unique".into()));
}
for (output_index, output_id) in outputs.iter().enumerate() {
match self.get_node(*output_id) {
Some(Node::Value(_) | Node::Constant(_)) => {}
_ => {
return Err(RunError::PlanningError(format!(
"output ID at index {} is not a value node in the graph",
output_index
)));
}
}
}

if !all_unique(inputs, |x_id, y_id| x_id == y_id) {
return Err(RunError::PlanningError("input IDs are not unique".into()));
}
for (input_index, input_id) in inputs.iter().enumerate() {
match self.get_node(*input_id) {
Some(Node::Value(_)) => {}
_ => {
return Err(RunError::PlanningError(format!(
"input ID at index {} is not a value node in the graph",
input_index
)));
}
}
}

// Build an execution plan via a depth first traversal of the graph
// starting at the output nodes. A helper struct is used as recursive
Expand Down Expand Up @@ -2377,6 +2399,56 @@ mod tests {
);
}

#[test]
fn test_invalid_input_id() {
let mut g = Graph::new();

let const_id = g.add_constant(None, Tensor::<f32>::zeros(&[5, 5]));
let (op_id, op_out) = g.add_simple_op("op", AddOne {}, &[]);
let input = Tensor::from([1.]);
let invalid_id = NodeId::from_u32(1234);

for wrong_input_id in [const_id, op_id, invalid_id] {
let result = g.run(
[(wrong_input_id, input.view().into())].into(),
&[op_out],
None,
None,
);
assert_eq!(
result,
Err(RunError::PlanningError(
"input ID at index 0 is not a value node in the graph".into(),
))
);
}
}

#[test]
fn test_invalid_output_id() {
let mut g = Graph::new();

let input_id = g.add_value(None, None, None);
let (op_id, _op_out) = g.add_simple_op("op", AddOne {}, &[input_id]);
let input = Tensor::from([1.]);
let invalid_id = NodeId::from_u32(1234);

for wrong_output_id in [op_id, invalid_id] {
let result = g.run(
[(input_id, input.view().into())].into(),
&[wrong_output_id],
None,
None,
);
assert_eq!(
result,
Err(RunError::PlanningError(
"output ID at index 0 is not a value node in the graph".into(),
))
);
}
}

#[test]
fn test_call_op_with_missing_input() {
let mut g = Graph::new();
Expand All @@ -2398,16 +2470,6 @@ mod tests {
);
}

#[test]
fn test_err_if_invalid_output() {
let g = Graph::new();
let result = g.run(vec![], &[NodeId::from_u32(123)], None, None);
assert_eq!(
result.err(),
Some(RunError::PlanningError("Missing output 123".to_string()))
);
}

#[test]
fn test_err_if_missing_operator_input() {
let mut g = Graph::new();
Expand Down

0 comments on commit a9d26ee

Please sign in to comment.