diff --git a/src/solvers/sgd/mod.rs b/src/solvers/sgd/mod.rs index 8b739e3c..64cab199 100644 --- a/src/solvers/sgd/mod.rs +++ b/src/solvers/sgd/mod.rs @@ -31,7 +31,13 @@ macro_rules! impl_isolver_sgd { for weight_gradient in net.learnable_weights_gradients() { let shape = weight_gradient.read().unwrap().desc().clone(); - let history_tensor = Arc::new(RwLock::new(SharedTensor::new(IBackend::device(&*self.backend), &shape).unwrap())); + let mut tensor = SharedTensor::new(IBackend::device(&*self.backend), + &shape).unwrap(); + + let filler = ::weight::FillerType::Constant { value: 0f32 }; + filler.fill(&mut tensor); + + let history_tensor = Arc::new(RwLock::new(tensor)); self.history.push(history_tensor); } }