Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Passing cuda device as an argument to Train, Predict nodes #188

Closed
Mohinta2892 opened this issue Apr 17, 2023 · 3 comments
Closed

Passing cuda device as an argument to Train, Predict nodes #188

Mohinta2892 opened this issue Apr 17, 2023 · 3 comments

Comments

@Mohinta2892
Copy link

It would be helpful if we could pass a cuda device explicitly to the train and predict nodes for DGX systems, to be able to spawn multiple models on different cards simultaneously on a single machine.
Currently both scripts default to "cuda:0" on device availability: "cuda:0"

I have added this into a fork of the repo, wanted to discuss if this is something useful to integrate here.

Best,
Samia

@pattonw
Copy link
Collaborator

pattonw commented Oct 6, 2023

Yes, this would be quite helpful to integrate into gunpowder directly

@Mohinta2892
Copy link
Author

Thanks for replying Will!
I shall raise a pull request soon-ish.

Mohinta2892 added a commit to Mohinta2892/gunpowder that referenced this issue Nov 2, 2023
Allow passing cuda device to Predict. Issue funkelab#188
Mohinta2892 added a commit to Mohinta2892/gunpowder that referenced this issue Nov 2, 2023
pattonw added a commit that referenced this issue Dec 19, 2023
commit 1686b949766b76960534ede1105751591fd91c9f
Author: William Patton <[email protected]>
Date:   Tue Dec 19 08:43:11 2023 -0700

    black reformatting

commit 26d2c7cfff3f2702f56a5bb4249a0811f54b45ef
Author: Mohinta2892 <[email protected]>
Date:   Thu Nov 2 19:09:15 2023 +0000

    Revert "black reformatted"

    This reverts commit 66dd69b.

    Only format changed files, since black does not consider formatting history

commit a273fd3813fc16b516c2438ad5af0c4ee3f0686b
Author: Samia Mohinta <[email protected]>
Date:   Thu Nov 2 17:12:26 2023 +0000

    black reformatted

commit bb37769eec33af5921386f283e2579055bb34e6d
Author: Samia Mohinta <[email protected]>
Date:   Thu Nov 2 16:40:32 2023 +0000

    add device arg

    Allow passing cuda device to Predict. Issue #188

commit a3b3588a1406d609ae95370cf2c5339872616011
Author: Samia Mohinta <[email protected]>
Date:   Thu Nov 2 16:39:09 2023 +0000

    add device arg

    allow passing cuda device to Train
@pattonw
Copy link
Collaborator

pattonw commented Jan 3, 2024

solved in a pull request

@pattonw pattonw closed this as completed Jan 3, 2024
pattonw added a commit that referenced this issue Aug 30, 2024
changelog:
features

add probability for applying augmentation nodes via an overrideable can_skip method.
errors now print in reversed order
improve CSVReader to use built in python csv reader
torch predict supports arrays as position arguments
add funlib.persistence.Array source
add ScanCallback
bugfixes:

pytorch train: move hooks to being added in start method. This caused problems when trying to run the model in some multithreaded use cases.
Deform Augment subsampling fixed
avoid np.sctypes["float"] to work on numpy >= 2.0


* remove duplicated for loop

* increment patch number

* ArraySpec docs

fix documentation to be more accurate around nonspatial arrays

* ArraySpec bug fix:

allow None roi/voxel size for spatial arrays

* Add probability option to aug nodes

* Revert can_skip to private method

* fix the deform augment test

no longer assumes a deformed label will still exist in an array

* better bounds on required packages

* ignore missing imports from packages that don't provide type hints

* fix typehint mistakes

* format pyproject.toml

* black format

* move register hooks to the start method

This is to get around local functions (i.e. the hooks) not being
pickle-able which we need for the "spawn" start function
(spawn is the default on windows and recent macs)

* fix typo

* support non-spatial arrays in ArraySource

* overhaul torch tests

* remove multiprocess set start method monkey patch

We want to test with both fork and spawn start methods, but this
seems to interfere with the torch tests

