From 277e09fb4054c6f748d5cb2dfae7d07f2aeb811e Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 1 Jul 2024 12:08:07 -0700 Subject: [PATCH] Registry docs update (#1323) --- README.md | 6 +- llmfoundry/layers_registry.py | 152 ++++++++++++++++++++++++----- llmfoundry/registry.py | 179 ++++++++++++++++++++++++++-------- 3 files changed, 272 insertions(+), 65 deletions(-) diff --git a/README.md b/README.md index 16c765c7e4..7d39b7a829 100644 --- a/README.md +++ b/README.md @@ -282,6 +282,8 @@ We provide two commands currently: Use `--help` on any of these commands for more information. +These commands can also help you understand what each registry is composed of, as each registry contains a docstring that will be printed out. The general concept is that each registry defines an interface, and components registered to that registry must implement that interface. If there is a part of the library that is not currently extendable, but you think it should be, please open an issue! + ## How to register There are a few ways to register a new component: @@ -289,8 +291,9 @@ There are a few ways to register a new component: ### Python entrypoints You can specify registered components via a Python entrypoint if you are building your own package with registered components. +This would be the expected usage if you are building a large extension to LLM Foundry, and going to be overriding many components. Note that things registered via entrypoints will override components registered directly in code. -For example, the following would register the `WandBLogger` class, under the key `wandb`, in the `llm_foundry.loggers` registry: +For example, the following would register the `MyLogger` class, under the key `my_logger`, in the `llm_foundry.loggers` registry: ```yaml @@ -359,6 +362,7 @@ code_paths: ... ``` +One of these would be the expected usage if you are building a small extension to LLM Foundry, only overriding a few components, and thus don't want to create an entire package. # Learn more about LLM Foundry! diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py index e618d03dc8..50a4906ec1 100644 --- a/llmfoundry/layers_registry.py +++ b/llmfoundry/layers_registry.py @@ -7,32 +7,65 @@ from llmfoundry.utils.registry_utils import create_registry -_norm_description = ( - 'The norms registry is used to register classes that implement normalization layers.' +_norms_description = ( + """The norms registry is used to register classes that implement normalization layers. + + One example of this is torch.nn.LayerNorm. See norm.py for examples. + + Args: + normalized_shape Union[int, List[int], torch.Size]: The shape of the input tensor. + device: Optional[torch.device]: The device to use for the normalization layer. + + Returns: + torch.nn.Module: The normalization layer. + """ ) norms = create_registry( 'llmfoundry', 'norms', generic_type=Type[torch.nn.Module], entry_points=True, - description=_norm_description, + description=_norms_description, ) -_fc_description = ( - 'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).' - + - 'These classes should take in_features and out_features in as args, at a minimum.' + +_fcs_description = ( + """The fcs registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear). + + See fc.py for examples. + + Args: + in_features: int: The number of input features. + out_features: int: The number of output features. + kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer. + + Returns: + torch.nn.Module: The fully connected layer. + """ ) fcs = create_registry( 'llmfoundry', 'fcs', generic_type=Type[torch.nn.Module], entry_points=True, - description=_fc_description, + description=_fcs_description, ) _ffns_description = ( - 'The ffns registry is used to register functions that build ffn layers.' + - 'See ffn.py for examples.' + """The ffns registry is used to register functions that build FFN layers. + + These layers are generally composed of fc layers and activation functions. + One example is MPTMLP. See ffn.py for examples. + + Args: + d_model: int: The size of the input and output tensors. + expansion_ratio: float: The expansion ratio for the hidden layer. + device: Optional[str]: The device to use for the layer. + bias: bool: Whether or not to include a bias term. + kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer. + + Returns: + torch.nn.Module: The FFN layer. + """ ) ffns = create_registry( 'llmfoundry', @@ -43,8 +76,21 @@ ) _ffns_with_norm_description = ( - 'The ffns_with_norm registry is used to register functions that build ffn layers that apply a normalization layer.' - + 'See ffn.py for examples.' + """The ffns_with_norm registry is used to register functions that build FFN layers with normalization. + + The resulting layer will have ._has_norm set on it. + One example is te.LayerNormMLP. See ffn.py for examples. + + Args: + d_model: int: The size of the input and output tensors. + expansion_ratio: float: The expansion ratio for the hidden layer. + device: Optional[str]: The device to use for the layer. + bias: bool: Whether or not to include a bias term. + kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer. + + Returns: + torch.nn.Module: The FFN layer. + """ ) ffns_with_norm = create_registry( 'llmfoundry', @@ -58,6 +104,16 @@ 'The ffns_with_megablocks registry is used to register functions that build ffn layers using MegaBlocks.' + 'See ffn.py for examples.' ) +_ffns_with_megablocks_description = ( + """The ffns_with_megablocks registry is used to register functions that build FFN layers using MegaBlocks. + + The resulting layer will have ._uses_megablocks set on it. + One example is megablocks.layers.dmoe.dMoE. See ffn.py for examples. + + Returns: + torch.nn.Module: The FFN layer. + """ +) ffns_with_megablocks = create_registry( 'llmfoundry', 'ffns_with_megablocks', @@ -67,8 +123,17 @@ ) _attention_classes_description = ( - 'The attention_classes registry is used to register classes that implement attention layers. See ' - + 'attention.py for expected constructor signature.' + """The attention_classes registry is used to register classes that implement attention layers. + + The kwargs are passed directly to the constructor of the class. + One example is GroupedQueryAttention. See attention.py for examples. + + Args: + kwargs: Dict[str, Any]: Additional keyword arguments to pass to the layer. + + Returns: + torch.nn.Module: The attention layer. + """ ) attention_classes = create_registry( 'llmfoundry', @@ -79,8 +144,29 @@ ) _attention_implementations_description = ( - 'The attention_implementations registry is used to register functions that implement the attention operation.' - + 'See attention.py for expected function signature.' + """The attention_implementations registry is used to register functions that implement the attention operation. + + One example is 'flash'. See attention.py for examples. + + Args: + query (torch.Tensor): The query tensor. + key (torch.Tensor): The key tensor. + value (torch.Tensor): The value tensor. + n_heads (int): The number of attention heads. + kv_n_heads (int): The number of attention heads for the key and value tensors. + past_key_value (Optional[tuple[torch.Tensor, torch.Tensor]]): The past key and value tensors. + softmax_scale (Optional[float]) = None + attn_bias (Optional[torch.Tensor]) = None + is_causal (bool) = False + dropout_p (float) = 0.0 + training (bool) = True + needs_weights (bool) = False + kwargs: Dict[str, Any]: Additional keyword arguments the implementation accepts. + + Returns: + tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor, torch.Tensor]]]: + The output tensor, the attention weights, and the past key and value tensors. + """ ) attention_implementations = create_registry( 'llmfoundry', @@ -91,9 +177,17 @@ ) _param_init_fns_description = ( - 'The param_init_fns registry is used to register functions that initialize parameters.' - + - 'These will be called on a module to initialize its parameters. See param_init_fns.py for examples.' + """The param_init_fns registry is used to register functions that initialize parameters. + + These functions should take in a torch.nn.Module, additional kwargs, and initialize the parameters of the module. + Generally they can call generic_param_init_fn_ with an appropriate partial function. See param_init_fns.py for examples. + + Note: These functions should take in arbitrary kwargs, and discard any they don't need. + + Args: + module: torch.nn.Module: The module to initialize. + kwargs: Dict[str, Any]: Additional keyword arguments to use for initialization. + """ ) param_init_fns = create_registry( 'llmfoundry', @@ -103,9 +197,23 @@ description=_param_init_fns_description, ) -_module_init_fns_description = """The module_init_fns registry is used to register functions that initialize specific modules. -These functions should return True if they initialize the module, and False otherwise. This allows them to be called without knowing their contents. -They should take in the module, init_div_is_residual, and div_is_residual arguments.""" +_module_init_fns_description = ( + """The module_init_fns registry is used to register functions that initialize specific modules. + + These functions should return True if they initialize the module, and False otherwise. + This allows them to be called without knowing their contents. They should take in the module and additional kwargs. + If multiple functions can initialize the module, the one that is registered first will be used, so it is recommended to + override an existing function if you want to change existing initialization behavior, and add new functions if you have new + layer types. See param_init_fns.py for details. + + Args: + module: torch.nn.Module: The module to initialize. + kwargs: Dict[str, Any]: Additional keyword arguments to use for initialization. + + Returns: + bool: Whether or not the module was initialized. + """ +) module_init_fns = create_registry( 'llmfoundry', 'module_init_fns', diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index f0c6486a1a..5924070497 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -27,11 +27,17 @@ from llmfoundry.utils.registry_utils import create_registry _loggers_description = ( - 'The loggers registry is used to register classes that implement the LoggerDestination interface. ' - + - 'These classes are used to log data from the training loop, and will be passed to the loggers arg of the Trainer. The loggers ' - + - 'will be constructed by directly passing along the specified kwargs to the constructor.' + """The loggers registry is used to register classes that implement the LoggerDestination interface. + + These classes are used to log data from the training loop, and will be passed to the loggers arg of the Trainer. The loggers + will be constructed by directly passing along the specified kwargs to the constructor. See loggers/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the LoggerDestination constructor. + + Returns: + LoggerDestination: The logger destination. + """ ) loggers = create_registry( 'llmfoundry', @@ -42,11 +48,17 @@ ) _callbacks_description = ( - 'The callbacks registry is used to register classes that implement the Callback interface. ' - + - 'These classes are used to interact with the Composer event system, and will be passed to the callbacks arg of the Trainer. ' - + - 'The callbacks will be constructed by directly passing along the specified kwargs to the constructor.' + """The callbacks registry is used to register classes that implement the Callback interface. + + These classes are used to interact with the Composer event system, and will be passed to the callbacks arg of the Trainer. + The callbacks will be constructed by directly passing along the specified kwargs to the constructor. See callbacks/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the Callback constructor. + + Returns: + Callback: The callback. + """ ) callbacks = create_registry( 'llmfoundry', @@ -57,9 +69,18 @@ ) _callbacks_with_config_description = ( - 'The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface. ' - + - 'These are the same as the callbacks registry, except that they additionally take the full training config as an argument to their constructor.' + """The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface. + + These are the same as the callbacks registry, except that they additionally take the full training config as an argument to their constructor. + See callbacks/ for examples. + + Args: + config (DictConfig): The training config. + kwargs (Dict[str, Any]): The kwargs to pass to the Callback constructor. + + Returns: + Callback: The callback. + """ ) callbacks_with_config = create_registry( 'llm_foundry.callbacks_with_config', @@ -69,10 +90,18 @@ ) _optimizers_description = ( - 'The optimizers registry is used to register classes that implement the Optimizer interface. ' - + - 'The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the ' - + 'specified kwargs to the constructor, along with the model parameters.' + """The optimizers registry is used to register classes that implement the Optimizer interface. + + The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the + specified kwargs to the constructor, along with the model parameters. See optim/ for examples. + + Args: + params (Iterable[torch.nn.Parameter]): The model parameters. + kwargs (Dict[str, Any]): The kwargs to pass to the Optimizer constructor. + + Returns: + Optimizer: The optimizer. + """ ) optimizers = create_registry( 'llmfoundry', @@ -83,10 +112,17 @@ ) _algorithms_description = ( - 'The algorithms registry is used to register classes that implement the Algorithm interface. ' - + - 'The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the ' - + 'specified kwargs to the constructor.' + """The algorithms registry is used to register classes that implement the Algorithm interface. + + The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the + specified kwargs to the constructor. See algorithms/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the Algorithm constructor. + + Returns: + Algorithm: The algorithm. + """ ) algorithms = create_registry( 'llmfoundry', @@ -97,10 +133,17 @@ ) _schedulers_description = ( - 'The schedulers registry is used to register classes that implement the ComposerScheduler interface. ' - + - 'The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the ' - + 'specified kwargs to the constructor.' + """The schedulers registry is used to register classes that implement the ComposerScheduler interface. + + The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the + specified kwargs to the constructor. See optim/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the ComposerScheduler constructor. + + Returns: + ComposerScheduler: The scheduler. + """ ) schedulers = create_registry( 'llmfoundry', @@ -111,11 +154,18 @@ ) _models_description = ( - 'The models registry is used to register classes that implement the ComposerModel interface. ' - + - 'The model constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. ' - + - 'Note: This will soon be updated to take in named kwargs instead of a config directly.' + """The models registry is used to register classes that implement the ComposerModel interface. + + The model constructor should accept a PreTrainedTokenizerBase named `tokenizer`, and the rest of its constructor kwargs. + See models/ for examples. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer. + kwargs (Dict[str, Any]): The kwargs to pass to the Composer + + Returns: + ComposerModel: The model. + """ ) models = create_registry( 'llmfoundry', @@ -126,9 +176,19 @@ ) _dataloaders_description = ( - 'The dataloaders registry is used to register functions that create a DataSpec. The function should take ' - + - 'a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.' + """The dataloaders registry is used to register functions that create a DataSpec given a config. + + The function should take a PreTrainedTokenizerBase, a device batch size, and the rest of its constructor kwargs, + and return a DataSpec. See data/ for examples. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer + device_batch_size (Union[int, float]): The device batch size. + kwargs (Dict[str, Any]): The kwargs to pass to the builder function. + + Returns: + DataSpec: The dataspec. + """ ) dataloaders = create_registry( 'llmfoundry', @@ -141,14 +201,19 @@ ) _dataset_replication_validators_description = ( - """Validates the dataset replication args. + """The dataset_replication_validators registry is used to register functions that validate replication factor. + + The function should return the replication factor and the dataset device batch size. See data/ for examples. + Args: cfg (DictConfig): The dataloader config. tokenizer (PreTrainedTokenizerBase): The tokenizer device_batch_size (Union[int, float]): The device batch size. + Returns: replication_factor (int): The replication factor for dataset. - dataset_batch_size (int): The dataset device batch size.""" + dataset_batch_size (int): The dataset device batch size. + """ ) dataset_replication_validators = create_registry( 'llmfoundry', @@ -161,14 +226,19 @@ ) _collators_description = ( - """Returns the data collator. + """The collators registry is used to register functions that create the collate function for the DataLoader. + + See data/ for examples. + Args: cfg (DictConfig): The dataloader config. tokenizer (PreTrainedTokenizerBase): The tokenizer dataset_batch_size (Union[int, float]): The dataset device batch size. + Returns: collate_fn (Any): The collate function. - dataloader_batch_size (int): The batch size for dataloader. In case of packing, this might be the packing ratio times the dataset device batch size.""" + dataloader_batch_size (int): The batch size for dataloader. In case of packing, this might be the packing ratio times the dataset device batch size. + """ ) collators = create_registry( 'llmfoundry', @@ -180,12 +250,17 @@ ) _data_specs_description = ( - """Returns the get_data_spec function. + """The data_specs registry is used to register functions that create a DataSpec given a dataloader. + + See data/ for examples. + Args: dl (Union[Iterable, TorchDataloader): The dataloader. dataset_cfg (DictConfig): The dataset config. + Returns: - dataspec (DataSpec): The dataspec.""" + dataspec (DataSpec): The dataspec. + """ ) data_specs = create_registry( 'llmfoundry', @@ -197,7 +272,17 @@ ) _metrics_description = ( - 'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.' + """The metrics registry is used to register classes that implement the torchmetrics.Metric interface. + + The metric will be passed to the metrics arg of the Trainer. The metric will be constructed by directly passing along the + specified kwargs to the constructor. See metrics/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the Metric constructor. + + Returns: + Metric: The metric. + """ ) metrics = create_registry( 'llmfoundry', @@ -208,7 +293,17 @@ ) _icl_datasets_description = ( - 'The ICL datasets registry is used to register an torch.utils.data.Dataset class which can be used for ICL tasks.' + """The ICL datasets registry is used to register classes that implement the InContextLearningDataset interface. + + The dataset will be constructed along with an Evaluator. The dataset will be constructed by directly passing along the + specified kwargs to the constructor. See eval/ for examples. + + Args: + kwargs (Dict[str, Any]): The kwargs to pass to the Dataset constructor. + + Returns: + InContextLearningDataset: The dataset. + """ ) icl_datasets = create_registry( 'llmfoundry', @@ -226,7 +321,7 @@ """The config_transforms registry is used to register functions that transform the training config The config will be transformed before it is used anywhere else. Note: By default ALL registered transforms will be applied to the train config - and NONE to the eval config. Each transform should return the modified config. + and NONE to the eval config. Each transform should return the modified config. See utils/config_utils.py for examples. Args: cfg (Dict[str, Any]): The training config.