Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reenable the distributed checkpointing test #8424

Merged
merged 1 commit into from
Dec 2, 2024

Conversation

JackCaoG
Copy link
Collaborator

@JackCaoG JackCaoG commented Nov 27, 2024

This is follow up of #8386.

In the previous pr I found that someone during fallback the pytorch will try to update an existing XLATensor with a CPU tesnor with different shape. In that case we need to remove the sharding spec otherwise there will be a shape mismatch. However I found that in the distributed point we will swap the existing XLATensor with the cpu tensor and it seems like we want to keep the sharding spec.

@jonb377 one concern I have is that test only test the single host, I felt like if it is a actual multi-host case the CPU tensor withh have different shape(sharded) than the shardingspec? I am not sure if we have such test somewhere. Even if we clear the shardingspec after a torch_xla.sync() the tensor will be moved to the device, but most likely replicated. I am a bit worried if I am breaking the distributed checkpointing here.

@JackCaoG JackCaoG added the tpuci label Nov 27, 2024
@JackCaoG JackCaoG marked this pull request as ready for review November 28, 2024 09:06
@JackCaoG JackCaoG merged commit 591c397 into master Dec 2, 2024
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants