Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to run a simple inference on Switch base #5

Open
younesbelkada opened this issue Oct 12, 2022 · 1 comment
Open

How to run a simple inference on Switch base #5

younesbelkada opened this issue Oct 12, 2022 · 1 comment

Comments

@younesbelkada
Copy link

Hi there!

First of all, awesome work on Switch transformers 🔥
I was wondering if there is a simple example script / commands to do a simple inference using switch_base model?
Thanks !

@younesbelkada
Copy link
Author

younesbelkada commented Oct 12, 2022

I finally managed to have a working script - for those who are interested you would need to:

1- Prepare the working setup:

git clone --branch=main https://github.com/google-research/t5x
cd t5x
python3 -m pip install -e '.[tpu]' -f \
  https://storage.googleapis.com/jax-releases/libtpu_releases.html

git clone https://github.com/google/flaxformer.git
cd flaxformer
pip3 install '.[testing]'

2- Get the model checkpoints:

export PATH_CHECKPOINTS=...
gcloud storage cp -r gs://t5-data/pretrained_models/t5x/moe/switch_classic/base/e8/checkpoint_500100 $PATH_CHECKPOINTS

3- Create and save a gin file for Switch Transformers similar to flaxformers (an example below)

# Switch Transformer Base model.
#
# Based on the original Switch Transformer (https://arxiv.org/abs/2101.03961).
#
# Note that unlike the original Switch Transformer, this T5X version does not
# use any jitter noise in the router.
#
# Provides MODEL and NUM_EXPERTS.

from __gin__ import dynamic_registration

from flaxformer.architectures.moe import moe_architecture
from flaxformer.architectures.moe import moe_layers
from flaxformer.architectures.moe import routing
from flaxformer.components import dense
import seqio
from t5x import adafactor

ARCHITECTURE = %gin.REQUIRED

include 'flaxformer/t5x/configs/moe/models/tokens_choose_base.gin'

# Architecture overrides
MLP_DIM = 3072

# MoE overrides
NUM_EXPERTS = 128
# Replace every other MLP sublayer is an MoE sublayer.
NUM_ENCODER_SPARSE_LAYERS = 6
NUM_DECODER_SPARSE_LAYERS = 6
TRAIN_EXPERT_CAPACITY_FACTOR = 1.25
EVAL_EXPERT_CAPACITY_FACTOR = 2.
NUM_SELECTED_EXPERTS = 1   # Switch routing
AUX_LOSS_FACTOR = 0.01
ROUTER_Z_LOSS_FACTOR = 0.0
GROUP_SIZE = 8192



# Switch Transformer Base uses relu activations.
dense.MlpBlock.activations = ('relu',)
expert/dense.MlpBlock.activations = ('relu',)

# Switch Transformer Base re-uses the token embedder to compute output logits.
moe_architecture.SparseDecoder.output_logits_factory = None

# Switch Transformer doesn't use BPR in encoder (although most sparse encoders
# generally see a boost from it).
sparse_encoder/routing.TokensChooseMaskedRouter.batch_prioritized_routing = False

Call it for example switch_base.gin and save it wherever you are happy to save it - (we'll refer it to PATH_GIN_BASE below)

4- Create an infer gin file for Switch Transfrormers:

from __gin__ import dynamic_registration

import __main__ as infer_script
from t5.data import mixtures
from t5x import partitioning
from t5x import utils

include "t5x/configs/runs/infer.gin"
# Here use $PATH_GIN_BASE
include "t5x/examples/moe/switch-c/switch_base.gin"

DROPOUT_RATE = 0.0  # unused but needs to be specified
MIXTURE_OR_TASK_NAME = "wmt_t2t_ende_v003"
TASK_FEATURE_LENGTHS = {"inputs": 64, "targets": 64}

partitioning.PjitPartitioner.num_partitions = 1

utils.DatasetConfig:
  split = "test"
  batch_size = 32

And save it somewhere and call it for example switch_base_infer.gin - and save it wherever (we'll refer it to PATH_INFER_GIN below)

5- Run the inference script!

Finally run the command: python -m t5x.infer --gin_file=$PATH_INFER_GIN --logtostderr --gin.MODEL_DIR=\"~/disk\" --gin.CHECKPOINT_PATH=\"/$PATH_CHECKPOINT\" --gin.INFER_OUTPUT_DIR=\"./\" --gin.NUM_MODEL_PARTITIONS=1 --gin.NUM_EXPERTS=8
Make sure the paths that are defined in the include are correct! Otherwise you'll get some errors

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant