Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

simplify scope tracking in burn-import #2207

Merged
merged 2 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions crates/burn-import/src/burn/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,21 +314,17 @@ impl<PS: PrecisionSettings> BurnGraph<PS> {
.flat_map(to_tensor)
.for_each(|tensor| {
self.scope
.tensor_register_variable(&tensor, node_position + 1)
})
});

self.nodes
.iter()
.enumerate()
.for_each(|(node_position, node)| {
.tensor_register_variable(&tensor, node_position + 1);
});
// Since the graph is guaranteed to be a DAG, we can safely register future uses
// of the inputs (which are the previous nodes' outputs)
node.input_types()
.into_iter()
.flat_map(to_tensor)
.for_each(|tensor| {
self.scope
.tensor_register_future_use(&tensor, node_position)
})
});
});
}

Expand Down
39 changes: 13 additions & 26 deletions crates/burn-import/src/burn/scope.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::collections::HashMap;
/// The scope struct ensures that ownership rules are respected during the forward pass.
#[derive(Clone, Debug, Default)]
pub struct Scope {
variables: HashMap<Ident, Vec<TensorVariable>>,
variables: HashMap<Ident, TensorVariable>,
}

#[derive(Clone, Debug, new)]
Expand All @@ -19,20 +19,13 @@ struct TensorVariable {
impl Scope {
/// Declare a new tensor variable.
pub fn tensor_register_variable(&mut self, tensor: &TensorType, node_position: usize) {
if let Some(variables) = self.variables.get_mut(&tensor.name) {
for variable in variables.iter_mut() {
if variable.node_position == node_position {
variable.references += 1;
return;
}
if let Some(variable) = self.variables.get_mut(&tensor.name) {
if variable.node_position == node_position {
variable.references += 1;
}

variables.push(TensorVariable::new(0, node_position));
} else {
self.variables.insert(
tensor.name.clone(),
vec![TensorVariable::new(0, node_position)],
);
self.variables
.insert(tensor.name.clone(), TensorVariable::new(0, node_position));
}
}

Expand All @@ -42,12 +35,9 @@ impl Scope {
///
/// We need to know all futures use of a variable in advance.
pub fn tensor_register_future_use(&mut self, tensor: &TensorType, node_position: usize) {
if let Some(variables) = self.variables.get_mut(&tensor.name) {
for variable in variables.iter_mut().rev() {
if node_position >= variable.node_position {
variable.references += 1;
break;
}
if let Some(variable) = self.variables.get_mut(&tensor.name) {
if node_position >= variable.node_position {
variable.references += 1;
}
} else {
panic!("No variable with name {}", tensor.name);
Expand All @@ -56,16 +46,13 @@ impl Scope {

/// Use a tensor variable, cloning it if it was registered multiple times and the tensor will still be used afterward.
pub fn tensor_use_owned(&mut self, tensor: &TensorType, node_position: usize) -> TokenStream {
if let Some(variables) = self.variables.get_mut(&tensor.name) {
if let Some(variable) = self.variables.get_mut(&tensor.name) {
let mut count = 0;
let name = &tensor.name;

for variable in variables.iter_mut().rev() {
if node_position >= variable.node_position {
variable.references -= 1;
count = variable.references;
break;
}
if node_position >= variable.node_position {
variable.references -= 1;
count = variable.references;
}

if count > 0 {
Expand Down
Loading