forked from google/flax
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request google#680 from jheek:lift-md
PiperOrigin-RevId: 345646943
- Loading branch information
Showing
4 changed files
with
137 additions
and
124 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# Design Note: transformation lifting | ||
|
||
## Introduction | ||
|
||
JAX uses a functional api meaning that it deals only which pure functions only. | ||
A pure function is defined as a function where the output only depends on the function arguments. | ||
Therefore, mutable state outside the function should not affect the function itself and the function | ||
itself should not cause side effects in objects that live outside of the function. | ||
|
||
Python functions do not have to be pure because they allow side effects or mutations to occur. | ||
For JAX restricting the API to pure functions has a number of advantages: | ||
|
||
1. It becames easier to reason about functions locally | ||
2. Both stochasticity and determinism are explicit because a function can only return a different output if the arguments are changed. | ||
3. Functional transforms which would otherwise be ambigious. | ||
|
||
## Functionalization | ||
|
||
TODO | ||
|
||
## Lifting | ||
|
||
TODO | ||
|
||
## Alternatives | ||
|
||
TODO |
Oops, something went wrong.