Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Read speeds decrease 2x when reading with fewer processes #195

Open
heiner opened this issue Sep 6, 2024 · 10 comments
Open

Read speeds decrease 2x when reading with fewer processes #195

heiner opened this issue Sep 6, 2024 · 10 comments

Comments

@heiner
Copy link

heiner commented Sep 6, 2024

The issue

Given a specific checkpoint, load it in two different settings:

  1. Load it with 64 nodes, 512 GPUs, 512 processes (1 GPU / process).
  2. Load it with 64 nodes, 512 GPUs, 64 processes (8 GPUs / process).

What I observe:

  1. Using 512 processes, reading takes ~20 seconds.
  2. Using 64 processes, reading takes ~40 seconds (2x).

The checkpoint in question is also written with 512 processes (see below for repro). Except for the number of processes, nothing else changes (sharding etc. stays the same).

To reproduce.

Download this file and run it in a context with 64 nodes, 8 GPUs each. Make sure hostfile has the hostnames of the 64 nodes. (mpirun isn't essential here, it's just a way to spawn these processes.)

To create the checkpoint:

mpirun -hostfile hostfile -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 -np 512 -npernode 8 python ts_multigpu.py /data/heiner/ckpttest4-64/ $(hostname):1234 2>&1 | grep -v 'The transformations API'

To load the checkpoint with 512 processes:

mpirun -hostfile hostfile -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 -np 512 -npernode 8 python ts_multigpu.py /data/heiner/ckpttest4-64/ $(hostname):1234 2>&1 | grep -v 'The transformations API'

This takes ~20 sec for me.

To load the checkpoint with 64 processes:

mpirun -hostfile hostfile -mca pml ob1 -mca btl ^openib -mca btl_tcp_if_include eth0 -np 64 -npernode 1 python ts_multigpu.py /data/heiner/ckpttest4-64/ $(hostname):1234 2>&1 | grep -v 'The transformations API'

This takes ~40 sec for me.

The issue doesn't seem to be in Orbax because the same happens with a plain jax.experimental.serialization.async_deserialize.

@heiner heiner changed the title Read speeds increase 2x when reading with fewer processes Read speeds decrease 2x when reading with fewer processes Sep 6, 2024
@laramiel
Copy link
Collaborator

laramiel commented Sep 11, 2024

I don't have access to a cluster; is there a local method to run this?

I'm trying something like:

 sudo apt install libopenmpi-dev
 python3 -m venv ts-venv
 source ts-venv/bin/activate
 python3 -m pip install jax orbax numpy tensorstore mpi4py
 
 mpirun -np 1 python3 $(pwd)/ts_mpitest.py /tmp/data1/ $(hostname):3345

Edit: So I hacked your file and replaced some large values to run on my machine with num processes = 1
as well as changing you large arrays to much smaller ones. This won't be anything like what you've got:

for x in ts.experimental_collect_matching_metrics():
  print(x)

The tensorstore spec looks something like:

$ mpirun -np 1 python3 $(pwd)/ts_mpitest.py /usr/local/google/tmp/data1/ $(hostname):3345

Loading existing checkpoint
Starting checkpoint load
{'driver': 'zarr', 'kvstore': {'driver': 'ocdbt', 'base': {'driver': 'file', 'path': '/usr/local/google/tmp/data1'}, 'path': '0', 'experimental_read_coalescing_threshold_bytes': 1000000, 'experimental_read_coalescing_merged_bytes': 500000000000, 'experimental_read_coalescing_interval': '1ms', 'cache_pool': 'cache_pool#ocdbt'}, 'recheck_cached_data': False, 'recheck_cached_metadata': False}
...
Loaded checkpoint from /usr/local/google/tmp/data1/ in 107.53 sec
{'name': '/tensorstore/cache/chunk_cache/reads', 'values': [{'value': 33}]}
{'name': '/tensorstore/cache/hit_count', 'values': [{'value': 44}]}
{'name': '/tensorstore/cache/kvs_cache_read', 'values': [{'category': 'changed', 'value': 46}]}
{'name': '/tensorstore/cache/miss_count', 'values': [{'value': 48}]}
{'name': '/tensorstore/futures/force_callbacks', 'values': [{'value': 303}]}
{'name': '/tensorstore/futures/live', 'values': [{'max_value': 162, 'value': 1}]}
{'name': '/tensorstore/futures/not_needed_callbacks', 'values': [{'value': 45}]}
{'name': '/tensorstore/futures/ready_callbacks', 'values': [{'value': 436}]}
{'name': '/tensorstore/internal/riegeli/noncontiguous_bytes', 'values': [{'value': 54767124480}]}
{'name': '/tensorstore/internal/thread/schedule_at/insert_histogram_ms', 'values': [{'0': 0, '1': 46, 'count': 46, 'mean': 0.0, 'sum_of_squared_deviation': 0.0}]}
{'name': '/tensorstore/internal/thread/schedule_at/next_event', 'values': [{'value': 'infinite-future'}]}
{'name': '/tensorstore/internal/thread/schedule_at/queued_ops', 'values': [{'max_value': 15, 'value': 0}]}
{'name': '/tensorstore/kvstore/file/batch_read', 'values': [{'value': 27}]}
{'name': '/tensorstore/kvstore/file/bytes_read', 'values': [{'value': 50616803970}]}
{'name': '/tensorstore/kvstore/file/open_read', 'values': [{'value': 27}]}
{'name': '/tensorstore/kvstore/file/read', 'values': [{'value': 27}]}
{'name': '/tensorstore/kvstore/file/read_latency_ms', 'values': [{'0': 0, '1': 2, '10': 0, '11': 0, '12': 0, '13': 0, '14': 3, '15': 6, '16': 7, '17': 6, '2': 0, '3': 0, '4': 0, '5': 2, '6': 0, '7': 1, '8': 0, '9': 0, 'count': 27, 'mean': 20684.2962962963, 'sum_of_squared_deviation': 9117420933.62963}]}
{'name': '/tensorstore/kvstore/ocdbt/read', 'values': [{'value': 45}]}
{'name': '/tensorstore/thread_pool/active', 'values': [{'max_value': 24, 'value': 13}]}
{'name': '/tensorstore/thread_pool/max_delay_ns', 'values': [{'max_value': 16552002634}]}
{'name': '/tensorstore/thread_pool/started', 'values': [{'value': 24}]}
{'name': '/tensorstore/thread_pool/steal_count', 'values': [{'value': 37.0}]}
{'name': '/tensorstore/thread_pool/task_providers', 'values': [{'max_value': 2, 'value': 0}]}
{'name': '/tensorstore/thread_pool/total_queue_time_ns', 'values': [{'value': 95758202782.0}]}
{'name': '/tensorstore/thread_pool/work_time_ns', 'values': [{'value': 1072415922369.0}]}