* only deploy docs on tagged commits to main

* minor black formatting and configuration changes

* properly skip torch tests if torch not installed

* black formatting

* avoid testing on python 3.7, instead use 3.11

numpy is no longer releasing updates for python 3.7, they are on 1.24
but the last release for 3.7 was 1.21.
I don't think we need to support it either, but we should test on 3.11

* add typed libraries to dev dependencies

* test subsampling in deform augment

test fails

* fix bugs associated with subsampling

* deform augment

fix bug with checking dims of graph_raster_voxel_size

* Add progress callback to Scan node

* pass torch train test

if using start method = "spawn" and the "start_subprocess" flag
for the predict node, we now pass our test.

* pass torch train test

if using the start method "spawn", and the "spawn_subprocess" flag for
the train node, we now pass our test

* remove extra error printing

* switch error printing order

Now prints the errors in reverse order of execution so the initial pipeline error is printed first

* black format docs and examples

* Squashed commit of the following:

commit 1686b949766b76960534ede1105751591fd91c9f
Author: William Patton <[email protected]>
Date:   Tue Dec 19 08:43:11 2023 -0700

    black reformatting

commit 26d2c7cfff3f2702f56a5bb4249a0811f54b45ef
Author: Mohinta2892 <[email protected]>
Date:   Thu Nov 2 19:09:15 2023 +0000

    Revert "black reformatted"

    This reverts commit 66dd69b.

    Only format changed files, since black does not consider formatting history

commit a273fd3813fc16b516c2438ad5af0c4ee3f0686b
Author: Samia Mohinta <[email protected]>
Date:   Thu Nov 2 17:12:26 2023 +0000

    black reformatted

commit bb37769eec33af5921386f283e2579055bb34e6d
Author: Samia Mohinta <[email protected]>
Date:   Thu Nov 2 16:40:32 2023 +0000

    add device arg

    Allow passing cuda device to Predict. Issue #188

commit a3b3588a1406d609ae95370cf2c5339872616011
Author: Samia Mohinta <[email protected]>
Date:   Thu Nov 2 16:39:09 2023 +0000

    add device arg

    allow passing cuda device to Train

* parameterize tests for cuda devices

currently failing a few of them, some are expected failures.

* Added support for reflect padding

Squashed commit of the following:

commit 0fb29c8
Author: William Patton <[email protected]>
Date:   Tue Jan 2 08:54:17 2024 -0800

    replace custom padding code with np.pad

commit c6928bd
Author: William Patton <[email protected]>
Date:   Tue Jan 2 08:54:06 2024 -0800

    simplify/expand padding test

    test padding on both sides

commit 3782525
Author: William Patton <[email protected]>
Date:   Tue Dec 19 11:30:31 2023 -0700

    pass the fixed tests

commit a7027c6
Author: William Patton <[email protected]>
Date:   Tue Dec 19 10:37:48 2023 -0700

    fix the test case

commit 531d81d
Author: William Patton <[email protected]>
Date:   Tue Dec 19 10:06:44 2023 -0700

    update the pad tests

    parametrized the use of constant or reflect padding.

    Now avoids using the unittest framework

commit 443c666
Author: Manan Lalit <[email protected]>
Date:   Fri Nov 3 00:09:33 2023 -0400

    Replace .ndim by len()

commit a7503d7
Author: lmanan <[email protected]>
Date:   Thu Nov 2 11:52:27 2023 -0400

    Update pad.py to include reflective padding

* Fix bug in rasterize graph

we were using `graph.data.items()` to iterate over nodes instead of `graph.nodes`

Squashed commit of the following:

commit d027f5a260a1e2a9cf851efca85b7318434675d6
Author: William Patton <[email protected]>
Date:   Tue Jan 2 09:44:21 2024 -0800

    refactor rasterize_points test to use pytest

commit eadb0476d8475b55120486df6cf30f95b6df86f4
Author: William Patton <[email protected]>
Date:   Tue Jan 2 09:25:11 2024 -0800

    remove extra roi handling

    The node only needs to request the data it needs for its
    own operations.
    If you request a mask for a set of points that extend outside
    the bounds of your mask you will get an error

