Skip to content

Commit

Permalink
fix/reshape: fix reshaping of network input blobs
Browse files Browse the repository at this point in the history
  • Loading branch information
hobofan committed Feb 22, 2016
1 parent 8ea9eac commit 20d97e9
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,14 @@ impl<B: IBackend + LayerOps<f32> + 'static> Network<B> {
for layer in &mut self.layers {
for (blob_index, blob_name) in layer.input_blob_names().to_owned().iter().enumerate() {
if blob_name == &self.input_blob_names[i] {
let reshaped_shape = layer.input_blobs_data[blob_index].read().unwrap().desc().clone();
layer.input_blobs_data[blob_index] = inp.clone();
// reshape input tensor to the reshaped shape
let old_shape = layer.input_blobs_data[blob_index].read().unwrap().desc().clone();
if old_shape.size() != reshaped_shape.size() {
panic!("The provided input does not have the expected shape");
}
layer.input_blobs_data[blob_index].write().unwrap().reshape(&reshaped_shape).unwrap();
}
}
}
Expand Down

0 comments on commit 20d97e9

Please sign in to comment.