Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: akoumpa <[email protected]>
  • Loading branch information
akoumpa committed Jul 26, 2024
1 parent ce6b090 commit 12ce03c
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
4 changes: 2 additions & 2 deletions nemo/collections/llm/gpt/model/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def config(self) -> MixtralConfig8x7B:
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
max_position_embeddings=config.max_position_embeddings, # TODO
seq_length=4096, #config.max_position_embeddings,
seq_length=4096, # config.max_position_embeddings,
# RoPE
position_embedding_type='rope',
rotary_base=config.rope_theta,
Expand All @@ -143,7 +143,7 @@ def config(self) -> MixtralConfig8x7B:
# CPU init
use_cpu_initialization=True,
perform_initialization=False,
params_dtype=getattr(config, "torch_dtype", torch.bfloat16)
params_dtype=getattr(config, "torch_dtype", torch.bfloat16),
)


Expand Down
3 changes: 2 additions & 1 deletion nemo/lightning/io/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload

import numpy as np
from torch import nn
import torch
from torch import nn

SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module)
TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module)
Expand All @@ -19,6 +19,7 @@ class TransformCTX:
target: nn.Module
target_state: dict


@torch.no_grad
def apply_transforms(
source: nn.Module,
Expand Down

0 comments on commit 12ce03c

Please sign in to comment.