From 38537c42a0ce53b54198e3746c6941c281e077bb Mon Sep 17 00:00:00 2001 From: "Ma, Guokai" Date: Wed, 26 Jul 2023 12:17:12 -0400 Subject: [PATCH] faster allreduce with omp parallel for reduce kernel --- csrc/cpu/comm/ccl.cpp | 5 +++++ op_builder/cpu/comm.py | 3 +++ 2 files changed, 8 insertions(+) diff --git a/csrc/cpu/comm/ccl.cpp b/csrc/cpu/comm/ccl.cpp index 69859cc0c0d2..4a06b3239419 100644 --- a/csrc/cpu/comm/ccl.cpp +++ b/csrc/cpu/comm/ccl.cpp @@ -8,6 +8,7 @@ #include #include #include +#include #include #include #include @@ -188,6 +189,7 @@ void reduce_all_buffers(struct allreduce_workspace* workspace, // num_elements must be divisible by 16 (caller check) void reduce_bf16_buffers(int num_elements, int num_buffers, struct allreduce_workspace* workspace) { +#pragma omp parallel for for (int i = 0; i < num_elements * 2; i += VECTOR_LENGTH_IN_BYTES) { auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(workspace[0].buffer + i))); switch (num_buffers) { @@ -205,6 +207,7 @@ void reduce_bf16_buffers(int num_elements, int num_buffers, struct allreduce_wor void reduce_2_bf16_buffers(int num_elements, void* in_out, void* in1) { +#pragma omp parallel for for (int i = 0; i < num_elements * 2; i += VECTOR_LENGTH_IN_BYTES) { auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in_out + i))); auto in1_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)((char*)in1 + i))); @@ -222,6 +225,7 @@ void reduce_2_bf16_buffers(int num_elements, void* in_out, void* in1) // num_elements must be divisible by 16 (caller check) void reduce_fp32_buffers(int num_elements, int num_buffers, struct allreduce_workspace* workspace) { +#pragma omp parallel for for (int i = 0; i < num_elements * 4; i += VECTOR_LENGTH_IN_BYTES) { auto inout_val = _mm256_loadu_ps((float*)(workspace[0].buffer + i)); switch (num_buffers) { @@ -239,6 +243,7 @@ void reduce_fp32_buffers(int num_elements, int num_buffers, struct allreduce_wor void reduce_2_fp32_buffers(int num_elements, void* in_out, void* in1) { +#pragma omp parallel for for (int i = 0; i < num_elements * 4; i += VECTOR_LENGTH_IN_BYTES) { auto inout_val = _mm256_loadu_ps((float*)((char*)in_out + i)); auto in1_val = _mm256_loadu_ps((float*)((char*)in1 + i)); diff --git a/op_builder/cpu/comm.py b/op_builder/cpu/comm.py index c076ee48376d..ec908eb0622b 100644 --- a/op_builder/cpu/comm.py +++ b/op_builder/cpu/comm.py @@ -25,6 +25,9 @@ def include_paths(self): includes = ['csrc/cpu/includes'] return includes + def cxx_args(self): + return ['-O2', '-fopenmp'] + def is_compatible(self, verbose=True): # TODO: add soft compatibility check for private binary release. # a soft check, as in we know it can be trivially changed.