From 085e69e8cee839086872fd8fef5cb9874de442ae Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 15 Jul 2024 13:49:38 -0700 Subject: [PATCH] Fixes in distributed layers --- python/mlx/nn/layers/distributed.py | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/python/mlx/nn/layers/distributed.py b/python/mlx/nn/layers/distributed.py index c29cd81d08..f1cef52e61 100644 --- a/python/mlx/nn/layers/distributed.py +++ b/python/mlx/nn/layers/distributed.py @@ -1,5 +1,6 @@ # Copyright © 2024 Apple Inc. +import math from functools import lru_cache from typing import Optional @@ -168,7 +169,7 @@ def __call__(self, x: mx.array) -> mx.array: if self.group.size() > 1: # Perform the local projection and aggregate the results x = x @ self["weight"].T - x = mx.distributed.all_sum(x, group=group) + x = mx.distributed.all_sum(x, group=self.group) # Add the bias if we have one if "bias" in self: @@ -316,9 +317,9 @@ def from_quantized_linear( bits=quantized_linear_layer.bits, group=group, ) - sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1 - sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1 - sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1 + sl.weight = quantized_linear_layer.weight[r * step : (r + 1) * step] * 1 + sl.scales = quantized_linear_layer.scales[r * step : (r + 1) * step] * 1 + sl.biases = quantized_linear_layer.biases[r * step : (r + 1) * step] * 1 if "bias" in quantized_linear_layer: sl.bias = quantized_linear_layer.bias[r * step : (r + 1) * step] * 1 @@ -413,7 +414,7 @@ def __call__(self, x: mx.array) -> mx.array: bits=self.bits, ) if self.group.size() > 1: - x = mx.distributed.sum_all(x, group=group) + x = mx.distributed.all_sum(x, group=self.group) if "bias" in self: x = x + self["bias"] return x @@ -428,6 +429,8 @@ def from_quantized_linear( N = group.size() r = group.rank() output_dims, input_dims = quantized_linear_layer.weight.shape + step = input_dims // N + step_grouped = quantized_linear_layer.scales.shape[1] // N input_dims *= (32 // quantized_linear_layer.bits) * N sl = cls( @@ -438,9 +441,15 @@ def from_quantized_linear( bits=quantized_linear_layer.bits, group=group, ) - sl.weight = quantized_linear_layer.weight[r : step : (r + 1) * step] * 1 - sl.scales = quantized_linear_layer.scales[r : step : (r + 1) * step] * 1 - sl.biases = quantized_linear_layer.biases[r : step : (r + 1) * step] * 1 + sl.weight = quantized_linear_layer.weight[:, r * step : (r + 1) * step] * 1 + sl.scales = ( + quantized_linear_layer.scales[:, r * step_grouped : (r + 1) * step_grouped] + * 1 + ) + sl.biases = ( + quantized_linear_layer.biases[:, r * step_grouped : (r + 1) * step_grouped] + * 1 + ) if "bias" in quantized_linear_layer: sl.bias = quantized_linear_layer.bias