YAX: JAX/FLAX Module Tracing, Evaluation, and Mutation #24605
Unanswered
daskol
asked this question in
Show and tell
Replies: 1 comment
-
A piece of advice - if you write marketing copy (this is a marketing copy), don't make it so long nobody wants to read it. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
YAX: Yet Another X
JAX/FLAX Module Tracing, Evaluation, and Mutation.
https://github.com/daskol/yax
Rationale
Module-level API, provided by JAX/FLAX, PyTorch, Keras, and other deep learning frameworks, is an easy, clear, and descriptive way to define a neural network. However, model examination and mutation are not so straightforward. For example, it is generally impossible to change activation function in a layer somewhere deep in a model. Another example is lorafying — substituting fully-connected layers with LoRA-like adapters (qax solves the issue for vanilla LoRA for JAX/FLAX; HuggingFace's peft is a collection of adapters in PyTorch). What is easy and general way to modify model?
A model defined with FLAX basically consists of a weight dictionary and a routine that calculates model outputs by weight dictionary and inputs. One can easily operate with weight dictionary: flatten it or unflatten, change its structure, delete leaves, add internal nodes, and so on. Weight manipulation is trivial. Non-trivial part is changing internal model structure.
In order to demonstrate this, one can consider the model below and try to replace an internal module
nn.Dense(42)
with an attention layer.As it was said before, it is trivial to change weight dictionary (just replace relevant params subtree). However, one wants to replace applying of internal module
nn.Dense(42).__call__
to input with application of the desired attention module. The only way to do it is to rewrite methodModel.__call__
as follows.Generally speaking, this is a nice solution: it is simple but requires some copy-pasting. However, this solution does not scale well. A change for a deeply nested modules requires rewriting and copy-pasting of all parent modules. Thus we need a general approach which would require only local changes and apply them to a model.
Approach
The main idea is to use JAX trace/tracers facility to build intermediate representation of a FLAX module. This intermediate representation is called a Module eXpression (MoX) which is essentially an extension to Jaxpr. It represents modules in a tree structure where internal nodes and leaves correspond to modules and primitives including Jaxpr.
In order to build a MoX, JAX monad transformer facility and FLAX interceptors API are used to build a module expression tree (MoX is actually a tree).
ModuleTrace
is pushed to the top of global stack.ModuleTracer
which keeps only shape and data element type (dtype) of a source tensor.ModuleTrace
binds primitives to abstract expression evaluation in order to catch data dependency between inputs and outputs and support shape and dtype correctness.jax.jit
.Tracing Module expression can be easily built in a similar way to building a Jaxpr (see
jax.make_jaxpr
).The resulting MoX can be pretty printed as follows (but it looks not very presentable for now).
Evaluation Since MoX is opaque to inputs, there is no problem to evaluate it and apply JIT.
Querying MoX provides tools powered by XPath for model exploration and examination. Specifically, MoX can help answer questions like: "What
nn.Dense
modules have 10 features?"Modification With an expressive query language like XPath, modifying an original model on the fly becomes easy. For example, one can replace all ReLU activation functions with GELU or substitute all
nn.Dense
layers with LoRA adapters (see code snippets and Use Cases below).Use Cases
Assume that we have a fancy model
FancyDiffusionGANTransformerRNNModel
. We can initialize it from scratch or load from checkpoint downloaded from HuggingFace Hub as usual. But it turns out that it takes a lot of memory for training or fine-tuning and we decided to apply some optimization. So we build a module expressionold_mox
for it withmake_mox
.Activation Functions
Let's assume that original model uses ReLU everywhere but we want to replace it with new GELU activation functions that requires more compute but gives faster convergence and more stable training. Moreover, some GELU variants saves memory as well (see few-bit). The idea is pretty straightforward: make MoX from new activation function and replace all ReLUs with it as follows.
Low-Rank Adaptation (LoRA)
Another use case stems from popular parameter-efficient fine-tuning approach. The idea is to add to original fully-connected layers a low-rank correction and train only that correction weights (see LoRA). The idea is the same: instantiate a new model, initialize it, obtain MoX, and replace original layer with new one.
Beta Was this translation helpful? Give feedback.
All reactions