@heiner
Copy link
Author

heiner commented Sep 11, 2024

Hey Laramie - thanks for taking a look!

Unfortunately, I haven't managed to create a smaller repro yet. I'll run with experimental_collect_matching_metrics and get back to you soon.

More generally, do you know of any settings I might need to change to increase the per-process throughput? Or failing that, is there a (possibly hacky) way to have separate independent TensorStore clients within a single process? I suspect there's some kind of per-process limit (threadpool, TCP/IP connections, etc) that we hit here.

@laramiel
Copy link
Collaborator

laramiel commented Sep 11, 2024

At the tensorstore layer this is using an ocdbt kvstore on top of a file kvstore.
Tensorstore has some context settings for files which you could try: https://google.github.io/tensorstore/kvstore/file/index.html

Try setting "file_io_concurrency", which defaults to max(4, hardware_concurrency).

https://en.cppreference.com/w/cpp/thread/thread/hardware_concurrency

You could also add detailed logging to the file operations via TENSORSTORE_VERBOSE_LOGGING=file=2

How many hosts are in your hostfile? And what is the underlying filesystem?

@heiner
Copy link
Author

heiner commented Sep 11, 2024

There's 64 nodes (it says so in the issue description above). The file system is a distributed file system a la Lustre or VAST.

I already tried setting file_io_concurrency manually and it didn't seem to help.

@rdyro
Copy link

rdyro commented Sep 11, 2024

I don't work on tensorstore directly, but one setting I found helps with loading performance sometimes is the ocdbt_target_data_file_size

def save(state, path, ocdbt_target_file_size: int = 2 * 1024 ** 3):
  start = time.time()
  ocp.PyTreeCheckpointer(use_ocdbt=True, use_zarr3=True).save(
    path, ocp.args.PyTreeSave(
      item=state, ocdbt_target_data_file_size=ocdbt_target_file_size))
  log(f"Saved checkpoint to {path} in {time.time() - start:.2f} sec")

def load(path, shape_dtype):
  start = time.time()
  state = ocp.PyTreeCheckpointer(use_ocdbt=True, use_zarr3=True).restore(
    path, ocp.args.PyTreeRestore(
      shape_dtype, restore_args=ocp.checkpoint_utils.construct_restore_args(shape_dtype),
  ))
  end = time.time()
  log(f"Loaded checkpoint from {path} in {end - start:.2f} sec")
  return state

2 GB is the default, but going smaller might help

@laramiel
Copy link
Collaborator

laramiel commented Sep 11, 2024

I imagine that a lot of the performance will have to do with specific details about how the filesystem interaction happens.
So this is basically running either a single process per node (n=64) or 8 processes per node (n=512).

If it's related to file_io_concurrency that implies going from something like 8x8 threads issuing io to something like 1x8 threads issuing io (if hardware_concurrency is, for example 8).

I would be interested to see the output of the tensorstore counters on for the various configs.

Edit: Looking at orbax it appears that file_io_concurrency has been set to an adequately large value.

https://github.com/google/orbax/blob/d27fcdd8e9227fcd3d631554f17fc90e4c04e150/checkpoint/orbax/checkpoint/type_handlers.py#L58

It would be nice to get a pprof of these; is that possible?

@laramiel
Copy link
Collaborator

Ok, I figured out an inconsistency with our internal build which makes logging hard to use in python. Once I get it added then it will be easier to debug what's going on.

@laramiel
Copy link
Collaborator

laramiel commented Oct 3, 2024

You should now be able to set this environment variable and look at the io timing across runs:

TENSORSTORE_VERBOSE_LOGGING=file=1,file_detail=2

@laramiel
Copy link
Collaborator

laramiel commented Oct 17, 2024

I just submitted a tscli change that will help me to create better test harness/benchmark for this case.
If you could run this and let me know what the output is I'd appreciate it.
If the parameter names have meaning (in the path component) it's fine to redact them in a consistent way.

git clone https://github.com/google/tensorstore
cd tensorstore
./bazelisk.py build //tensorstore/tscli
alias tscli=$(pwd)/bazel-bin/tensorstore/tscli/tscli

for x in $(tscli search file:/// /usr/local/google/tmp/data1/); do
    echo
    tscli print_spec --spec "$x"
    tscli print_stats --spec "$x"
    echo
done

@laramiel
Copy link
Collaborator

I have been running a variant of this with my updated multi_read_benchmark. We found some internal tensorstore chunk cache contention which may help here. It was alleviated in 5927385

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants