Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Colossalai-Ascend] Support llama2-7b, chatglm2-6b finetune and inference on NPU #6118

Open
wants to merge 4 commits into
base: support-npu
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@ def __len__(self):
def __getitem__(self, idx):
return {
"input_ids": self.input_ids[idx],
"attention_mask": self.attention_mask[idx],
# "attention_mask": self.attention_mask[idx],
"labels": self.input_ids[idx],
}
519 changes: 519 additions & 0 deletions applications/Colossal-LLaMA/train_chatglm.py

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions colossalai/booster/booster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.quantization import BnbQuantizationConfig

from .accelerator import Accelerator
from .mixed_precision import MixedPrecision, mixed_precision_factory
Expand Down Expand Up @@ -242,7 +241,7 @@ def enable_lora(
model: nn.Module,
pretrained_dir: Optional[str] = None,
lora_config: "peft.LoraConfig" = None,
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
bnb_quantization_config=None,
quantize=False,
) -> nn.Module:
"""
Expand Down Expand Up @@ -279,6 +278,8 @@ def enable_lora(
ranks=[0],
)
else:
from colossalai.quantization import BnbQuantizationConfig

bnb_quantization_config = BnbQuantizationConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
Expand Down
35 changes: 27 additions & 8 deletions colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
Expand Down Expand Up @@ -1256,13 +1256,32 @@ def configure(
# Replace with distributed implementation if exists
optimizer = cast_to_distributed(optimizer)

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_config["partition_grad"] = False
zero_stage = 0
try:
from colossalai.nn.optimizer import DistGaloreAwamW

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_config["partition_grad"] = False
zero_stage = 0
except ImportError:
if zero_stage > 0 and self.dp_size > 0:
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_config["partition_grad"] = False
zero_stage = 0

# if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and self.dp_size > 0:
# self.logger.warning(
# "Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
# ranks=[0],
# )
# zero_config["partition_grad"] = False
# zero_stage = 0

if not isinstance(model, ModelWrapper):
# Shouldn't use pp (frequent grad accumulation) with torch ddp
Expand Down
28 changes: 19 additions & 9 deletions colossalai/booster/plugin/low_level_zero_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.nn.optimizer import cast_to_distributed
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
Expand Down Expand Up @@ -513,14 +513,24 @@ def configure(

# Replace with the distributed implementation if exists
optimizer = cast_to_distributed(optimizer)

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_optim_kwargs["partition_grad"] = False
zero_stage = 0
try:
from colossalai.nn.optimizer import DistGaloreAwamW

if isinstance(optimizer, DistGaloreAwamW) and zero_stage > 0 and dp_size > 0:
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_optim_kwargs["partition_grad"] = False
zero_stage = 0
except ImportError:
if zero_stage > 0 and dp_size > 0:
self.logger.warning(
"Galore is only supported for Tensor Parallel and vanilla Data Parallel yet. Disabling ZeRO.",
ranks=[0],
)
zero_optim_kwargs["partition_grad"] = False
zero_stage = 0

if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer: LowLevelZeroOptimizer = LowLevelZeroOptimizer(
Expand Down
52 changes: 34 additions & 18 deletions colossalai/nn/optimizer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
from galore_torch import GaLoreAdafactor, GaLoreAdamW

# try:
# from galore_torch import GaLoreAdafactor, GaLoreAdamW
# except TypeError:
# pass
from colossalai.logging import get_dist_logger

from .came import CAME
from .cpu_adam import CPUAdam
from .distributed_adafactor import DistributedAdaFactor
from .distributed_came import DistributedCAME
from .distributed_galore import DistGaloreAwamW
from .distributed_lamb import DistributedLamb
from .fused_adam import FusedAdam
from .fused_lamb import FusedLAMB
from .fused_sgd import FusedSGD
from .galore import GaLoreAdamW8bit
from .hybrid_adam import HybridAdam
from .lamb import Lamb
from .lars import Lars
Expand All @@ -27,31 +27,47 @@
"CPUAdam",
"HybridAdam",
"DistributedLamb",
"DistGaloreAwamW",
"GaLoreAdamW",
"GaLoreAdafactor",
"GaLoreAdamW8bit",
# "DistGaloreAwamW",
# "GaLoreAdamW",
# "GaLoreAdafactor",
# "GaLoreAdamW8bit",
"CAME",
"DistributedCAME",
"Adafactor",
"DistributedAdaFactor",
]

optim2DistOptim = {
GaLoreAdamW8bit: DistGaloreAwamW,
Lamb: DistributedLamb,
CAME: DistributedCAME,
Adafactor: DistributedAdaFactor,
}

try:
from galore_torch import GaLoreAdamW

from .distributed_galore import DistGaloreAwamW
from .galore import GaLoreAdamW8bit

optim2DistOptim[GaLoreAdamW8bit] = DistGaloreAwamW
__all__.append("DistGaloreAwamW")

def cast_to_distributed(optim):
if optim.__class__ in optim2DistOptim:
_logger = get_dist_logger()
_logger.info(f"Converting optimizer {optim.__class__.__name__} to its distributed version.", ranks=[0])

if isinstance(optim, GaLoreAdamW8bit):
return optim2DistOptim[GaLoreAdamW8bit](optim.param_groups, args=optim.args)
return optim2DistOptim[optim.__class__](optim.param_groups)

return optim

def cast_to_distributed(optim):
if optim.__class__ in optim2DistOptim:
_logger = get_dist_logger()
_logger.info(f"Converting optimizer {optim.__class__.__name__} to its distributed version.", ranks=[0])
except:

if isinstance(optim, GaLoreAdamW8bit):
return optim2DistOptim[GaLoreAdamW8bit](optim.param_groups, args=optim.args)
return optim2DistOptim[optim.__class__](optim.param_groups)
def cast_to_distributed(optim):
if optim.__class__ in optim2DistOptim:
_logger = get_dist_logger()
_logger.info(f"Converting optimizer {optim.__class__.__name__} to its distributed version.", ranks=[0])
return optim2DistOptim[optim.__class__](optim.param_groups)

return optim
return optim
10 changes: 8 additions & 2 deletions colossalai/nn/optimizer/hybrid_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,11 +157,17 @@ def step(self, closure=None, div_scale: float = -1):
div_scale,
)
self._post_update(p, "exp_avg", "exp_avg_sq")

elif target_device.type == "npu":
assert state["exp_avg"].device.type == "npu", "exp_avg should stay on npu"
assert state["exp_avg_sq"].device.type == "npu", "exp_avg should stay on npu"
# record the state by group and update at once
g_l.append(p.grad.data)
p_l.append(p.data)
m_l.append(state["exp_avg"])
v_l.append(state["exp_avg_sq"])
elif target_device.type == "cuda":
assert state["exp_avg"].device.type == "cuda", "exp_avg should stay on cuda"
assert state["exp_avg_sq"].device.type == "cuda", "exp_avg should stay on cuda"

# record the state by group and update at once
g_l.append(p.grad.data)
p_l.append(p.data)
Expand Down
16 changes: 15 additions & 1 deletion colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,26 @@ def _check_for_nccl_backend(group):
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL


def _check_for_hccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg
return torch.distributed.is_hccl_available() and pg.name() == c10d.Backend.HCCL


def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
is_hccl_backend = _check_for_hccl_backend(group)
current_device = torch.device("cpu")

if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
return current_device, is_nccl_backend
elif is_hccl_backend:
current_device = torch.device("npu", torch.cuda.current_device())
return current_device, is_nccl_backend or is_hccl_backend


TensorMetadata = namedtuple("TensorMetadata", ["shape", "dtype", "requires_grad"])
Expand Down
2 changes: 1 addition & 1 deletion colossalai/quantization/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

IS_4BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.39.0")
IS_8BIT_BNB_AVAILABLE = BNB_VERSION >= Version("0.37.2")
except ImportError:
except (ImportError, TypeError):
pass


Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,8 @@ def attention(
AttnMaskType.CAUSAL,
AttnMaskType.PADDED_CAUSAL,
)
if scale is None:
scale = 1.0 / ((q.size(-1)) ** 0.5)
return attn_func(
q,
k,
Expand Down
8 changes: 4 additions & 4 deletions colossalai/shardformer/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_
attention_mask_type = AttnMaskType.CUSTOM
if attention_mask is not None:
attn_bias = torch.zeros_like(attention_mask, dtype=query_layer.dtype)
attn_bias.masked_fill_(attention_mask, torch.finfo(query_layer.dtype).min)
attn_bias.masked_fill_(attention_mask.bool(), torch.finfo(query_layer.dtype).min)
dropout_p = self.attention_dropout.p if self.training else 0.0
context_layer = ColoAttention.attention(
query_layer,
Expand Down Expand Up @@ -180,9 +180,9 @@ def chatglm_model_forward(
],
dim=-1,
)
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)
# if full_attention_mask is None:
# if (attention_mask is not None and not attention_mask.cpu().all()) or (past_key_values and seq_length != 1):
# full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

# Support SP + PP
sp_size = shard_config.sequence_parallel_size
Expand Down
2 changes: 1 addition & 1 deletion examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ def empty_init():
config,
trust_remote_code=True,
**init_kwargs,
attn_implementation="flash_attention_2",
# attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
)
if args.grad_checkpoint:
Expand Down
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ click
fabric
contexttimer
ninja
torch>=2.2.0,<=2.4.0
torch==2.1.0
safetensors
einops
pydantic
Expand Down