Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix StridedMemoryView by deferring the check for whether a capsule is versioned #292

Merged
merged 2 commits into from
Dec 13, 2024

Conversation

leofang
Copy link
Member

@leofang leofang commented Dec 12, 2024

Close #285.

xref: #285 (comment)

@leofang leofang added bug Something isn't working P0 High priority - Must do! cuda.core Everything related to the cuda.core module labels Dec 12, 2024
@leofang leofang added this to the cuda.core beta 2 milestone Dec 12, 2024
@leofang leofang self-assigned this Dec 12, 2024
Copy link

copy-pr-bot bot commented Dec 12, 2024

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@leofang
Copy link
Member Author

leofang commented Dec 12, 2024

@yangcal could you test this instead?

@leofang
Copy link
Member Author

leofang commented Dec 12, 2024

/ok to test

@leofang
Copy link
Member Author

leofang commented Dec 13, 2024

/ok to test

Copy link
Contributor

@ksimpson-work ksimpson-work left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved and locally verified

import jax.numpy as jnp
import jax
from cuda.core.experimental.utils import args_viewable_as_strided_memory
from cuda.core.experimental import Device

@args_viewable_as_strided_memory((0,))
def parse_tensor(arr):
    dev = Device(0)
    dev.set_current()
    stream = dev.create_stream()
    view = arr.view(stream.handle)

arr = jnp.array([1, 2, 3], device = jax.devices("cuda")[0])
parse_tensor(arr)

@leofang
Copy link
Member Author

leofang commented Dec 13, 2024

Thanks, Keenan! Let's merge!

@leofang leofang merged commit ddc1f94 into NVIDIA:main Dec 13, 2024
30 checks passed
@leofang leofang deleted the fix_view branch December 13, 2024 21:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working cuda.core Everything related to the cuda.core module P0 High priority - Must do!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

StridedMemoryView fails with Jax arrays
2 participants