Skip to content

Commit

Permalink
dtype fix in mamba
Browse files Browse the repository at this point in the history
Signed-off-by: arendu <[email protected]>
  • Loading branch information
arendu committed Nov 13, 2024
1 parent 7c78ef4 commit 6c2ce66
Showing 1 changed file with 2 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import torch
import torch.nn.functional as F
from omegaconf.dictconfig import DictConfig
from omegaconf import open_dict
from pytorch_lightning.trainer.trainer import Trainer

from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel
Expand Down Expand Up @@ -67,11 +68,7 @@ def model_provider_func(self, pre_process, post_process):
self.transformer_config.add_bias_linear = self.cfg.get('add_bias_linear', False)
self.transformer_config.gated_linear_unit = self.cfg.get('gated_linear_unit', False)
self.transformer_config.layernorm_epsilon = self.cfg.get('layernorm_epsilon', 1e-5)
if self.cfg.get('params_dtype'):
self.transformer_config.params_dtype = torch.bfloat16
else:
self.transformer_config.params_dtype = torch.float32
self.transformer_config.params_dtype = torch.bfloat16
self.transformer_config.params_dtype = torch.bfloat16 if self.cfg.params_dtype == "bf16" else torch.float32
if self.cfg.get('kv_channels'):
self.transformer_config.kv_channels = self.cfg.get('kv_channels')
if self.cfg.get('squared_relu_activation'):
Expand Down

0 comments on commit 6c2ce66

Please sign in to comment.