commit 29507f1f21d69cf76e34e7b0f05cd780100fd68b
Author: William Patton <[email protected]>
Date:   Tue Jan 2 09:22:21 2024 -0800

    remove type cast

    we do a bitwise during the `__rasterize` call which
    results fails if you change the dtype

commit 96e93e53ce0bc8240357259dab92f1ca64a08199
Author: William Patton <[email protected]>
Date:   Tue Jan 2 09:21:16 2024 -0800

    remove matplotlib

commit eb2977a187a1cad95da54a515c84ce44d73b8315
Author: Samia Mohinta <[email protected]>
Date:   Thu Dec 14 15:27:41 2023 +0000

    fix mask intersection with request

    outputs must match request rois when a mask is provided

commit 682189dac2ef6b94876bd30df813717da6530060
Author: Samia Mohinta <[email protected]>
Date:   Thu Dec 14 15:25:12 2023 +0000

    Update rasterize_graph.py

commit e36dcf179ccd1aec6a5cafd31e7a9a858352faa1
Author: Mohinta2892 <[email protected]>
Date:   Thu Nov 2 19:18:00 2023 +0000

    reformat rasterize_graph and rasterize_points

commit 42da2702e746f702d3d07144ab9fc1d4352b0c0d
Author: Samia Mohinta <[email protected]>
Date:   Thu Nov 2 14:23:01 2023 +0000

    Test for issue #193

    Test added to pass mask to `RasterizeGraph()` via `RasterizationSettings`.

commit b17cfad413f5ad7f48045a2167ec20d89674d939
Author: Samia Mohinta <[email protected]>
Date:   Thu Nov 2 14:19:42 2023 +0000

    fix for issue #193

    lines 224-226: replace graph.data.items() with graph.nodes
    lines 255-257: explicitly cast the boolean mask data to the original dtype of mask_array

* ruff: remove unused imports and fix small typos.

* Custom BatchRequestError handling in pipeline.request_batch

We can filter out some more of the excess error traceback that isn't helpful to the readers.

* black formatting

* mypy workflow use dev dependencies

* avoid testing on python 3.8, it doesn't support typing very well

* fix type hint for logdir in torch train

* switch order of decorators to avoid trying to determine if cuda is available if torch isn't installed

* check if torch is installed before checking if cuda is available

* update funlib.geometry version for mypy typing

* remove batch.id

replaced in tensorflow predict node debug statments with the request. This better indicates
the roi being predicted on.
replaced in snapshot node with an internal counter

* Provide the separator to the csv points source

* Use csv reader in csv points source

* Test csv points source with new dev dependencies

* Update required python to 3.9

* Black test cases

* Fix typos in pytest unordered dependency

* Remove pytest unordered dependency

* Black and ruff CSVPointsSource tests

* Correctly read and document ids in CsvPointsSource

* Automatically detect header in CSVPointsSource

* Test all CSVPointsSource functionality

* add support for args as inputs to predict.py

Its often not so straightforward to know the key word argument name for the forward function of your model. Especially if you use something like `torch.nn.Sequential`

* black reformat pad.py test

* remove excessive seed setting. I don't think this is necessary since as soon as the seeds are set, the rest of the tests are determanistic

* Pytorch Train: let users specify model inputs as args instead of kwargs

* PyTorch Train: add tests for using arg indexes for model inputs

* depend on overhauled funlib.persistence

* add funlib.persistence array source

* black formatting fix

* Add basic `ArraySource` node that accepts any `funlib.persistence.Array`

* add ArraySource to docs

* fix dtype checking for float types for numpy >= 2.0

* add documentation for gradients argument of torch `Train` node

* add typehint for dict

* black reformatting

* add support for python 3.12

* remove distutils

* black formatting

---------

Co-authored-by: sheridana <[email protected]>
Co-authored-by: Jan Funke <[email protected]>
Co-authored-by: Caroline Malin-Mayor <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants