Skip to content

Commit

Permalink
Pass correct node size for ZeRO++ (#4085)
Browse files Browse the repository at this point in the history
* Pass correct node size

* formatting

---------

Co-authored-by: Connor Holmes <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
3 people authored Aug 9, 2023
1 parent 977254c commit f0463b4
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 8 deletions.
14 changes: 7 additions & 7 deletions csrc/quantization/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,8 @@ std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
int in_groups,
int out_groups,
int num_bits,
quantize::Type quant_type)
quantize::Type quant_type,
int devices_per_node)
{
auto scales_options = at::TensorOptions()
.dtype(at::kFloat)
Expand All @@ -201,25 +202,24 @@ std::vector<at::Tensor> quantized_reduction(at::Tensor& input_vals,
.requires_grad(false);

std::vector<long int> sz(input_vals.sizes().begin(), input_vals.sizes().end());
const int gpu_per_node = 16; // depend on machine in_groups/out_groups;
sz[sz.size() - 1] = sz.back() / gpu_per_node; // num of GPU per nodes
const int elems_per_in_tensor = at::numel(input_vals) / gpu_per_node;
sz[sz.size() - 1] = sz.back() / devices_per_node; // num of GPU per nodes
const int elems_per_in_tensor = at::numel(input_vals) / devices_per_node;
auto output = torch::empty(sz, output_options);

const int elems_per_in_group = elems_per_in_tensor / (in_groups / gpu_per_node);
const int elems_per_in_group = elems_per_in_tensor / (in_groups / devices_per_node);
const int elems_per_out_group = elems_per_in_tensor / out_groups;

launch_dequant_reduce((int8_t*)output.data_ptr(),
(float*)scales.data_ptr(),
(const int8_t*)input_vals.data_ptr(),
(const float*)input_scales.data_ptr(),
gpu_per_node,
devices_per_node,
num_bits,
quant_type,
out_groups,
elems_per_out_group,
elems_per_in_tensor,
in_groups / gpu_per_node,
in_groups / devices_per_node,
elems_per_in_group,
at::cuda::getCurrentCUDAStream());
return {output, scales};
Expand Down
3 changes: 2 additions & 1 deletion deepspeed/runtime/comm/coalesced_collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ def all_to_all_quant_reduce(tensors: List[Tensor], groups: {}) -> List[Tensor]:
all_to_all_single(local_output, intra_quant_int4, group=groups[f'local_{intra_idx}'])
all_to_all_single(scale_output, intra_q_scales, group=groups[f'local_{intra_idx}'])
global_input_tensor, global_scales = quantizer_module.quantized_reduction(
local_output, scale_output, intra_quant_group, inter_quant_group, 4, quantizer_module.Symmetric)
local_output, scale_output, intra_quant_group, inter_quant_group, 4, quantizer_module.Symmetric,
local_world_size)
global_output = torch.empty_like(global_input_tensor)
global_scale_output = torch.empty_like(global_scales)
all_to_all_single(global_output, global_input_tensor, group=groups[f'global_{inter_idx}'])
Expand Down

0 comments on commit f0463b4

Please sign in to comment.