diff --git a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc index 15bac8f7d7c6a..4813382211368 100644 --- a/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc +++ b/orttraining/orttraining/training_ops/cuda/nn/layer_norm.cc @@ -89,7 +89,12 @@ Status LayerNormGrad::ComputeInternal(OpKernelContext* p_op_ke bias_grad_data = reinterpret_cast(bias_grad->template MutableData()); } + #ifndef USE_ROCM const int part_size = 16; + #else + // Optimization for ROCm MI100 + const int part_size = 64; + #endif auto part_grad_gamma = GetScratchBuffer(part_size * n2); auto part_grad_beta = GetScratchBuffer(part_size * n2); @@ -138,7 +143,12 @@ Status InvertibleLayerNormGrad::ComputeInternal(OpKernelContext* p_op_kern auto scale_grad_data = reinterpret_cast(scale_grad->template MutableData()); auto bias_grad_data = reinterpret_cast(bias_grad->template MutableData()); + #ifndef USE_ROCM const int part_size = 16; + #else + // Optimization for ROCm MI100 + const int part_size = 64; + #endif auto part_grad_gamma = GetScratchBuffer(part_size * n2); auto part_grad_beta = GetScratchBuffer(part_size * n2);