diff --git a/aurora/model/aurora.py b/aurora/model/aurora.py index 9eba914..efe5541 100644 --- a/aurora/model/aurora.py +++ b/aurora/model/aurora.py @@ -349,10 +349,9 @@ def adapt_checkpoint_max_history_size(self, checkpoint) -> Any: device=weight.device, dtype=weight.dtype) # Copy the existing weights to the new tensor by duplicating the histories provided into any new history dimensions - for j in range(new_weight.shape[2]): - if j < weight.shape[2]: - # only fill existing weights, others are zeros - new_weight[:, :, j, :, :] = weight[:, :, j, :, :] + for j in range(weight.shape[2]): + # only fill existing weights, others are zeros + new_weight[:, :, j, :, :] = weight[:, :, j, :, :] checkpoint[name] = new_weight return checkpoint