Skip to content

Commit

Permalink
add benchmark on state traversal, and a readme
Browse files Browse the repository at this point in the history
  • Loading branch information
IvyZX committed Dec 10, 2024
1 parent 21585ad commit f043c5e
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 34 deletions.
15 changes: 15 additions & 0 deletions benchmarks/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Benchmarks

These are mini benchmarks to measure the performance of NNX operations.

Sample profile command:

```shell
python -m cProfile -o ~/tmp/overhead.prof benchmarks/nnx_graph_overhead.py --mode=nnx --depth=100 --total_steps=1000
```

Sample profile inspection:

```shell
snakeviz ~/tmp/overhead.prof
```
1 change: 0 additions & 1 deletion benchmarks/nnx_graph_overhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# %%
import jax
import jax.numpy as jnp
import numpy as np
Expand Down
106 changes: 106 additions & 0 deletions benchmarks/nnx_state_traversal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# Copyright 2024 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Example profile command:
# python -m cProfile -o ~/tmp/overhead.prof benchmarks/nnx_graph_overhead.py --mode=nnx --depth=100 --total_steps=1000
# View profile (need to install snakeviz):
# snakeviz ~/tmp/overhead.prof

import jax
from time import time

from flax import nnx

from absl import flags
from absl import app

FLAGS = flags.FLAGS
flags.DEFINE_integer('total_steps', 1000, 'Total number of training steps')
flags.DEFINE_integer('width', 4, 'Width of each level')
flags.DEFINE_integer('depth', 4, 'Depth of the model')


class NestedClass(nnx.Module):
def __init__(self, width, depth):
self.x = nnx.Variable(jax.numpy.ones((depth+1, )))
if depth > 0:
for i in range(width):
setattr(self, f'child{i}', NestedClass(width, depth-1))


def main(argv):
print(argv)
total_steps: int = FLAGS.total_steps
width: int = FLAGS.width
depth: int = FLAGS.depth


model = NestedClass(width, depth)
to_test = nnx.state(model)

print(f'{total_steps=}, {width=}')

#------------------------------------------------------------
# tree_flatten_with_path
#------------------------------------------------------------
t0 = time()
for _ in range(total_steps):
jax.tree_util.tree_flatten_with_path(to_test)

total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth
print("### tree_flatten_with_path ###")
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')


#------------------------------------------------------------
# tree_map_with_path
#------------------------------------------------------------

t0 = time()
for _ in range(total_steps):
jax.tree_util.tree_map_with_path(lambda _, x: x, to_test)

total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth
print("### tree_map_with_path ###")
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')


#------------------------------------------------------------
# tree_flatten
#------------------------------------------------------------

t0 = time()
for _ in range(total_steps):
jax.tree_util.tree_flatten(to_test)

total_time = time() - t0
time_per_step = total_time / total_steps
time_per_layer = time_per_step / depth
print("### tree_flatten ###")
print('total time:', total_time)
print(f'time per step: {time_per_step * 1e6:.2f} µs')
print(f'time per layer: {time_per_layer * 1e6:.2f} µs')



if __name__ == '__main__':
app.run(main)
Loading

0 comments on commit f043c5e

Please sign in to comment.